In [1]:
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm import tqdm
from tabulate import tabulate
import os
import pandas as pd
from PIL import Image
import imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST

## Loss Definitions

In [2]:


# Create nn.module class called ArcFace that can be plugged in at the end of any backbone network

class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, num_classes, s=4, m=0.5):
        super(ArcFaceLayer, self).__init__()

        #Margin parameter and scaling factor
        self.s = s
        self.m = m

        #Input feature dimension and output number of classes
        self.in_features = in_features
        self.num_classes = num_classes

        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x, labels):

        #Cosine similarity between normalized input features and normalized weights
        cos_theta = F.linear(F.normalize(x), F.normalize(self.weight))

        # Clamp the values between -1 and 1 with 1e-7 for numerical stability
        cos_theta = torch.clamp(cos_theta, -1 + 1e-7, 1 - 1e-7)

        #Get the angle using arccos
        theta = torch.acos(cos_theta)

        #Add margin to the angle
        theta += self.m

        #Apply cosine to the angle to get adjusted cosine similarity
        adjusted_cos_theta = torch.cos(theta)

        #One hot encode labels
        one_hot = torch.zeros(cos_theta.size(), device=x.device)

        #Fill the one hot encoded tensor with 1s at the label indices
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)

        output = (one_hot * adjusted_cos_theta) + ((1.0 - one_hot) * cos_theta)
        output *= self.s

        return output

## Model Definitions

In [3]:


# Create a model class that combines the backbone network (Resnet18) and (optinally) the ArcFace layer
# The model class will be used for training and inference

# Create a base class called EmbeddingModel that contains the backbone network
class EmbeddingModel(nn.Module):
    def __init__(self, embedding_size):
        super(EmbeddingModel, self).__init__()
        self.embedding_size = embedding_size

        # Load a pre-trained backbone network (e.g., ResNet18)
        self.backbone = models.resnet18(weights= models.ResNet18_Weights.DEFAULT)

        # Freeze the backbone network
        for param in self.backbone.parameters():
            param.requires_grad = False


        # Replace the last fc layer of the backbone network with a one that outputs embedding_size
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, self.embedding_size)


# ArcFaceModel is a subclass of EmbeddingModel that adds an ArcFace layer
class ArcFaceModel(EmbeddingModel):
    def __init__(self, num_classes, embedding_size):
        super(ArcFaceModel, self).__init__(embedding_size)

        # Number of classes
        self.num_classes = num_classes

        # Create an new ArcFace layer
        self.arcface = ArcFaceLayer(self.embedding_size, self.num_classes)

    def forward(self, x, labels=None):

        # Get the output of the backbone network, so that we can use it to compute the embedding vectors
        x = self.backbone(x)

        # A cpu copy of the embedding vectors with detached gradients
        embedding_vectors = x.cpu().detach()

        output = self.arcface(x, labels)

        return output, embedding_vectors

    # Add a method to get the embedding vector
    def get_embedding(self, x):
        #no grad
        with torch.no_grad():
            return self.backbone(x)

