In [1]:
!pip install transformers



In [2]:
import torch
import pandas as pd
from torch import nn
# from transformers import AutoModel, ElectraConfig
# from transformers import ElectraTokenizer
from transformers import BertTokenizer, BertModel

# ELECTRA_TOKENIZER = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
BERT_TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased')

# ELECTRA model with custom output layer.
# class Electra(nn.Module):
#     def __init__(self, output_size, device='cpu'):
#         super().__init__()
#         self.device = device
#         config = ElectraConfig.from_pretrained('google/electra-small-discriminator')
#         self.model = AutoModel.from_config(config).to(device)
#         self.output = nn.Linear(self.model.config.hidden_size, output_size).to(device)

#     # What happens when passing input into the model.
#     def forward(self, sents, locs):
#         # sents = torch.tensor(sents).to(self.device)
#         # print(self.electra(sents))
#         sents = self.model(sents)[0]
#         abbs = torch.stack([sents[n, idx, :] for n, idx in enumerate(locs)])  # (B * M)
#         return self.output(abbs)

class Bert(nn.Module):
    def __init__(self, output_size, device='cpu'):
        super().__init__()
        self.device = device
        self.model = BertModel.from_pretrained('bert-base-uncased').to(device)
        self.output = nn.Linear(self.model.config.hidden_size, output_size).to(device)

    # What happens when passing input into the model.
    def forward(self, sents, locs):
        # sents = torch.tensor(sents).to(self.device)
        # print(self.electra(sents))
        sents = self.model(sents)[0]
        abbs = torch.stack([sents[n, idx, :] for n, idx in enumerate(locs)])  # (B * M)
        return self.output(abbs)

In [3]:
class MedalDatasetTokenizer(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, dictionary_file, max_length=256, device='cpu'):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.device = device
        self.df = df
        label_df = pd.read_csv(dictionary_file, sep='\t', index_col = "EXPANSION")
        self.label_ser = label_df["LABEL"].squeeze()


    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idxs):
        
        # Code to remove entries that are larger than the max length size
        # batch_df = self.df.iloc[idxs]
        # # ic(batch_df['TEXT'].apply(lambda string: len(string.split())))
        # filter = batch_df['TEXT'].apply(lambda string: len(string.split()) < self.max_length).to_list()
        # # ic(idxs, filter)
        # idxs = list(compress(idxs, filter))


        batch_df = self.df.iloc[idxs]
        locs = batch_df['LOCATION'].values
        label_strings = batch_df['LABEL'].values
        labels = self.label_ser[label_strings].to_numpy()

        # ic(batch_df['TEXT'].tolist())
        # ic(type(batch_df['TEXT'].tolist()[0]))
        batch_encode = self.tokenizer(batch_df['TEXT'].tolist(), max_length=self.max_length, \
                    padding=True, truncation = True)
        
        # ic(batch_encode)
        # ic(type(batch_encode))

        tokenized = batch_encode['input_ids']
        # decoded = self.tokenizer.batch_decode(tokenized)
        # ic(decoded, len(decoded[0].split()))
        # ic(len(tokenized[0]), len(tokenized[1]), type(tokenized))
        return torch.tensor(tokenized).to(self.device), torch.tensor(locs).to(self.device), torch.tensor(labels).to(self.device)


In [4]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from datetime import datetime
import os
from time import time 

In [5]:
def train_loop(train_data, model, loss_fn, optimizer, train_loader, max = -1):
    
    size = len(train_data)

    # Switches model to training mode.
    model.train()
    
    # List of all values for the loss. Output at the end.
    loss_list = []
    # For computing accuracy
    correct = 0

    for batch, idx in enumerate(tqdm(train_loader)):
    # for batch, idx in enumerate(train_loader):
        # print(idx)
        X = train_data[idx][0]
        loc = train_data[idx][1]
        y = train_data[idx][2]

        # Compute prediction and loss
        pred = model(X, loc)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record the loss and accuracy
        loss_value = loss.item()
        batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        correct += batch_correct
        loss_list.append(loss_value)

        # Terminate early for testing purposes
        if max > 0:
            if batch >= max:
                print("\nMax iterations reached.")
                break

        # Minibatch loss
        if batch % 20 == 0 and batch != 0:
            print(f"\nBatch loss: {loss_value:>7f}")
    
    loss_list = np.array(loss_list)
    mean_loss = np.mean(loss_list)
    accuracy = correct/size
    print(f"Accuracy: {accuracy:>3f} | Average Loss: {mean_loss:>7f}\n")
    return mean_loss, accuracy

