<a href="https://colab.research.google.com/github/ko-bbie/bert_livedoor_classifier/blob/master/notebook/livedoor_albert_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers==2.11

Collecting transformers==2.11
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |▌                               | 10kB 31.0MB/s eta 0:00:01[K     |█                               | 20kB 6.2MB/s eta 0:00:01[K     |█▌                              | 30kB 7.1MB/s eta 0:00:01[K     |██                              | 40kB 7.7MB/s eta 0:00:01[K     |██▍                             | 51kB 7.3MB/s eta 0:00:01[K     |███                             | 61kB 8.1MB/s eta 0:00:01[K     |███▍                            | 71kB 8.0MB/s eta 0:00:01[K     |███▉                            | 81kB 9.0MB/s eta 0:00:01[K     |████▍                           | 92kB 8.5MB/s eta 0:00:01[K     |████▉                           | 102kB 8.5MB/s eta 0:00:01[K     |█████▍                          | 112kB 8.5MB/s eta 0:00:01[K     |█████▉                          | 122

In [None]:
from google.colab import drive

# Google Driveをマウントする仮想マシン上のディレクトリ
DIR_DRIVE = './gdrive/'

# Google Drive上でのNotebook等の各種ファイルのパス
DIR_COLAB = DIR_DRIVE + 'My Drive/Colab Notebooks/'
DIR_PROJCET = DIR_COLAB + 'livedoor_classification/'

# Google Driveをマウント
drive.mount(DIR_DRIVE)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at ./gdrive/


In [None]:
DATA_DIR = DIR_DRIVE + 'My Drive/nlp/data/'

## import

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import transformers
import re
import unicodedata
from tqdm import tqdm
from sklearn import model_selection
from sklearn import metrics
from transformers import AdamW, get_linear_schedule_with_warmup

## Config

In [None]:
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 16
EPOCHS = 10
ALBERT_PATH = "ALINEAR/albert-japanese-v2"
MODEL_PATH = DATA_DIR + "albert-model.bin"
TRAINING_FILE = DATA_DIR + "train.csv"
TOKENIZER = transformers.AutoTokenizer.from_pretrained(ALBERT_PATH)


## Utils

In [None]:
def category_dict():
    dic = {
        "sports-watch": [1, 0, 0, 0, 0, 0, 0, 0, 0],
        "smax": [0, 1, 0, 0, 0, 0, 0, 0, 0],
        "dokujo-tsushin": [0, 0, 1, 0, 0, 0, 0, 0, 0],
        "movie-enter": [0, 0, 0, 1, 0, 0, 0, 0, 0],
        "it-life-hack": [0, 0, 0, 0, 1, 0, 0, 0, 0],
        "kaden-channel": [0, 0, 0, 0, 0, 1, 0, 0, 0],
        "peachy": [0, 0, 0, 0, 0, 0, 1, 0, 0],
        "topic-news": [0, 0, 0, 0, 0, 0, 0, 1, 0],
        "livedoor-homme": [0, 0, 0, 0, 0, 0, 0, 0, 1],
    }
    return dic


def normalize(text):
    text = normalize_unicode(text)
    text = capitalize_symbols(text)
    return text


def normalize_unicode(text, form="NFKC"):
    normalized_text = unicodedata.normalize(form, text)
    normalized_text = re.sub(r"\s", " ", normalized_text)
    return normalized_text


def capitalize_symbols(text):
    # this function is often used for mecab morph analysis
    text = text.replace("・", "")
    table = str.maketrans(dict(zip("!?(),.:;/@%&[]", "！？（）、。：；／＠％＆［］")))
    capitalized_text = text.translate(table)
    return capitalized_text

## Dataset

In [None]:
class LivedoorDataset:
    def __init__(self, article, targets):
        self.article = article
        self.targets = targets
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
        self.category_dic = category_dict()

    def __len__(self):
        return len(self.article)

    def __getitem__(self, item):
        article = normalize(self.article[item])

        inputs = self.tokenizer.encode_plus(
            article, add_special_tokens=True
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        token_type_ids = inputs["token_type_ids"]

        padding_len = self.max_len - len(ids)
        if padding_len >= 0:
            ids = ids + [0] * padding_len
            mask = mask + [0] * padding_len
            token_type_ids = token_type_ids + [0] * padding_len
        else:
            half_len = int(self.max_len / 2)
            ids = ids[:half_len] + ids[-half_len:]
            mask = mask[:half_len] + mask[-half_len:]
            token_type_ids = token_type_ids[:half_len] + token_type_ids[-half_len:]

        return {
            "ids": torch.tensor(ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "targets": torch.tensor(
                self.category_dic.get(self.targets[item]), dtype=torch.long
            ),
            "orig_targets": self.targets[item],
        }

## Model

In [None]:
class AlbertBaseJapanese(nn.Module):
    def __init__(self):
        super(AlbertBaseJapanese, self).__init__()
        self.albert = transformers.AlbertModel.from_pretrained(ALBERT_PATH)
        self.drop = nn.Dropout(0.3)
        self.out = nn.Linear(768, 9)

    def forward(self, ids, mask, token_type_ids):
        _, pooled_output = self.albert(
            ids, attention_mask=mask, token_type_ids=token_type_ids
        )
        bo = self.drop(pooled_output)
        return self.out(bo)

## Engine

In [None]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets)


def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    tk0 = tqdm(data_loader, total=len(data_loader))
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        mask = d["mask"]
        token_type_ids = d["token_type_ids"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)

        loss = loss_fn(torch.softmax(outputs, dim=1), targets)
        loss.backward()
        optimizer.step()
        scheduler.step()


def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []

    for bi, d in enumerate(data_loader):
        ids = d["ids"]
        mask = d["mask"]
        token_type_ids = d["token_type_ids"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.long)

        outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)
        fin_targets.extend(targets.argmax(1).cpu().detach().numpy().tolist())
        fin_outputs.extend(outputs.argmax(1).cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

## Train

In [None]:
def run():
    df = pd.read_csv(TRAINING_FILE).fillna("none")

    df_train, df_valid = model_selection.train_test_split(
        df, test_size=0.1, random_state=42, stratify=df.category.values
    )

    df_train = df_train.reset_index(drop=True)
    df_valid = df_valid.reset_index(drop=True)

    train_dataset = LivedoorDataset(
        article=df_train.article.values, targets=df_train.category.values
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=TRAIN_BATCH_SIZE, num_workers=4
    )

    valid_dataset = LivedoorDataset(
        article=df_valid.article.values, targets=df_valid.category.values
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=VALID_BATCH_SIZE, num_workers=1
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AlbertBaseJapanese()
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.001,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    num_train_steps = int(len(df_train) / TRAIN_BATCH_SIZE * EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
    )

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model.to(device)


    best_accuracy = 0
    for epoch in range(EPOCHS):
        train_fn(train_data_loader, model, optimizer, device, scheduler)
        outputs, targets = eval_fn(valid_data_loader, model, device)
        accuracy = metrics.accuracy_score(targets, outputs)
        print(f"Accuracy Score = {accuracy}")
        if accuracy > best_accuracy:
            torch.save(model.state_dict(), MODEL_PATH)
            best_accuracy = accuracy
        print(f"epoch = {epoch}, best_accuracy = {best_accuracy}")

In [None]:
run()


  0%|          | 0/166 [00:00<?, ?it/s][A
  1%|          | 1/166 [00:01<05:01,  1.83s/it][A
  1%|          | 2/166 [00:03<04:36,  1.68s/it][A
  2%|▏         | 3/166 [00:04<04:17,  1.58s/it][A
  2%|▏         | 4/166 [00:05<04:04,  1.51s/it][A
  3%|▎         | 5/166 [00:07<03:54,  1.46s/it][A
  4%|▎         | 6/166 [00:08<03:47,  1.42s/it][A
  4%|▍         | 7/166 [00:09<03:43,  1.40s/it][A
  5%|▍         | 8/166 [00:11<03:40,  1.39s/it][A
  5%|▌         | 9/166 [00:12<03:37,  1.38s/it][A
  6%|▌         | 10/166 [00:13<03:34,  1.38s/it][A
  7%|▋         | 11/166 [00:15<03:32,  1.37s/it][A
  7%|▋         | 12/166 [00:16<03:30,  1.37s/it][A
  8%|▊         | 13/166 [00:18<03:29,  1.37s/it][A
  8%|▊         | 14/166 [00:19<03:28,  1.37s/it][A
  9%|▉         | 15/166 [00:20<03:27,  1.37s/it][A
 10%|▉         | 16/166 [00:22<03:26,  1.38s/it][A
 10%|█         | 17/166 [00:23<03:25,  1.38s/it][A
 11%|█         | 18/166 [00:24<03:23,  1.38s/it][A
 11%|█▏        | 19/166 [00:2

Accuracy Score = 0.8728813559322034
epoch = 0, best_accuracy = 0.8728813559322034



  1%|          | 1/166 [00:01<05:08,  1.87s/it][A
  1%|          | 2/166 [00:03<04:53,  1.79s/it][A
  2%|▏         | 3/166 [00:05<04:40,  1.72s/it][A
  2%|▏         | 4/166 [00:06<04:32,  1.68s/it][A
  3%|▎         | 5/166 [00:08<04:26,  1.65s/it][A
  4%|▎         | 6/166 [00:09<04:21,  1.63s/it][A
  4%|▍         | 7/166 [00:11<04:17,  1.62s/it][A
  5%|▍         | 8/166 [00:12<04:15,  1.61s/it][A
  5%|▌         | 9/166 [00:14<04:12,  1.61s/it][A
  6%|▌         | 10/166 [00:16<04:10,  1.60s/it][A
  7%|▋         | 11/166 [00:17<04:08,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:06,  1.60s/it][A
  8%|▊         | 13/166 [00:20<04:05,  1.60s/it][A
  8%|▊         | 14/166 [00:22<04:03,  1.60s/it][A
  9%|▉         | 15/166 [00:24<04:01,  1.60s/it][A
 10%|▉         | 16/166 [00:25<04:00,  1.60s/it][A
 10%|█         | 17/166 [00:27<03:58,  1.60s/it][A
 11%|█         | 18/166 [00:28<03:57,  1.60s/it][A
 11%|█▏        | 19/166 [00:30<03:55,  1.60s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9288135593220339
epoch = 1, best_accuracy = 0.9288135593220339



  1%|          | 1/166 [00:01<05:14,  1.91s/it][A
  1%|          | 2/166 [00:03<04:58,  1.82s/it][A
  2%|▏         | 3/166 [00:05<04:45,  1.75s/it][A
  2%|▏         | 4/166 [00:06<04:36,  1.71s/it][A
  3%|▎         | 5/166 [00:08<04:29,  1.68s/it][A
  4%|▎         | 6/166 [00:09<04:24,  1.65s/it][A
  4%|▍         | 7/166 [00:11<04:21,  1.64s/it][A
  5%|▍         | 8/166 [00:13<04:17,  1.63s/it][A
  5%|▌         | 9/166 [00:14<04:14,  1.62s/it][A
  6%|▌         | 10/166 [00:16<04:12,  1.62s/it][A
  7%|▋         | 11/166 [00:17<04:10,  1.62s/it][A
  7%|▋         | 12/166 [00:19<04:08,  1.61s/it][A
  8%|▊         | 13/166 [00:21<04:06,  1.61s/it][A
  8%|▊         | 14/166 [00:22<04:04,  1.61s/it][A
  9%|▉         | 15/166 [00:24<04:02,  1.61s/it][A
 10%|▉         | 16/166 [00:26<04:01,  1.61s/it][A
 10%|█         | 17/166 [00:27<03:59,  1.61s/it][A
 11%|█         | 18/166 [00:29<03:57,  1.60s/it][A
 11%|█▏        | 19/166 [00:30<03:55,  1.60s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9474576271186441
epoch = 2, best_accuracy = 0.9474576271186441



  1%|          | 1/166 [00:01<05:00,  1.82s/it][A
  1%|          | 2/166 [00:03<04:47,  1.75s/it][A
  2%|▏         | 3/166 [00:04<04:37,  1.70s/it][A
  2%|▏         | 4/166 [00:06<04:30,  1.67s/it][A
  3%|▎         | 5/166 [00:08<04:24,  1.64s/it][A
  4%|▎         | 6/166 [00:09<04:19,  1.62s/it][A
  4%|▍         | 7/166 [00:11<04:16,  1.61s/it][A
  5%|▍         | 8/166 [00:12<04:13,  1.60s/it][A
  5%|▌         | 9/166 [00:14<04:10,  1.60s/it][A
  6%|▌         | 10/166 [00:16<04:08,  1.60s/it][A
  7%|▋         | 11/166 [00:17<04:06,  1.59s/it][A
  7%|▋         | 12/166 [00:19<04:04,  1.59s/it][A
  8%|▊         | 13/166 [00:20<04:03,  1.59s/it][A
  8%|▊         | 14/166 [00:22<04:01,  1.59s/it][A
  9%|▉         | 15/166 [00:24<03:59,  1.59s/it][A
 10%|▉         | 16/166 [00:25<03:58,  1.59s/it][A
 10%|█         | 17/166 [00:27<03:56,  1.59s/it][A
 11%|█         | 18/166 [00:28<03:55,  1.59s/it][A
 11%|█▏        | 19/166 [00:30<03:53,  1.59s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9559322033898305
epoch = 3, best_accuracy = 0.9559322033898305



  1%|          | 1/166 [00:01<05:13,  1.90s/it][A
  1%|          | 2/166 [00:03<04:57,  1.81s/it][A
  2%|▏         | 3/166 [00:05<04:44,  1.74s/it][A
  2%|▏         | 4/166 [00:06<04:35,  1.70s/it][A
  3%|▎         | 5/166 [00:08<04:28,  1.67s/it][A
  4%|▎         | 6/166 [00:09<04:23,  1.65s/it][A
  4%|▍         | 7/166 [00:11<04:20,  1.64s/it][A
  5%|▍         | 8/166 [00:13<04:17,  1.63s/it][A
  5%|▌         | 9/166 [00:14<04:14,  1.62s/it][A
  6%|▌         | 10/166 [00:16<04:12,  1.62s/it][A
  7%|▋         | 11/166 [00:17<04:10,  1.62s/it][A
  7%|▋         | 12/166 [00:19<04:08,  1.61s/it][A
  8%|▊         | 13/166 [00:21<04:06,  1.61s/it][A
  8%|▊         | 14/166 [00:22<04:04,  1.61s/it][A
  9%|▉         | 15/166 [00:24<04:02,  1.61s/it][A
 10%|▉         | 16/166 [00:25<04:00,  1.61s/it][A
 10%|█         | 17/166 [00:27<03:58,  1.60s/it][A
 11%|█         | 18/166 [00:29<03:57,  1.60s/it][A
 11%|█▏        | 19/166 [00:30<03:54,  1.60s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9525423728813559
epoch = 4, best_accuracy = 0.9559322033898305



  1%|          | 1/166 [00:01<05:06,  1.86s/it][A
  1%|          | 2/166 [00:03<04:52,  1.78s/it][A
  2%|▏         | 3/166 [00:05<04:40,  1.72s/it][A
  2%|▏         | 4/166 [00:06<04:32,  1.68s/it][A
  3%|▎         | 5/166 [00:08<04:26,  1.65s/it][A
  4%|▎         | 6/166 [00:09<04:21,  1.63s/it][A
  4%|▍         | 7/166 [00:11<04:17,  1.62s/it][A
  5%|▍         | 8/166 [00:12<04:15,  1.62s/it][A
  5%|▌         | 9/166 [00:14<04:12,  1.61s/it][A
  6%|▌         | 10/166 [00:16<04:10,  1.61s/it][A
  7%|▋         | 11/166 [00:17<04:08,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:06,  1.60s/it][A
  8%|▊         | 13/166 [00:21<04:05,  1.61s/it][A
  8%|▊         | 14/166 [00:22<04:04,  1.61s/it][A
  9%|▉         | 15/166 [00:24<04:02,  1.61s/it][A
 10%|▉         | 16/166 [00:25<04:00,  1.60s/it][A
 10%|█         | 17/166 [00:27<03:59,  1.60s/it][A
 11%|█         | 18/166 [00:29<03:57,  1.61s/it][A
 11%|█▏        | 19/166 [00:30<03:55,  1.61s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9677966101694915
epoch = 5, best_accuracy = 0.9677966101694915



  1%|          | 1/166 [00:01<05:12,  1.89s/it][A
  1%|          | 2/166 [00:03<04:56,  1.81s/it][A
  2%|▏         | 3/166 [00:05<04:44,  1.74s/it][A
  2%|▏         | 4/166 [00:06<04:35,  1.70s/it][A
  3%|▎         | 5/166 [00:08<04:28,  1.67s/it][A
  4%|▎         | 6/166 [00:09<04:23,  1.65s/it][A
  4%|▍         | 7/166 [00:11<04:19,  1.63s/it][A
  5%|▍         | 8/166 [00:13<04:16,  1.62s/it][A
  5%|▌         | 9/166 [00:14<04:13,  1.61s/it][A
  6%|▌         | 10/166 [00:16<04:10,  1.61s/it][A
  7%|▋         | 11/166 [00:17<04:08,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:06,  1.60s/it][A
  8%|▊         | 13/166 [00:21<04:04,  1.60s/it][A
  8%|▊         | 14/166 [00:22<04:02,  1.59s/it][A
  9%|▉         | 15/166 [00:24<04:00,  1.59s/it][A
 10%|▉         | 16/166 [00:25<03:58,  1.59s/it][A
 10%|█         | 17/166 [00:27<03:56,  1.59s/it][A
 11%|█         | 18/166 [00:28<03:54,  1.59s/it][A
 11%|█▏        | 19/166 [00:30<03:53,  1.59s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.964406779661017
epoch = 6, best_accuracy = 0.9677966101694915



  1%|          | 1/166 [00:01<05:03,  1.84s/it][A
  1%|          | 2/166 [00:03<04:49,  1.77s/it][A
  2%|▏         | 3/166 [00:05<04:38,  1.71s/it][A
  2%|▏         | 4/166 [00:06<04:31,  1.68s/it][A
  3%|▎         | 5/166 [00:08<04:24,  1.65s/it][A
  4%|▎         | 6/166 [00:09<04:20,  1.63s/it][A
  4%|▍         | 7/166 [00:11<04:17,  1.62s/it][A
  5%|▍         | 8/166 [00:12<04:14,  1.61s/it][A
  5%|▌         | 9/166 [00:14<04:11,  1.60s/it][A
  6%|▌         | 10/166 [00:16<04:09,  1.60s/it][A
  7%|▋         | 11/166 [00:17<04:07,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:05,  1.59s/it][A
  8%|▊         | 13/166 [00:20<04:03,  1.59s/it][A
  8%|▊         | 14/166 [00:22<04:02,  1.59s/it][A
  9%|▉         | 15/166 [00:24<04:00,  1.59s/it][A
 10%|▉         | 16/166 [00:25<03:59,  1.59s/it][A
 10%|█         | 17/166 [00:27<03:57,  1.59s/it][A
 11%|█         | 18/166 [00:28<03:55,  1.59s/it][A
 11%|█▏        | 19/166 [00:30<03:54,  1.60s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9694915254237289
epoch = 7, best_accuracy = 0.9694915254237289



  1%|          | 1/166 [00:01<05:25,  1.97s/it][A
  1%|          | 2/166 [00:03<05:04,  1.86s/it][A
  2%|▏         | 3/166 [00:05<04:49,  1.78s/it][A
  2%|▏         | 4/166 [00:06<04:39,  1.73s/it][A
  3%|▎         | 5/166 [00:08<04:30,  1.68s/it][A
  4%|▎         | 6/166 [00:09<04:24,  1.65s/it][A
  4%|▍         | 7/166 [00:11<04:19,  1.63s/it][A
  5%|▍         | 8/166 [00:13<04:16,  1.62s/it][A
  5%|▌         | 9/166 [00:14<04:13,  1.61s/it][A
  6%|▌         | 10/166 [00:16<04:10,  1.60s/it][A
  7%|▋         | 11/166 [00:17<04:07,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:05,  1.59s/it][A
  8%|▊         | 13/166 [00:21<04:03,  1.59s/it][A
  8%|▊         | 14/166 [00:22<04:01,  1.59s/it][A
  9%|▉         | 15/166 [00:24<03:59,  1.59s/it][A
 10%|▉         | 16/166 [00:25<03:58,  1.59s/it][A
 10%|█         | 17/166 [00:27<03:56,  1.59s/it][A
 11%|█         | 18/166 [00:28<03:55,  1.59s/it][A
 11%|█▏        | 19/166 [00:30<03:53,  1.59s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9728813559322034
epoch = 8, best_accuracy = 0.9728813559322034



  1%|          | 1/166 [00:01<05:02,  1.84s/it][A
  1%|          | 2/166 [00:03<04:49,  1.76s/it][A
  2%|▏         | 3/166 [00:05<04:38,  1.71s/it][A
  2%|▏         | 4/166 [00:06<04:31,  1.67s/it][A
  3%|▎         | 5/166 [00:08<04:24,  1.64s/it][A
  4%|▎         | 6/166 [00:09<04:20,  1.63s/it][A
  4%|▍         | 7/166 [00:11<04:16,  1.62s/it][A
  5%|▍         | 8/166 [00:12<04:14,  1.61s/it][A
  5%|▌         | 9/166 [00:14<04:11,  1.60s/it][A
  6%|▌         | 10/166 [00:16<04:09,  1.60s/it][A
  7%|▋         | 11/166 [00:17<04:07,  1.60s/it][A
  7%|▋         | 12/166 [00:19<04:05,  1.59s/it][A
  8%|▊         | 13/166 [00:20<04:03,  1.59s/it][A
  8%|▊         | 14/166 [00:22<04:02,  1.59s/it][A
  9%|▉         | 15/166 [00:24<04:00,  1.59s/it][A
 10%|▉         | 16/166 [00:25<03:58,  1.59s/it][A
 10%|█         | 17/166 [00:27<03:57,  1.59s/it][A
 11%|█         | 18/166 [00:28<03:55,  1.59s/it][A
 11%|█▏        | 19/166 [00:30<03:53,  1.59s/it][A
 12%|█▏        | 20/

Accuracy Score = 0.9694915254237289
epoch = 9, best_accuracy = 0.9728813559322034