class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers):
        super(VGGBlock, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
            in_channels = out_channels
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class VGG8ArcFace(nn.Module):
    def __init__(self, num_features, num_classes):
        super(VGG8ArcFace, self).__init__()
        self.block1 = VGGBlock(1, 16, 2)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.block2 = VGGBlock(16, 32, 2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.block3 = VGGBlock(32, 64, 2)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.batchnorm = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 3 * 3, num_features)
        self.batchnorm2 = nn.BatchNorm1d(num_features)
        self.arcface = ArcFaceLayer(num_features, num_classes)

    # Forward pass, returns the embedding vectors and the output of the ArcFace layer
    def forward(self, x, labels):
        x = self.block1(x)
        x = self.maxpool1(x)
        x = self.block2(x)
        x = self.maxpool2(x)
        x = self.block3(x)
        x = self.maxpool3(x)
        x = self.batchnorm(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.batchnorm2(x)
        embedding_vectors = x.cpu().detach()
        output = self.arcface(x, labels)

        return output, embedding_vectors

    def get_embedding(self, x):

        #no grad
        with torch.no_grad():
            x = self.block1(x)
            x = self.maxpool1(x)
            x = self.block2(x)
            x = self.maxpool2(x)
            x = self.block3(x)
            x = self.maxpool3(x)
            x = self.batchnorm(x)
            x = self.dropout(x)
            x = self.flatten(x)
            x = self.fc1(x)
            x = self.batchnorm2(x)
            return x



## Visualization Functions

In [4]:
def visualize_embeddings(all_embeddings, all_labels, visualize_val=False):

    # First plot train embeddings
    plot_embeddings(all_embeddings['train'], all_labels['train'], title_poststr="Training_Embeddings")

    if visualize_val:
        # Then plot val embeddings
        plot_embeddings(all_embeddings['val'], all_labels['val'], title_poststr="Validation_Embeddings")


# Function to plot the embeddings using matplotlib
def plot_embeddings(embeddings, labels, title_poststr = "Training_Embeddings"):

    #Get the number of epochs
    num_epochs = len(embeddings)

    #Num classes
    num_classes = len(np.unique(labels[0]))

    # Create a GIF of the embeddings

    #Create list to hold frames for train and val
    frames = []

    #Get the min and max values for the train embeddings across all epochs
    x_min = np.min([np.min(embeddings[epoch][:,0]) for epoch in range(num_epochs)])
    x_max = np.max([np.max(embeddings[epoch][:,0]) for epoch in range(num_epochs)])
    y_min = np.min([np.min(embeddings[epoch][:,1]) for epoch in range(num_epochs)])
    y_max = np.max([np.max(embeddings[epoch][:,1]) for epoch in range(num_epochs)])

    #max absolute value
    max_abs = max(abs(x_min), abs(x_max), abs(y_min), abs(y_max))

    # Set the limits of the axes
    #xlim = [x_min - 0.1, x_max + 0.1]
    #ylim = [y_min - 0.1, y_max + 0.1]
    xlim = [-max_abs - 0.1, max_abs + 0.1]
    ylim = [-max_abs - 0.1, max_abs + 0.1]

    #Create gif for train
    for epoch in range(num_epochs):
        epoch_embeddings = embeddings[epoch]
        epoch_labels = labels[epoch]

        #Create a scatter plot with colored labels
        fig = plt.figure(figsize=(10,10))
        ax = fig.add_subplot(111)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # Define a custom colormap excluding the color red
        colors = ['blue', 'green', 'yellow', 'cyan', 'magenta', 'orange', 'purple']
        cmap = ListedColormap(colors)

        # Add a unit circle to the plot
        circle = plt.Circle((0,0), 1, color='red', fill=False)
        ax.add_artist(circle)

        for class_id in range(num_classes):
            class_indices = np.where(epoch_labels == class_id)[0]
            ax.scatter(epoch_embeddings[class_indices, 0], epoch_embeddings[class_indices, 1], label=f"Class {class_id}", alpha=1)



            # Add axes to the plot with arrows
            ax.arrow(xlim[0], 0, xlim[1] - xlim[0], 0, length_includes_head=True, head_width=0.05, color='black')
            ax.arrow(0, ylim[0], 0, ylim[1] - ylim[0], length_includes_head=True, head_width=0.05, color='black')

        plt.title(f"Epoch [{epoch+1}/{num_epochs}] - {title_poststr}")
        plt.xlabel('Dimension 1')
        plt.ylabel('Dimension 2')
        #Set the limits of the axes
        fig.gca().set_xlim([xlim[0]-0.1, xlim[1]+0.1])
        fig.gca().set_ylim([ylim[0]-0.1, ylim[1]+0.1])

        # Add the unit circle to the legend
        ax.scatter([], [], color='red', label="Unit Circle")

        plt.legend()

        #Save the plot as an image
        plt.savefig(f"data/{title_poststr}_{epoch}.png")
        plt.close(fig)

        #Add the image to the list of frames
        frames.append(imageio.imread(f"data/{title_poststr}_{epoch}.png"))

    #Save the list of frames as a GIF
    imageio.mimsave(f"data/{title_poststr}.gif", frames, duration=125)


## Hyperparameters

In [5]:
# Hyperparameters
num_epochs = 20
batch_size = 4
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the random seed for reproducible results
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Training Loop

In [6]:
# Function to train the model
def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, save_embeddings=False, visualize_val=False):
    best_val_loss = float("inf")  # Initialize with a very high value
    model.train()

    train_loop = tqdm(total=len(train_loader), leave=False)
    train_loop.set_description(f"Epochs: 0/{num_epochs}")

    table_data = []  # Table data to store epoch, train loss, and val loss

    #Dictionary to store the embeddings and labels for visualization
    all_embeddings = {}
    all_labels = {}

    all_embeddings['train'] = []
    all_embeddings['val'] = []

    all_labels['train'] = []
    all_labels['val'] = []

    for epoch in range(num_epochs):
        total_loss = 0.0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output, train_batch_embeddings = model(data, target)  # Pass both data and target to the model's forward method

            # if epoch == 2:
            #     print(train_batch_embeddings)

            #Save the embeddings and labels for visualization
            if save_embeddings:
                if batch_idx == 0:
                    train_embeddings = train_batch_embeddings
                    train_labels = target.cpu().detach()
                else:
                    train_embeddings = torch.cat((train_embeddings, train_batch_embeddings), dim=0)
                    train_labels = torch.cat((train_labels, target.cpu().detach()), dim=0)

            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            train_loop.update(1)

        train_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch_idx, (val_data, val_target) in enumerate(tqdm(val_loader, leave=False)):
                val_data, val_target = val_data.to(device), val_target.to(device)
                val_output, val_batch_embeddings = model(val_data, val_target)  # Pass both val_data and val_target to the model's forward method

                if save_embeddings:
                    if batch_idx == 0:
                        val_embeddings = val_batch_embeddings
                        val_labels = val_target.cpu().detach()
                    else:
                        val_embeddings = torch.cat((val_embeddings, val_batch_embeddings), dim=0)
                        val_labels = torch.cat((val_labels, val_target.cpu().detach()), dim=0)

                val_loss += criterion(val_output, val_target).item()

        val_loss /= len(val_loader)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "models/best_model.pth")  # Save the model with the best validation loss

        #Save the embeddings and labels for visualization
        if save_embeddings:
            all_embeddings['train'].append(train_embeddings.numpy())
            all_embeddings['val'].append(val_embeddings.numpy())

            all_labels['train'].append(train_labels.numpy())
            all_labels['val'].append(val_labels.numpy())

        train_loop.set_description(f"Epochs: {epoch+1}/{num_epochs}")
        train_loop.set_postfix(train_loss=train_loss, val_loss=val_loss)

        train_loop.reset()

        # Store epoch, train loss, and val loss in table data
        table_data.append([epoch+1, train_loss, val_loss])

        # Create a table with epoch, train loss, and val loss
        table = tabulate(table_data, headers=["Epoch", "Train Loss", "Val Loss"], tablefmt="presto")

        # Print the table with updated data
        if epoch == 0:
            # First epoch, print the table normally
            print(table)
        else:
            # Subsequent epochs, use carriage return to overwrite the previous table
            table = tabulate(table_data, headers=["Epoch", "Train Loss", "Val Loss"], tablefmt="presto")
            print(table, end="\r")

    train_loop.close()

    print("Training complete!")

    #Create a visualization of the embeddings
    if save_embeddings:
        visualize_embeddings(all_embeddings, all_labels, visualize_val= visualize_val)