# Tests the model on the validation data
def valid_loop(valid_data, model, loss_fn, valid_loader, max = -1):

    # Switches model to evaluation mode
    model.eval()

    size = len(valid_data)
    loss_list = [] 
    correct = 0

    with torch.no_grad():
        for batch, idx in tqdm(enumerate(valid_loader)):
        # for batch, idx in enumerate(valid_loader):
            # idx = torch.tensor([id])
            # print(id, idx)
            X = valid_data[idx][0]
            loc = valid_data[idx][1]
            y = valid_data[idx][2]
            pred = model(X, loc)
            loss_list.append(loss_fn(pred, y).item())
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

            if max > 0:
                if batch > max:
                    break

    valid_loss = np.mean(np.array(loss_list))
    correct /= size
    print(f"Validation| \nAccuracy: {correct:>3f} | Average loss: {valid_loss:>8f} \n")
    return valid_loss, correct

In [6]:
num_abbr = "two_abbr"
folder = "drive/MyDrive/Bootcamp"

In [7]:
def save_model(model, save_dir):
    now = datetime.now()
    now_formatted = now.strftime("%d")+"_"+now.strftime("%H")+"_"+now.strftime("%M")
    torch.save(model, save_dir + f"{now_formatted}_two_abbr_Electra.pt")
    print("Model saved\n")

In [8]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev) 

# N_CPU_CORES = 2
# torch.set_num_threads(N_CPU_CORES)

# tokenizer = ELECTRA_TOKENIZER
tokenizer = BERT_TOKENIZER
max = -1

# Data
num_abbr = "two_abbr"
# folder = "datasets/medal"
train_df = pd.read_csv(f"{folder}/{num_abbr}/train.csv")
dictionary_file = f"{folder}/{num_abbr}/dict.txt"
output_size = 25 # Should be set to the size of the dictionary
train_data = MedalDatasetTokenizer(train_df, tokenizer, dictionary_file, device = device)

valid_df = pd.read_csv(f"{folder}/{num_abbr}/valid.csv")
valid_data = MedalDatasetTokenizer(valid_df, tokenizer, dictionary_file, device = device)

# Hyperparameters
learning_rate = 4e-5
batch_size = 16
epochs = 10

# Model
model = Bert(output_size, device)
#Electra(output_size=output_size, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
loss_fn = nn.CrossEntropyLoss()


# Train the model
for t in range(epochs):
    print(f"\nEpoch {t+1}\n-------------------------------")
    
    train_loader = DataLoader(
        range(len(train_data)), 
        shuffle=True, 
        batch_size=batch_size
    )

    valid_loader = DataLoader(
        range(len(valid_data)), 
        shuffle=True, 
        batch_size=batch_size
    )

    start = time()
    train_loss, train_accuracy = train_loop(train_data, model, loss_fn, optimizer, train_loader, max = max)
    end = time()
    print(f"Training time: {end-start}\n")

    valid_loss, valid_accuracy = valid_loop(valid_data, model, loss_fn, valid_loader, max = max)

    with open(f"{folder}/saves/loss.txt", "a") as file:
        file.writelines(f"\n{t+1},{train_loss},{train_accuracy},{valid_loss},{valid_accuracy}")

    save_model(model, f"{folder}/saves/{num_abbr}_epoch{t}")



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  0%|          | 0/271 [00:00<?, ?it/s]


Epoch 1
-------------------------------


  8%|▊         | 21/271 [00:19<03:58,  1.05it/s]


Batch loss: 3.310792


 15%|█▌        | 41/271 [00:39<03:36,  1.06it/s]