In [7]:

# Transforms to be applied on the MNIST images
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

#Load MNIST dataset
dataset = MNIST(root="data/", download=True, transform=transform)

#number of classes
num_classes = len(dataset.classes)


# Split the dataset into training, validation and test sets

# Define the split sizes
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Use the random_split function to split dataset into non-overlapping training, validation and test sets
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])



# Create data loaders with transforms, no shuffling for repeatable results
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Create an instance of ArcFaceModel
#model = ArcFaceModel(num_classes=num_classes, embedding_size=2).to(device)
model = VGG8ArcFace(num_classes=num_classes, num_features=2).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model using train() function
train(model, train_loader, val_loader, optimizer, criterion, num_epochs, save_embeddings=True, visualize_val=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 145459641.96it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 42563490.45it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 35945007.83it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20180644.88it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Epochs: 0/20: 100%|█████████▉| 11993/12000 [01:34<00:00, 145.62it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 340.96it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 314.63it/s][A
  7%|▋         | 102/1500 [00:00<00:04, 307.23it/s][A
  9%|▉         | 133/1500 [00:00<00:04, 304.27it/s][A
 11%|█         | 164/1500 [00:00<00:04, 303.53it/s][A
 13%|█▎        | 195/1500 [00:00<00:04, 304.24it/s][A
 15%|█▌        | 226/1500 [00:00<00:04, 297.65it/s][A
 17%|█▋        | 259/1500 [00:00<00:04, 306.99it/s][A
 19%|█▉        | 292/1500 [00:00<00:03, 313.04it/s][A
 22%|██▏       | 324/1500 [00:01<00:03, 313.87it/s][A
 24%|██▍       | 357/1500 [00:01<00:03, 316.81it/s][A
 26%|██▌       | 389/1500 [00:01<00:03, 305.28it/s][A
 28%|██▊       | 421/1500 [00:01<00:03, 308.30it/s][A
 30%|███       | 452/1500 [00:01<00:03, 306.79it/s][A
 32%|███▏      | 484/1500 [00:01<00:03, 308.64it/s][A
 35%|███▍      | 518/1500 [00:01<00:03, 315.19it/s][A
 37%|███▋

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |     1.8126


Epochs: 1/20: 100%|█████████▉| 11995/12000 [01:20<00:00, 160.49it/s, train_loss=2.91, val_loss=1.81]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 346.32it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 324.79it/s][A
  7%|▋         | 105/1500 [00:00<00:04, 332.16it/s][A
  9%|▉         | 140/1500 [00:00<00:04, 335.72it/s][A
 12%|█▏        | 174/1500 [00:00<00:04, 329.28it/s][A
 14%|█▍        | 208/1500 [00:00<00:03, 332.54it/s][A
 16%|█▌        | 242/1500 [00:00<00:03, 326.73it/s][A
 18%|█▊        | 275/1500 [00:00<00:03, 326.85it/s][A
 21%|██        | 308/1500 [00:00<00:03, 325.34it/s][A
 23%|██▎       | 341/1500 [00:01<00:03, 322.21it/s][A
 25%|██▍       | 374/1500 [00:01<00:03, 313.33it/s][A
 27%|██▋       | 406/1500 [00:01<00:03, 314.34it/s][A
 29%|██▉       | 439/1500 [00:01<00:03, 318.35it/s][A
 31%|███▏      | 472/1500 [00:01<00:03, 319.59it/s][A
 34%|███▎      | 504/1500 [00:01<00:03, 311.07it/s][A
 36%|███▌      | 538/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946

Epochs: 2/20: 100%|█████████▉| 11997/12000 [01:19<00:00, 160.71it/s, train_loss=1.34, val_loss=1.19]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 32/1500 [00:00<00:04, 312.66it/s][A
  4%|▍         | 64/1500 [00:00<00:04, 310.78it/s][A
  7%|▋         | 98/1500 [00:00<00:04, 321.73it/s][A
  9%|▊         | 131/1500 [00:00<00:04, 315.67it/s][A
 11%|█         | 163/1500 [00:00<00:04, 314.85it/s][A
 13%|█▎        | 196/1500 [00:00<00:04, 319.16it/s][A
 15%|█▌        | 228/1500 [00:00<00:04, 314.43it/s][A
 17%|█▋        | 260/1500 [00:00<00:03, 314.31it/s][A
 19%|█▉        | 292/1500 [00:00<00:03, 314.27it/s][A
 22%|██▏       | 326/1500 [00:01<00:03, 320.81it/s][A
 24%|██▍       | 359/1500 [00:01<00:03, 320.12it/s][A
 26%|██▌       | 392/1500 [00:01<00:03, 316.52it/s][A
 28%|██▊       | 426/1500 [00:01<00:03, 322.45it/s][A
 31%|███       | 459/1500 [00:01<00:03, 315.29it/s][A
 33%|███▎      | 492/1500 [00:01<00:03, 318.27it/s][A
 35%|███▌      | 525/1500 [00:01<

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243

Epochs: 3/20: 100%|██████████| 12000/12000 [01:25<00:00, 155.64it/s, train_loss=1.2, val_loss=1.14]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 32/1500 [00:00<00:04, 314.40it/s][A
  4%|▍         | 64/1500 [00:00<00:04, 305.89it/s][A
  6%|▋         | 95/1500 [00:00<00:04, 304.00it/s][A
  8%|▊         | 126/1500 [00:00<00:04, 304.31it/s][A
 10%|█         | 157/1500 [00:00<00:04, 294.67it/s][A
 13%|█▎        | 188/1500 [00:00<00:04, 299.11it/s][A
 15%|█▍        | 219/1500 [00:00<00:04, 301.07it/s][A
 17%|█▋        | 250/1500 [00:00<00:04, 298.52it/s][A
 19%|█▉        | 283/1500 [00:00<00:03, 307.57it/s][A
 21%|██        | 315/1500 [00:01<00:03, 309.39it/s][A
 23%|██▎       | 346/1500 [00:01<00:03, 307.52it/s][A
 25%|██▌       | 379/1500 [00:01<00:03, 313.18it/s][A
 27%|██▋       | 412/1500 [00:01<00:03, 314.49it/s][A
 30%|██▉       | 444/1500 [00:01<00:03, 312.67it/s][A
 32%|███▏      | 476/1500 [00:01<00:03, 305.03it/s][A
 34%|███▍      | 508/1500 [00:01<0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383

Epochs: 4/20: 100%|██████████| 12000/12000 [01:20<00:00, 160.96it/s, train_loss=1.14, val_loss=1.12]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 37/1500 [00:00<00:04, 362.00it/s][A
  5%|▍         | 74/1500 [00:00<00:04, 348.97it/s][A
  7%|▋         | 109/1500 [00:00<00:04, 343.34it/s][A
 10%|▉         | 144/1500 [00:00<00:04, 337.64it/s][A
 12%|█▏        | 179/1500 [00:00<00:03, 339.06it/s][A
 14%|█▍        | 213/1500 [00:00<00:03, 331.43it/s][A
 16%|█▋        | 247/1500 [00:00<00:03, 332.46it/s][A
 19%|█▊        | 281/1500 [00:00<00:03, 325.91it/s][A
 21%|██        | 314/1500 [00:00<00:03, 326.49it/s][A
 23%|██▎       | 347/1500 [00:01<00:03, 323.51it/s][A
 25%|██▌       | 380/1500 [00:01<00:03, 324.49it/s][A
 28%|██▊       | 413/1500 [00:01<00:03, 325.35it/s][A
 30%|██▉       | 446/1500 [00:01<00:03, 318.43it/s][A
 32%|███▏      | 480/1500 [00:01<00:03, 322.34it/s][A
 34%|███▍      | 514/1500 [00:01<00:03, 324.95it/s][A
 36%|███▋      | 547/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186

Epochs: 5/20: 100%|█████████▉| 11994/12000 [01:19<00:00, 157.81it/s, train_loss=1.11, val_loss=1.13]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 31/1500 [00:00<00:04, 307.66it/s][A
  4%|▍         | 63/1500 [00:00<00:04, 311.84it/s][A
  6%|▋         | 96/1500 [00:00<00:04, 319.82it/s][A
  9%|▊         | 129/1500 [00:00<00:04, 320.37it/s][A
 11%|█         | 162/1500 [00:00<00:04, 318.72it/s][A
 13%|█▎        | 194/1500 [00:00<00:04, 318.44it/s][A
 15%|█▌        | 228/1500 [00:00<00:03, 322.54it/s][A
 17%|█▋        | 261/1500 [00:00<00:03, 321.73it/s][A
 20%|█▉        | 294/1500 [00:00<00:03, 318.51it/s][A
 22%|██▏       | 327/1500 [00:01<00:03, 320.81it/s][A
 24%|██▍       | 360/1500 [00:01<00:03, 307.31it/s][A
 26%|██▌       | 393/1500 [00:01<00:03, 313.82it/s][A
 28%|██▊       | 425/1500 [00:01<00:03, 311.86it/s][A
 31%|███       | 458/1500 [00:01<00:03, 316.80it/s][A
 33%|███▎      | 491/1500 [00:01<00:03, 318.68it/s][A
 35%|███▍      | 523/1500 [00:01<

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975

Epochs: 6/20: 100%|█████████▉| 11992/12000 [01:19<00:00, 157.48it/s, train_loss=1.1, val_loss=1.1]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 36/1500 [00:00<00:04, 356.09it/s][A
  5%|▍         | 72/1500 [00:00<00:04, 327.51it/s][A
  7%|▋         | 105/1500 [00:00<00:04, 315.01it/s][A
  9%|▉         | 139/1500 [00:00<00:04, 321.36it/s][A
 11%|█▏        | 172/1500 [00:00<00:04, 314.92it/s][A
 14%|█▎        | 206/1500 [00:00<00:04, 322.87it/s][A
 16%|█▌        | 239/1500 [00:00<00:03, 324.12it/s][A
 18%|█▊        | 272/1500 [00:00<00:03, 324.77it/s][A
 20%|██        | 305/1500 [00:00<00:03, 323.99it/s][A
 23%|██▎       | 339/1500 [00:01<00:03, 327.37it/s][A
 25%|██▍       | 372/1500 [00:01<00:03, 326.20it/s][A
 27%|██▋       | 405/1500 [00:01<00:03, 324.56it/s][A
 29%|██▉       | 438/1500 [00:01<00:03, 315.88it/s][A
 31%|███▏      | 472/1500 [00:01<00:03, 321.85it/s][A
 34%|███▎      | 505/1500 [00:01<00:03, 324.02it/s][A
 36%|███▌      | 540/1500 [00:01<0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742

Epochs: 7/20: 100%|█████████▉| 11993/12000 [01:19<00:00, 160.03it/s, train_loss=1.09, val_loss=1.1]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 343.24it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 315.21it/s][A
  7%|▋         | 103/1500 [00:00<00:04, 317.92it/s][A
  9%|▉         | 136/1500 [00:00<00:04, 320.85it/s][A
 11%|█▏        | 170/1500 [00:00<00:04, 323.47it/s][A
 14%|█▎        | 203/1500 [00:00<00:03, 325.46it/s][A
 16%|█▌        | 237/1500 [00:00<00:03, 327.76it/s][A
 18%|█▊        | 271/1500 [00:00<00:03, 328.93it/s][A
 20%|██        | 305/1500 [00:00<00:03, 331.99it/s][A
 23%|██▎       | 339/1500 [00:01<00:03, 326.53it/s][A
 25%|██▍       | 373/1500 [00:01<00:03, 328.38it/s][A
 27%|██▋       | 406/1500 [00:01<00:03, 319.76it/s][A
 29%|██▉       | 439/1500 [00:01<00:03, 321.84it/s][A
 32%|███▏      | 473/1500 [00:01<00:03, 325.83it/s][A
 34%|███▎      | 506/1500 [00:01<00:03, 321.28it/s][A
 36%|███▌      | 539/1500 [00:01<

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093

Epochs: 8/20: 100%|█████████▉| 11984/12000 [01:20<00:00, 154.12it/s, train_loss=1.09, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 36/1500 [00:00<00:04, 357.24it/s][A
  5%|▍         | 72/1500 [00:00<00:04, 327.81it/s][A
  7%|▋         | 106/1500 [00:00<00:04, 329.49it/s][A
  9%|▉         | 140/1500 [00:00<00:04, 321.00it/s][A
 12%|█▏        | 174/1500 [00:00<00:04, 325.23it/s][A
 14%|█▍        | 207/1500 [00:00<00:03, 324.79it/s][A
 16%|█▌        | 241/1500 [00:00<00:03, 326.72it/s][A
 18%|█▊        | 274/1500 [00:00<00:03, 319.45it/s][A
 20%|██        | 307/1500 [00:00<00:03, 320.72it/s][A
 23%|██▎       | 341/1500 [00:01<00:03, 325.04it/s][A
 25%|██▍       | 374/1500 [00:01<00:03, 324.20it/s][A
 27%|██▋       | 407/1500 [00:01<00:03, 323.65it/s][A
 29%|██▉       | 441/1500 [00:01<00:03, 325.78it/s][A
 32%|███▏      | 474/1500 [00:01<00:03, 322.89it/s][A
 34%|███▍      | 507/1500 [00:01<00:03, 320.32it/s][A
 36%|███▌      | 540/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364

Epochs: 9/20: 100%|██████████| 12000/12000 [01:20<00:00, 155.61it/s, train_loss=1.08, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 33/1500 [00:00<00:04, 323.78it/s][A
  4%|▍         | 66/1500 [00:00<00:04, 325.54it/s][A
  7%|▋         | 100/1500 [00:00<00:04, 328.98it/s][A
  9%|▉         | 133/1500 [00:00<00:04, 323.70it/s][A
 11%|█         | 166/1500 [00:00<00:04, 307.04it/s][A
 13%|█▎        | 198/1500 [00:00<00:04, 310.71it/s][A
 15%|█▌        | 230/1500 [00:00<00:04, 309.42it/s][A
 17%|█▋        | 262/1500 [00:00<00:03, 311.04it/s][A
 20%|█▉        | 294/1500 [00:00<00:03, 305.54it/s][A
 22%|██▏       | 325/1500 [00:01<00:03, 306.31it/s][A
 24%|██▍       | 358/1500 [00:01<00:03, 310.54it/s][A
 26%|██▌       | 390/1500 [00:01<00:03, 311.79it/s][A
 28%|██▊       | 422/1500 [00:01<00:03, 311.91it/s][A
 30%|███       | 454/1500 [00:01<00:03, 306.56it/s][A
 32%|███▏      | 486/1500 [00:01<00:03, 308.23it/s][A
 35%|███▍      | 518/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653

Epochs: 10/20: 100%|█████████▉| 11991/12000 [01:20<00:00, 144.40it/s, train_loss=1.07, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 348.93it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 335.02it/s][A
  7%|▋         | 104/1500 [00:00<00:04, 312.37it/s][A
  9%|▉         | 136/1500 [00:00<00:04, 314.67it/s][A
 11%|█         | 168/1500 [00:00<00:04, 312.47it/s][A
 13%|█▎        | 200/1500 [00:00<00:04, 309.70it/s][A
 15%|█▌        | 232/1500 [00:00<00:04, 305.62it/s][A
 18%|█▊        | 264/1500 [00:00<00:04, 308.50it/s][A
 20%|█▉        | 295/1500 [00:00<00:03, 308.32it/s][A
 22%|██▏       | 326/1500 [00:01<00:03, 305.41it/s][A
 24%|██▍       | 357/1500 [00:01<00:03, 304.63it/s][A
 26%|██▌       | 389/1500 [00:01<00:03, 306.88it/s][A
 28%|██▊       | 420/1500 [00:01<00:03, 304.47it/s][A
 30%|███       | 453/1500 [00:01<00:03, 310.30it/s][A
 32%|███▏      | 485/1500 [00:01<00:03, 306.77it/s][A
 35%|███▍      | 518/1500 [00:0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252

Epochs: 11/20: 100%|█████████▉| 11996/12000 [01:20<00:00, 110.72it/s, train_loss=1.06, val_loss=1.07]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 349.85it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 335.25it/s][A
  7%|▋         | 104/1500 [00:00<00:04, 323.99it/s][A
  9%|▉         | 137/1500 [00:00<00:04, 322.13it/s][A
 11%|█▏        | 170/1500 [00:00<00:04, 320.90it/s][A
 14%|█▎        | 203/1500 [00:00<00:04, 320.12it/s][A
 16%|█▌        | 236/1500 [00:00<00:03, 316.37it/s][A
 18%|█▊        | 269/1500 [00:00<00:03, 318.68it/s][A
 20%|██        | 301/1500 [00:00<00:03, 315.89it/s][A
 22%|██▏       | 334/1500 [00:01<00:03, 319.84it/s][A
 24%|██▍       | 367/1500 [00:01<00:03, 322.85it/s][A
 27%|██▋       | 400/1500 [00:01<00:03, 322.04it/s][A
 29%|██▉       | 433/1500 [00:01<00:03, 323.94it/s][A
 31%|███       | 468/1500 [00:01<00:03, 329.29it/s][A
 33%|███▎      | 501/1500 [00:01<00:03, 328.68it/s][A
 36%|███▌      | 534/1500 [00:0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775

Epochs: 12/20: 100%|█████████▉| 11993/12000 [01:20<00:00, 116.00it/s, train_loss=1.06, val_loss=1.08]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 26/1500 [00:00<00:05, 258.12it/s][A
  3%|▎         | 52/1500 [00:00<00:06, 237.08it/s][A
  5%|▌         | 77/1500 [00:00<00:05, 239.31it/s][A
  7%|▋         | 102/1500 [00:00<00:06, 232.44it/s][A
  8%|▊         | 126/1500 [00:00<00:06, 227.57it/s][A
 10%|▉         | 149/1500 [00:00<00:06, 219.99it/s][A
 12%|█▏        | 173/1500 [00:00<00:05, 224.64it/s][A
 14%|█▎        | 203/1500 [00:00<00:05, 247.27it/s][A
 16%|█▌        | 236/1500 [00:00<00:04, 269.93it/s][A
 18%|█▊        | 266/1500 [00:01<00:04, 278.21it/s][A
 20%|█▉        | 294/1500 [00:01<00:04, 272.72it/s][A
 22%|██▏       | 325/1500 [00:01<00:04, 283.44it/s][A
 24%|██▎       | 354/1500 [00:01<00:04, 285.29it/s][A
 26%|██▌       | 386/1500 [00:01<00:03, 294.51it/s][A
 28%|██▊       | 416/1500 [00:01<00:03, 291.79it/s][A
 30%|██▉       | 447/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621

Epochs: 13/20: 100%|██████████| 12000/12000 [01:21<00:00, 107.48it/s, train_loss=1.06, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 25/1500 [00:00<00:06, 241.59it/s][A
  3%|▎         | 50/1500 [00:00<00:06, 218.43it/s][A
  5%|▌         | 76/1500 [00:00<00:06, 234.42it/s][A
  7%|▋         | 107/1500 [00:00<00:05, 262.23it/s][A
  9%|▉         | 141/1500 [00:00<00:04, 287.52it/s][A
 12%|█▏        | 173/1500 [00:00<00:04, 296.08it/s][A
 14%|█▎        | 205/1500 [00:00<00:04, 300.88it/s][A
 16%|█▌        | 238/1500 [00:00<00:04, 308.93it/s][A
 18%|█▊        | 269/1500 [00:00<00:04, 304.44it/s][A
 20%|██        | 300/1500 [00:01<00:04, 299.46it/s][A
 22%|██▏       | 332/1500 [00:01<00:03, 305.18it/s][A
 24%|██▍       | 365/1500 [00:01<00:03, 310.15it/s][A
 26%|██▋       | 397/1500 [00:01<00:03, 309.48it/s][A
 29%|██▊       | 428/1500 [00:01<00:03, 300.37it/s][A
 31%|███       | 460/1500 [00:01<00:03, 304.64it/s][A
 33%|███▎      | 491/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023

Epochs: 14/20: 100%|█████████▉| 11993/12000 [01:21<00:00, 107.96it/s, train_loss=1.05, val_loss=1.08]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  1%|▏         | 22/1500 [00:00<00:06, 219.20it/s][A
  3%|▎         | 50/1500 [00:00<00:05, 252.94it/s][A
  5%|▌         | 81/1500 [00:00<00:05, 277.50it/s][A
  7%|▋         | 109/1500 [00:00<00:05, 264.70it/s][A
  9%|▉         | 138/1500 [00:00<00:04, 273.08it/s][A
 11%|█▏        | 170/1500 [00:00<00:04, 286.73it/s][A
 13%|█▎        | 200/1500 [00:00<00:04, 289.11it/s][A
 15%|█▌        | 232/1500 [00:00<00:04, 297.92it/s][A
 17%|█▋        | 262/1500 [00:00<00:04, 293.90it/s][A
 20%|█▉        | 293/1500 [00:01<00:04, 297.02it/s][A
 22%|██▏       | 324/1500 [00:01<00:03, 299.71it/s][A
 24%|██▎       | 354/1500 [00:01<00:03, 297.27it/s][A
 26%|██▌       | 385/1500 [00:01<00:03, 299.92it/s][A
 28%|██▊       | 416/1500 [00:01<00:03, 292.85it/s][A
 30%|██▉       | 446/1500 [00:01<00:03, 293.67it/s][A
 32%|███▏      | 477/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531

Epochs: 15/20: 100%|█████████▉| 11991/12000 [01:21<00:00, 116.94it/s, train_loss=1.06, val_loss=1.07]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 33/1500 [00:00<00:04, 325.88it/s][A
  4%|▍         | 66/1500 [00:00<00:04, 317.53it/s][A
  7%|▋         | 98/1500 [00:00<00:04, 312.14it/s][A
  9%|▊         | 131/1500 [00:00<00:04, 316.53it/s][A
 11%|█         | 163/1500 [00:00<00:04, 314.67it/s][A
 13%|█▎        | 195/1500 [00:00<00:04, 312.84it/s][A
 15%|█▌        | 227/1500 [00:00<00:04, 308.23it/s][A
 17%|█▋        | 260/1500 [00:00<00:03, 312.39it/s][A
 19%|█▉        | 292/1500 [00:00<00:03, 307.39it/s][A
 22%|██▏       | 323/1500 [00:01<00:03, 299.58it/s][A
 24%|██▎       | 355/1500 [00:01<00:03, 305.54it/s][A
 26%|██▌       | 386/1500 [00:01<00:03, 304.73it/s][A
 28%|██▊       | 417/1500 [00:01<00:03, 305.39it/s][A
 30%|██▉       | 448/1500 [00:01<00:03, 303.61it/s][A
 32%|███▏      | 480/1500 [00:01<00:03, 306.67it/s][A
 34%|███▍      | 513/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531
      16 |      1.05592 |    1.08674

Epochs: 16/20: 100%|█████████▉| 11985/12000 [01:21<00:00, 153.83it/s, train_loss=1.06, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 34/1500 [00:00<00:04, 339.64it/s][A
  5%|▍         | 68/1500 [00:00<00:04, 303.04it/s][A
  7%|▋         | 99/1500 [00:00<00:04, 304.92it/s][A
  9%|▉         | 132/1500 [00:00<00:04, 311.50it/s][A
 11%|█         | 164/1500 [00:00<00:04, 310.97it/s][A
 13%|█▎        | 198/1500 [00:00<00:04, 317.60it/s][A
 15%|█▌        | 230/1500 [00:00<00:04, 313.67it/s][A
 18%|█▊        | 263/1500 [00:00<00:03, 316.64it/s][A
 20%|█▉        | 295/1500 [00:00<00:03, 309.63it/s][A
 22%|██▏       | 327/1500 [00:01<00:03, 312.26it/s][A
 24%|██▍       | 359/1500 [00:01<00:03, 307.66it/s][A
 26%|██▌       | 390/1500 [00:01<00:03, 301.76it/s][A
 28%|██▊       | 423/1500 [00:01<00:03, 309.61it/s][A
 30%|███       | 455/1500 [00:01<00:03, 311.49it/s][A
 32%|███▏      | 487/1500 [00:01<00:03, 313.78it/s][A
 35%|███▍      | 519/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531
      16 |      1.05592 |    1.08674
      17 |      1.05188 |    1.06806

Epochs: 17/20: 100%|█████████▉| 11998/12000 [01:21<00:00, 154.70it/s, train_loss=1.05, val_loss=1.07]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 35/1500 [00:00<00:04, 346.71it/s][A
  5%|▍         | 70/1500 [00:00<00:04, 323.35it/s][A
  7%|▋         | 103/1500 [00:00<00:04, 323.28it/s][A
  9%|▉         | 136/1500 [00:00<00:04, 325.49it/s][A
 11%|█▏        | 169/1500 [00:00<00:04, 319.05it/s][A
 13%|█▎        | 201/1500 [00:00<00:04, 310.37it/s][A
 16%|█▌        | 233/1500 [00:00<00:04, 309.64it/s][A
 18%|█▊        | 265/1500 [00:00<00:03, 310.98it/s][A
 20%|█▉        | 297/1500 [00:00<00:03, 304.87it/s][A
 22%|██▏       | 329/1500 [00:01<00:03, 308.50it/s][A
 24%|██▍       | 361/1500 [00:01<00:03, 309.29it/s][A
 26%|██▌       | 392/1500 [00:01<00:03, 309.21it/s][A
 28%|██▊       | 425/1500 [00:01<00:03, 313.51it/s][A
 30%|███       | 457/1500 [00:01<00:03, 313.19it/s][A
 33%|███▎      | 490/1500 [00:01<00:03, 316.98it/s][A
 35%|███▍      | 522/1500 [00:0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531
      16 |      1.05592 |    1.08674
      17 |      1.05188 |    1.06806
      18 |      1.051   |    1.09049

Epochs: 18/20: 100%|██████████| 12000/12000 [01:21<00:00, 150.46it/s, train_loss=1.05, val_loss=1.09]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 34/1500 [00:00<00:04, 333.34it/s][A
  5%|▍         | 68/1500 [00:00<00:04, 325.21it/s][A
  7%|▋         | 101/1500 [00:00<00:04, 322.94it/s][A
  9%|▉         | 134/1500 [00:00<00:04, 317.28it/s][A
 11%|█         | 166/1500 [00:00<00:04, 310.15it/s][A
 13%|█▎        | 198/1500 [00:00<00:04, 305.15it/s][A
 15%|█▌        | 229/1500 [00:00<00:04, 305.85it/s][A
 17%|█▋        | 261/1500 [00:00<00:04, 307.24it/s][A
 20%|█▉        | 293/1500 [00:00<00:03, 310.73it/s][A
 22%|██▏       | 325/1500 [00:01<00:03, 313.46it/s][A
 24%|██▍       | 358/1500 [00:01<00:03, 317.15it/s][A
 26%|██▌       | 390/1500 [00:01<00:03, 310.85it/s][A
 28%|██▊       | 422/1500 [00:01<00:03, 302.68it/s][A
 30%|███       | 453/1500 [00:01<00:03, 290.00it/s][A
 32%|███▏      | 483/1500 [00:01<00:03, 290.64it/s][A
 34%|███▍      | 516/1500 [00:0

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531
      16 |      1.05592 |    1.08674
      17 |      1.05188 |    1.06806
      18 |      1.051   |    1.09049
      19 |      1.04746 |    1.06095

Epochs: 19/20: 100%|██████████| 12000/12000 [01:20<00:00, 153.76it/s, train_loss=1.05, val_loss=1.06]
  0%|          | 0/1500 [00:00<?, ?it/s][A
  2%|▏         | 31/1500 [00:00<00:04, 306.01it/s][A
  4%|▍         | 64/1500 [00:00<00:04, 318.73it/s][A
  6%|▋         | 96/1500 [00:00<00:04, 318.00it/s][A
  9%|▊         | 128/1500 [00:00<00:04, 307.27it/s][A
 11%|█         | 161/1500 [00:00<00:04, 312.37it/s][A
 13%|█▎        | 193/1500 [00:00<00:04, 308.21it/s][A
 15%|█▌        | 226/1500 [00:00<00:04, 313.46it/s][A
 17%|█▋        | 258/1500 [00:00<00:04, 309.91it/s][A
 19%|█▉        | 291/1500 [00:00<00:03, 313.16it/s][A
 22%|██▏       | 323/1500 [00:01<00:03, 312.54it/s][A
 24%|██▎       | 355/1500 [00:01<00:03, 310.00it/s][A
 26%|██▌       | 387/1500 [00:01<00:03, 312.03it/s][A
 28%|██▊       | 419/1500 [00:01<00:03, 310.74it/s][A
 30%|███       | 451/1500 [00:01<00:03, 313.23it/s][A
 32%|███▏      | 483/1500 [00:01<00:03, 308.28it/s][A
 34%|███▍      | 514/1500 [00:01

   Epoch |   Train Loss |   Val Loss
---------+--------------+------------
       1 |      2.91419 |    1.8126
       2 |      1.34318 |    1.18946
       3 |      1.19823 |    1.14243
       4 |      1.14477 |    1.12383
       5 |      1.11325 |    1.13186
       6 |      1.10478 |    1.09975
       7 |      1.08924 |    1.09742
       8 |      1.08683 |    1.09093
       9 |      1.07598 |    1.09364
      10 |      1.07207 |    1.08653
      11 |      1.06442 |    1.07252
      12 |      1.05747 |    1.07775
      13 |      1.0615  |    1.08621
      14 |      1.05388 |    1.08023
      15 |      1.05785 |    1.06531
      16 |      1.05592 |    1.08674
      17 |      1.05188 |    1.06806
      18 |      1.051   |    1.09049
      19 |      1.04746 |    1.06095
      20 |      1.04686 |    1.05946Training complete!


  frames.append(imageio.imread(f"data/{title_poststr}_{epoch}.png"))