Batch loss: 2.934976


 23%|██▎       | 61/271 [00:57<03:13,  1.09it/s]


Batch loss: 2.109861


 30%|██▉       | 81/271 [01:15<02:53,  1.10it/s]


Batch loss: 1.824819


 37%|███▋      | 101/271 [01:34<02:39,  1.06it/s]


Batch loss: 1.351421


 45%|████▍     | 121/271 [01:52<02:23,  1.05it/s]


Batch loss: 1.101399


 52%|█████▏    | 141/271 [02:11<01:59,  1.09it/s]


Batch loss: 0.978357


 59%|█████▉    | 161/271 [02:29<01:40,  1.10it/s]


Batch loss: 0.948114


 67%|██████▋   | 181/271 [02:48<01:22,  1.09it/s]


Batch loss: 0.600164


 74%|███████▍  | 201/271 [03:06<01:04,  1.09it/s]


Batch loss: 0.526303


 82%|████████▏ | 221/271 [03:25<00:45,  1.09it/s]


Batch loss: 0.702705


 89%|████████▉ | 241/271 [03:43<00:27,  1.08it/s]


Batch loss: 0.552193


 96%|█████████▋| 261/271 [04:02<00:09,  1.07it/s]


Batch loss: 1.002547


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.583430 | Average Loss: 1.461768

Training time: 250.76734805107117



89it [00:37,  2.37it/s]


Validation| 
Accuracy: 0.797163 | Average loss: 0.645020 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 2
-------------------------------


  8%|▊         | 21/271 [00:19<03:53,  1.07it/s]


Batch loss: 0.624406


 15%|█▌        | 41/271 [00:38<03:31,  1.09it/s]


Batch loss: 0.222061


 23%|██▎       | 61/271 [00:56<03:11,  1.10it/s]


Batch loss: 0.076710


 30%|██▉       | 81/271 [01:15<02:53,  1.09it/s]


Batch loss: 0.431845


 37%|███▋      | 101/271 [01:33<02:38,  1.07it/s]


Batch loss: 0.131959


 45%|████▍     | 121/271 [01:52<02:20,  1.07it/s]


Batch loss: 0.367015


 52%|█████▏    | 141/271 [02:10<02:00,  1.08it/s]


Batch loss: 0.798499


 59%|█████▉    | 161/271 [02:29<01:42,  1.08it/s]


Batch loss: 0.296125


 67%|██████▋   | 181/271 [02:48<01:23,  1.08it/s]


Batch loss: 0.360743


 74%|███████▍  | 201/271 [03:06<01:04,  1.08it/s]


Batch loss: 0.175611


 82%|████████▏ | 221/271 [03:25<00:45,  1.09it/s]


Batch loss: 0.192083


 89%|████████▉ | 241/271 [03:43<00:27,  1.09it/s]


Batch loss: 0.134823


 96%|█████████▋| 261/271 [04:01<00:09,  1.08it/s]


Batch loss: 0.366375


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.866003 | Average Loss: 0.410680

Training time: 250.44839715957642



89it [00:37,  2.38it/s]


Validation| 
Accuracy: 0.843972 | Average loss: 0.453943 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 3
-------------------------------


  8%|▊         | 21/271 [00:19<03:55,  1.06it/s]


Batch loss: 0.366711


 15%|█▌        | 41/271 [00:38<03:34,  1.07it/s]


Batch loss: 0.290180


 23%|██▎       | 61/271 [00:57<03:13,  1.08it/s]


Batch loss: 0.378052


 30%|██▉       | 81/271 [01:15<02:56,  1.07it/s]


Batch loss: 0.057320


 37%|███▋      | 101/271 [01:34<02:36,  1.09it/s]


Batch loss: 0.255987


 45%|████▍     | 121/271 [01:52<02:20,  1.07it/s]


Batch loss: 0.101104


 52%|█████▏    | 141/271 [02:11<02:00,  1.07it/s]


Batch loss: 0.210586


 59%|█████▉    | 161/271 [02:29<01:43,  1.07it/s]


Batch loss: 0.516430


 67%|██████▋   | 181/271 [02:48<01:23,  1.07it/s]


Batch loss: 0.245214


 74%|███████▍  | 201/271 [03:06<01:04,  1.09it/s]


Batch loss: 0.453671


 82%|████████▏ | 221/271 [03:25<00:46,  1.07it/s]


Batch loss: 0.731229


 89%|████████▉ | 241/271 [03:43<00:27,  1.09it/s]


Batch loss: 0.246549


 96%|█████████▋| 261/271 [04:02<00:09,  1.07it/s]


Batch loss: 0.217380


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.923397 | Average Loss: 0.205631

Training time: 250.6691951751709



89it [00:37,  2.37it/s]


Validation| 
Accuracy: 0.873759 | Average loss: 0.366283 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 4
-------------------------------


  8%|▊         | 21/271 [00:19<03:52,  1.08it/s]


Batch loss: 0.233164


 15%|█▌        | 41/271 [00:38<03:35,  1.07it/s]


Batch loss: 0.207571


 23%|██▎       | 61/271 [00:57<03:13,  1.09it/s]


Batch loss: 0.017368


 30%|██▉       | 81/271 [01:15<02:54,  1.09it/s]


Batch loss: 0.012102


 37%|███▋      | 101/271 [01:33<02:37,  1.08it/s]


Batch loss: 0.072531


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.051926


 52%|█████▏    | 141/271 [02:10<01:59,  1.09it/s]


Batch loss: 0.043944


 59%|█████▉    | 161/271 [02:29<01:40,  1.10it/s]


Batch loss: 0.202715


 67%|██████▋   | 181/271 [02:47<01:22,  1.08it/s]


Batch loss: 0.057594


 74%|███████▍  | 201/271 [03:06<01:05,  1.07it/s]


Batch loss: 0.237028


 82%|████████▏ | 221/271 [03:25<00:46,  1.07it/s]


Batch loss: 0.124443


 89%|████████▉ | 241/271 [03:43<00:27,  1.08it/s]


Batch loss: 0.031766


 96%|█████████▋| 261/271 [04:02<00:09,  1.08it/s]


Batch loss: 0.014196


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.952789 | Average Loss: 0.127414

Training time: 250.57128357887268



89it [00:37,  2.38it/s]


Validation| 
Accuracy: 0.875177 | Average loss: 0.399051 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 5
-------------------------------


  8%|▊         | 21/271 [00:20<03:56,  1.06it/s]


Batch loss: 0.070622


 15%|█▌        | 41/271 [00:38<03:34,  1.07it/s]


Batch loss: 0.090524


 23%|██▎       | 61/271 [00:57<03:13,  1.09it/s]


Batch loss: 0.071106


 30%|██▉       | 81/271 [01:15<02:54,  1.09it/s]


Batch loss: 0.056957


 37%|███▋      | 101/271 [01:33<02:38,  1.08it/s]


Batch loss: 0.040615


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.008462


 52%|█████▏    | 141/271 [02:10<01:58,  1.10it/s]


Batch loss: 0.217199


 59%|█████▉    | 161/271 [02:29<01:42,  1.07it/s]


Batch loss: 0.023296


 67%|██████▋   | 181/271 [02:47<01:23,  1.07it/s]


Batch loss: 0.080455


 74%|███████▍  | 201/271 [03:06<01:05,  1.07it/s]


Batch loss: 0.248467


 82%|████████▏ | 221/271 [03:24<00:46,  1.08it/s]


Batch loss: 0.164121


 89%|████████▉ | 241/271 [03:43<00:27,  1.07it/s]


Batch loss: 0.117223


 96%|█████████▋| 261/271 [04:01<00:09,  1.09it/s]


Batch loss: 0.187166


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.960194 | Average Loss: 0.109548

Training time: 250.39299774169922



89it [00:37,  2.37it/s]


Validation| 
Accuracy: 0.883688 | Average loss: 0.398527 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 6
-------------------------------


  8%|▊         | 21/271 [00:19<03:55,  1.06it/s]


Batch loss: 0.069770


 15%|█▌        | 41/271 [00:38<03:31,  1.09it/s]


Batch loss: 0.017403


 23%|██▎       | 61/271 [00:57<03:16,  1.07it/s]


Batch loss: 0.019669


 30%|██▉       | 81/271 [01:15<02:53,  1.09it/s]


Batch loss: 0.050840


 37%|███▋      | 101/271 [01:34<02:37,  1.08it/s]


Batch loss: 0.020941


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.010726


 52%|█████▏    | 141/271 [02:10<01:58,  1.09it/s]


Batch loss: 0.027503


 59%|█████▉    | 161/271 [02:29<01:41,  1.09it/s]


Batch loss: 0.099136


 67%|██████▋   | 181/271 [02:47<01:24,  1.07it/s]


Batch loss: 0.013957


 74%|███████▍  | 201/271 [03:06<01:04,  1.08it/s]


Batch loss: 0.037093


 82%|████████▏ | 221/271 [03:24<00:46,  1.08it/s]


Batch loss: 0.063869


 89%|████████▉ | 241/271 [03:43<00:27,  1.08it/s]


Batch loss: 0.368743


 96%|█████████▋| 261/271 [04:01<00:09,  1.08it/s]


Batch loss: 0.138911


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.979866 | Average Loss: 0.071694

Training time: 250.30532598495483



89it [00:37,  2.38it/s]


Validation| 
Accuracy: 0.880142 | Average loss: 0.443946 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 7
-------------------------------


  8%|▊         | 21/271 [00:19<03:53,  1.07it/s]


Batch loss: 0.153193


 15%|█▌        | 41/271 [00:38<03:32,  1.08it/s]


Batch loss: 0.021352


 23%|██▎       | 61/271 [00:56<03:16,  1.07it/s]


Batch loss: 0.008294


 30%|██▉       | 81/271 [01:15<02:54,  1.09it/s]


Batch loss: 0.030085


 37%|███▋      | 101/271 [01:33<02:36,  1.08it/s]


Batch loss: 0.044540


 45%|████▍     | 121/271 [01:52<02:19,  1.07it/s]


Batch loss: 0.005900


 52%|█████▏    | 141/271 [02:10<02:00,  1.08it/s]


Batch loss: 0.023631


 59%|█████▉    | 161/271 [02:29<01:41,  1.08it/s]


Batch loss: 0.005066


 67%|██████▋   | 181/271 [02:47<01:22,  1.09it/s]


Batch loss: 0.014002


 74%|███████▍  | 201/271 [03:06<01:04,  1.09it/s]


Batch loss: 0.013588


 82%|████████▏ | 221/271 [03:24<00:46,  1.08it/s]


Batch loss: 0.012006


 89%|████████▉ | 241/271 [03:43<00:27,  1.08it/s]


Batch loss: 0.003857


 96%|█████████▋| 261/271 [04:01<00:09,  1.08it/s]


Batch loss: 0.005158


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.989586 | Average Loss: 0.037939

Training time: 250.1384608745575



89it [00:37,  2.38it/s]


Validation| 
Accuracy: 0.897872 | Average loss: 0.388607 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 8
-------------------------------


  8%|▊         | 21/271 [00:19<03:56,  1.06it/s]


Batch loss: 0.005655


 15%|█▌        | 41/271 [00:38<03:34,  1.07it/s]


Batch loss: 0.009291


 23%|██▎       | 61/271 [00:56<03:10,  1.10it/s]


Batch loss: 0.010747


 30%|██▉       | 81/271 [01:15<02:53,  1.09it/s]


Batch loss: 0.003477


 37%|███▋      | 101/271 [01:34<02:38,  1.07it/s]


Batch loss: 0.007125


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.263866


 52%|█████▏    | 141/271 [02:10<02:00,  1.08it/s]


Batch loss: 0.090840


 59%|█████▉    | 161/271 [02:29<01:41,  1.08it/s]


Batch loss: 0.003169


 67%|██████▋   | 181/271 [02:47<01:22,  1.09it/s]


Batch loss: 0.106448


 74%|███████▍  | 201/271 [03:06<01:05,  1.07it/s]


Batch loss: 0.021050


 82%|████████▏ | 221/271 [03:24<00:47,  1.06it/s]


Batch loss: 0.099168


 89%|████████▉ | 241/271 [03:43<00:27,  1.10it/s]


Batch loss: 0.034458


 96%|█████████▋| 261/271 [04:01<00:09,  1.08it/s]


Batch loss: 0.003969


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.993520 | Average Loss: 0.023542

Training time: 250.1896312236786



89it [00:37,  2.37it/s]


Validation| 
Accuracy: 0.905674 | Average loss: 0.375498 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 9
-------------------------------


  8%|▊         | 21/271 [00:19<03:56,  1.06it/s]


Batch loss: 0.021877


 15%|█▌        | 41/271 [00:38<03:34,  1.07it/s]


Batch loss: 0.009642


 23%|██▎       | 61/271 [00:57<03:14,  1.08it/s]


Batch loss: 0.004514


 30%|██▉       | 81/271 [01:15<02:55,  1.08it/s]


Batch loss: 0.047199


 37%|███▋      | 101/271 [01:33<02:36,  1.09it/s]


Batch loss: 0.004040


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.003882


 52%|█████▏    | 141/271 [02:11<02:01,  1.07it/s]


Batch loss: 0.011097


 59%|█████▉    | 161/271 [02:29<01:43,  1.07it/s]


Batch loss: 0.006221


 67%|██████▋   | 181/271 [02:48<01:21,  1.11it/s]


Batch loss: 0.002317


 74%|███████▍  | 201/271 [03:06<01:04,  1.09it/s]


Batch loss: 0.003598


 82%|████████▏ | 221/271 [03:25<00:45,  1.10it/s]


Batch loss: 0.008398


 89%|████████▉ | 241/271 [03:43<00:28,  1.07it/s]


Batch loss: 0.034789


 96%|█████████▋| 261/271 [04:02<00:09,  1.09it/s]


Batch loss: 0.002996


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.995603 | Average Loss: 0.018452

Training time: 250.613703250885



89it [00:37,  2.38it/s]


Validation| 
Accuracy: 0.902837 | Average loss: 0.407997 



  0%|          | 0/271 [00:00<?, ?it/s]

Model saved


Epoch 10
-------------------------------


  8%|▊         | 21/271 [00:20<03:57,  1.05it/s]


Batch loss: 0.005797


 15%|█▌        | 41/271 [00:38<03:30,  1.09it/s]


Batch loss: 0.002581


 23%|██▎       | 61/271 [00:56<03:13,  1.08it/s]


Batch loss: 0.003574


 30%|██▉       | 81/271 [01:15<02:55,  1.08it/s]


Batch loss: 0.010474


 37%|███▋      | 101/271 [01:33<02:38,  1.07it/s]


Batch loss: 0.003708


 45%|████▍     | 121/271 [01:52<02:18,  1.08it/s]


Batch loss: 0.004471


 52%|█████▏    | 141/271 [02:10<02:00,  1.08it/s]


Batch loss: 0.004792


 59%|█████▉    | 161/271 [02:29<01:40,  1.09it/s]


Batch loss: 0.059056


 67%|██████▋   | 181/271 [02:48<01:24,  1.07it/s]


Batch loss: 0.009720


 74%|███████▍  | 201/271 [03:06<01:04,  1.08it/s]


Batch loss: 0.222378


 82%|████████▏ | 221/271 [03:25<00:46,  1.07it/s]


Batch loss: 0.027238


 89%|████████▉ | 241/271 [03:43<00:27,  1.10it/s]


Batch loss: 0.085924


 96%|█████████▋| 261/271 [04:01<00:09,  1.08it/s]


Batch loss: 0.006113


100%|██████████| 271/271 [04:10<00:00,  1.08it/s]
0it [00:00, ?it/s]

Accuracy: 0.985651 | Average Loss: 0.050514

Training time: 250.39060235023499



89it [00:37,  2.37it/s]


Validation| 
Accuracy: 0.870922 | Average loss: 0.590501 

Model saved

