In [1]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt

from torch import nn
from torch.utils.data import TensorDataset, DataLoader

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16, vgg16_bn, resnet18, resnet34, resnet50, resnet101, resnet152

from sklearn.model_selection import train_test_split
import sklearn.metrics as skm
from sklearn.metrics import f1_score

import seaborn as sn

from os import listdir, path
from PIL import Image
from collections import defaultdict
import csv
import re
import os

from IPython.display import display, clear_output, Image as IPython_Image

In [2]:
torch.cuda.is_available()

True

# Data loading

In [3]:
IMAGE_PATH = '../data/images'
LABEL_PATH = '../data/annotations'

# GIVEN DATASET
MEAN = (0.43672, 0.40107, 0.36762)
STD = (0.30139, 0.28781, 0.29236)

# IMAGENET
#MEAN = (0.485, 0.456, 0.406)
#STD = (0.229, 0.224, 0.225)

# RESNET
#MEAN = (0.485, 0.456, 0.406)
#STD = (0.229, 0.224, 0.225)
       
# Define default pos_weights for nn.BCEWithLogitsLoss(pos_weights).
label_pos_weights_for_loss = np.array([209.52631579, 55.87203791, 58.40594059, 16.77777778, 44.80152672, 5.25, 25.14379085, 5.75675676, 33.09090909, 2.15540363, 5.51465798, 163.38356164, 119., 37.46153846], dtype=np.float32)

In [4]:
def number_of_classes():
    return len(listdir(LABEL_PATH))

In [5]:
def get_class_map():
    ret = {}

    i = 0
    for fname in listdir(LABEL_PATH):
        img_class, _ = fname.split('.')
        ret[img_class] = i
        i += 1

    return ret

In [6]:
def write_labels_to_csv(name_of_set, label_array):
    filepath = f'../data/labels_{name_of_set}.csv'
    
    label_arr = np.array(label_array).astype(int)

    # Save 2D numpy array to csv file
    np.savetxt(filepath, label_arr, delimiter=',', fmt='%d')

In [7]:
def get_data(train_fr=.6, max_images_per_class=1e9):
    # mapping from class names to integers
    class_map = get_class_map()

    # create a dictionary to hold our label vectors
    n_classes = len(class_map.keys())
    img_to_class = defaultdict(lambda: np.zeros(n_classes))

    # another dictionary to hold the actual image data
    img_to_data = dict()
    
    # loop through all the annotations
    for fname in listdir(LABEL_PATH):
        img_class, _ = fname.split('.')
        print(f'Reading class: {img_class}')
        
        # open the annotation file
        i = 0
        with open(f'{LABEL_PATH}/{fname}', 'r') as fh:

            # get image ids from annotation file
            img_ids = fh.read().splitlines()
            
            # gather the images with labels
            for i, img_id in enumerate(img_ids):
                
                # let's not process images unnecessarily
                if not img_id in img_to_data:

                    img_path = f'{IMAGE_PATH}/im{img_id}.jpg'
                    img = Image.open(img_path)

                    # append to dict
                    img_to_data[img_id] = img.convert('RGB')

                # get one-hot encoded vector of image classes
                img_classes = img_to_class[img_id]

                # add new class to image vector
                img_class_id = class_map[img_class]
                img_classes[img_class_id] = 1

                # store the updated vector back
                img_to_class[img_id] = img_classes

                if i >= max_images_per_class:
                    break

                i += 1

    # load also all the images that do not have any labels
    i = 0
    print(f'Reading images without labels..')
    for fname in listdir(IMAGE_PATH):
        m = re.match('im(\d+)', fname)
        img_id = m.group(1)

        if img_id not in img_to_data:
            img_path = f'{IMAGE_PATH}/im{img_id}.jpg'
            img = Image.open(img_path)

            # append to dict
            img_to_data[img_id] = img.convert('RGB')

            if i >= max_images_per_class:
                break

            i += 1

    print('Creating train/valid/test split..')
    # collect data to a single array
    X = []
    y = []
    for img_id in img_to_data.keys():
        X.append(img_to_data[img_id])
        y.append(img_to_class[img_id])

    X_train, X_tmp, y_train, y_tmp = train_test_split(X, y, train_size=train_fr, random_state=42)
    X_test, X_valid, y_test, y_valid = train_test_split(X_tmp, y_tmp, train_size=.5, test_size=.5, random_state=42)
    
    print('Done.')

    return X_train, X_valid, X_test, y_train, y_valid, y_test

In [8]:
class TransformingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, transforms=None):
        self.X = X
        self.y = y
        self.transforms = transforms

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        img_data = self.X[idx]
        img_class = self.y[idx]

        if transforms:
            img_data = self.transforms(img_data)

        return img_data, img_class

# Models

In [9]:
class TwoLayerModel(nn.Module):
    def __init__(self, n_input, n_hidden1, n_hidden2, n_classes):
        super().__init__()
        self.bs = bs
        self.input_layer = nn.Linear(n_input, n_hidden1)
        self.hidden1 = nn.Linear(n_hidden1, n_hidden2)
        self.hidden2 = nn.Linear(n_hidden2, n_classes)
        self.relu = nn.ReLU()
        self.bn0 = nn.BatchNorm1d(n_input)
        self.bn1 = nn.BatchNorm1d(n_hidden1)
        self.bn2 = nn.BatchNorm1d(n_hidden2)

    def forward(self, x):
        x = self.bn0(x)
        x = self.input_layer(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.hidden1(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.hidden2(x)

        return x

In [10]:
class OneLayerModel(nn.Module):
    def __init__(self, n_input, n_hidden, n_classes):
        super().__init__()

        self.input_layer = nn.Linear(n_input, n_hidden)
        self.hidden = nn.Linear(n_hidden, n_classes)
        self.relu = nn.ReLU()
        self.bn0 = nn.BatchNorm1d(n_input)
        self.bn1 = nn.BatchNorm1d(n_hidden)

    def forward(self, x):
        print(f'X.SHAPE: {x.shape}')
        x = self.bn0(x)
        x = self.input_layer(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.hidden(x)

        return x

In [11]:
class ConvNetModel(nn.Module):
    def __init__(self, n_classes, keep_prob=.5):
        super(ConvNetModel, self).__init__()
        # Common layers used multiple times
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=1-keep_prob)
        
        # Unique layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=3, stride=1, padding=1) #(n samples, channels, height, width)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=14, kernel_size=3, stride=1, padding=1)
        self.fc3 = nn.Linear(in_features=256*4*14, out_features=n_classes)
        
    def forward(self, x):
        x = x.reshape(-1, 3, 128, 128)
        
        out = self.conv1(x)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.dropout(out)
        
        out = out.reshape(out.size(0), -1)  # Flatten for FC
        out = self.fc3(out)
        return out    

# Training and evaluation functions

In [12]:
def evaluate(dataloader, model, criterion, device, threshold=0.5):
    model.eval()

    f1_scores = []
    losses = []

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)

            loss = criterion(y_pred, y)
            losses.append(loss)

            with torch.no_grad():
                score = f1_score(y.cpu() == 1, y_pred.cpu() > threshold, average='micro')
                f1_scores.append(score)

    model.train()

    return torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(f1_scores))

In [13]:
def train(train_dataloader, valid_dataloader, model, optimizer, scheduler, criterion, device, n_epochs=50, verbose=True):
    model.train()

    if verbose:
        fmt = '{:<5} {:12} {:12} {:<9} {:<9}'
        print(fmt.format('Epoch', 'Train loss', 'Valid loss', 'Train F1', 'Valid F1'))

    for epoch in range(n_epochs):
        
        for i, batch in enumerate(train_dataloader):
            X, y = batch
            
            X = X.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            y_pred = model(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            print(f'Epoch: {epoch+1}, iteration: {i+1}, loss: {loss}')

        if verbose:
            train_loss, train_score = evaluate(train_dataloader, model, criterion, device)
            valid_loss, valid_score = evaluate(valid_dataloader, model, criterion, device)

            fmt = '{:<5} {:03.10f} {:03.10f} {:02.7f} {:02.7f}'
            print(fmt.format(epoch, train_loss, valid_loss, train_score, valid_score))
            
    print('Done training!')

In [14]:
def visualize_predictions(model, device, dataloader, mean=MEAN, std=STD, n_to_show=3, threshold=0.5):
    
    class_to_label = { v: k for k, v in get_class_map().items() }
    
    # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/5
    inv_transform = transforms.Compose([
        transforms.Normalize(mean = -1 * np.multiply(mean, std), std=np.divide(1, std))
    ])
    
    n_shown = 0
    for i, batch in enumerate(dataloader):        
        X, y = batch
        X = X.to(device)

        y_pred_raw = model(X).cpu()
        y_pred = y_pred_raw > threshold 
        y = y == 1

        for i in range(len(y)):
            pred_classes = np.where(y_pred[i] == 1)[0]
            true_classes = np.where(y[i] == 1)[0]
            
            true_classes_str = ', '.join([class_to_label[i] for i in true_classes])
            pred_classes_str = ', '.join([class_to_label[i] for i in pred_classes])

            img = inv_transform(X[i].cpu())              # inverse transforms
            img = img.permute(2, 1, 0)                   # BGR -> RGB
            img = np.rot90(img, 3)
                                    
            plt.title(f'True: {true_classes_str}, Predictions: {pred_classes_str}')
            plt.imshow(img)
            plt.pause(0.001)

            n_shown += 1
            
            if n_shown >= n_to_show:
                return            

It would be nice to use same naming conventions everywhere. Might be a good idea to rename `y_hat` to `y_pred` or vice versa everywhere (it makes it easier to combine the notebooks). I don't care which way it is. Also the `Xs` and `ys` etc.

In [15]:
def predict_X(fr, threshold=0.5):
    
    y_hat = fr > threshold
    
    return y_hat if (np.sum(y_hat) > 0) else fr == np.max(fr)

def predict(model, device, dataloader):
    
    ys_all = []  # Array of np.array(14) 
    y_hats_all = []
    
    for i, batch in enumerate(dataloader):
        
        Xs, ys = batch
        Xs = model(Xs.to(device))
        y_hats = np.apply_along_axis(predict_X, axis=1, arr=Xs.cpu().detach().numpy())
        
        y_hats_all.extend(y_hat for y_hat in y_hats)
        ys_all.extend(y.numpy() for y in ys==1)

    return np.array(ys_all), np.array(y_hats_all)

def visualize_confusion_matrix(y_true, y_pred, labels, file_path):

    plt.ioff()
    
    # Get confusion matrices
    cn_tensor = skm.multilabel_confusion_matrix(y_true, y_pred)
    
    # Get precision, recall, f1-score
    scores = skm.classification_report(y_true, y_pred, output_dict=True)

    fig, ax = plt.subplots(nrows=5, ncols=3,sharey=True, figsize=(20, 20), 
                           gridspec_kw={'hspace': 0.3, 'wspace': 0.0})
    gn = ['True Neg','False Pos','False Neg','True Pos']
    n = cn_tensor[0].sum()
    
    # Loop all labels
    for i, cn_matrix in enumerate(cn_tensor):

        j, k = int(i/3), i%3
        
        # Annotations
        annot = np.asarray(
            ['{}\n{:0.0f}\n{:.2%}'.format(gn[i], x, x/n) for i, x in enumerate(cn_matrix.flatten())]
        ).reshape(2,2)
        
        # Plot heatmap
        sn.heatmap(cn_matrix, annot=annot, fmt='', cmap='Blues', ax=ax[j, k])
        
        # Precision, recall, f1-score
        title = '{}\nprec.={:.3}, rec.={:.3}, f1={:.3}'.format(
            labels[i], scores[str(i)]['precision'], scores[str(i)]['recall'], scores[str(i)]['f1-score'])
        ax[j, k].set_title(title)
        
    plt.savefig(file_path, bbox_inches='tight')
    plt.close()

# Do the magic!

In [16]:
if torch.cuda.is_available():
    print('Using GPU!')
    device = torch.device('cuda')
else:
    print('Using CPU')
    device = torch.device('cpu')

lr = 0.01
n_epochs = 20
bs = 64
n_classes = len(get_class_map().keys())

Using GPU!


## Create and save / load dataloaders from disk

In [34]:
max_images_per_class = int(1e9)
#max_images_per_class = 200

transformations = {
    'train': transforms.Compose([
        transforms.RandomApply([
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomRotation((-10, 10)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
            transforms.RandomGrayscale(p=1),
            transforms.RandomPerspective(),
        ], p=0.5),
        transforms.ToTensor(),                
        transforms.Normalize(mean=MEAN, std=STD)            
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ]),
}

if os.path.isfile(f'X_train_n{max_images_per_class}.dat'):
    X_train, X_valid, X_test, y_train, y_valid, y_test = get_data(max_images_per_class=max_images_per_class)
    torch.save(X_train, f'../data/X_train_n{max_images_per_class}.dat')
    torch.save(X_valid, f'../data/X_valid_n{max_images_per_class}.dat')
    torch.save(X_test, f'../data/X_test_n{max_images_per_class}.dat')
    torch.save(y_train, f'../data/y_train_n{max_images_per_class}.dat')
    torch.save(y_valid, f'../data/y_valid_n{max_images_per_class}.dat')
    torch.save(y_test, f'../data/y_test_n{max_images_per_class}.dat')
else:
    X_train = torch.load(f'../data/X_train_n{max_images_per_class}.dat')
    X_valid = torch.load(f'../data/X_valid_n{max_images_per_class}.dat')
    X_test = torch.load(f'../data/X_test_n{max_images_per_class}.dat')
    y_train = torch.load(f'../data/y_train_n{max_images_per_class}.dat')
    y_valid = torch.load(f'../data/y_valid_n{max_images_per_class}.dat')
    y_test = torch.load(f'../data/y_test_n{max_images_per_class}.dat')

train_dataloader = DataLoader(
    TransformingDataset(X_train, y_train, transforms=transformations['train']),
    shuffle=True,
    batch_size=bs)

valid_dataloader = DataLoader(
    TransformingDataset(X_valid, y_valid, transforms=transformations['valid']),
    shuffle=True,
    batch_size=bs)

test_dataloader = DataLoader(
    TransformingDataset(X_test, y_test, transforms=transformations['test']),
    shuffle=True,
    batch_size=bs)

## Pretrained models

NB: The mean and std in transformations most probably need to be the same as for VGG and RESNET. Not 100% sure about this. Something to investigate!

More models here: https://pytorch.org/docs/stable/torchvision/models.html

If the models do not start to converge, try lowering the learning rate!

_VGG16_

Currently getting validation f1 scores around 0.67. 

Now around 0.71 with one cycle policy.

Surprisingly after quick testing the vgg16_bn (with BatchNorm layers) did not do as well? Maybe more to investigate here.

In [18]:
if True:
    model = vgg16(pretrained=True).to(device)

    for param in model.parameters():
        param.requires_grad = False

    model.classifier = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(),
        nn.Linear(4096, 2048),
        nn.ReLU(),
        nn.Linear(2048, 14),
    ).to(device)

_RESNET_


In [19]:
if False:
    model = resnet18(pretrained=True).to(device)

    for layer in model.children():
        layer.requires_grad = False

    model.fc = nn.Linear(512, 14).to(device)

In [20]:
if False:
    model = resnet34(pretrained=True).to(device)

    for layer in model.children():
        layer.requires_grad = False

    model.fc = nn.Linear(512, 14).to(device)

In [21]:
if False:
    model = resnet50(pretrained=True).to(device)

    for layer in model.children():
        layer.requires_grad = False

    model.fc = nn.Linear(2048, 14).to(device)

In [22]:
if False:
    model = resnet101(pretrained=True).to(device)

    for layer in model.children():
        layer.requires_grad = False

    model.fc = nn.Linear(2048, 14).to(device)

In [23]:
if False:
    model = resnet152(pretrained=True).to(device)

    for layer in model.children():
        layer.requires_grad = False

    model.fc = nn.Linear(2048, 14).to(device)

## Train a model or load an existing model from disk

In [32]:
# loss function
pos_weight = torch.from_numpy(label_pos_weights_for_loss).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# learning rate and momentum will be overriden by the scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.01,
    base_momentum=0.5,
    max_momentum=0.99,
    steps_per_epoch=len(train_dataloader),
    epochs=n_epochs,
)

In [25]:
# Paths for model saving and loading

# Only state dictionary
#model_save_path = '../data/resnet-valid-acc-aug-0.70.pth'
#model_save_path = '../data/resnet-101-valid-acc-0.73.pth'
#model_save_path = '../data/resnet-valid-acc-aug-0.716.pth'
#model_save_path = '../data/vgg16-9epochs-valid-acc-0.66.pth'

# Whole model
#model_whole_save_path = '../data/vgg16-7epochs-valid-acc-0.703.pth'

In [26]:
# Plain model saving and loading (only state dictionary).
#torch.save(model.state_dict(), model_save_path)
#model.load_state_dict(torch.load(model_save_path, map_location=torch.device(device)))

In [27]:
# Save an entire model (not just the state dict)
#torch.save(model, model_whole_save_path)

In [28]:
# Load an entire model (not just the state dict)
#model = torch.load(model_whole_save_path)
#model.eval()

In [30]:
# This is for trying to train all the layers...
#for param in model.parameters():
#    param.requires_grad = True

In [31]:
%time train(train_dataloader, valid_dataloader, model, optimizer, scheduler, criterion, device, n_epochs=n_epochs)

Epoch Train loss   Valid loss   Train F1  Valid F1 
Epoch: 1, iteration: 1, loss: 4.914584707823632
Epoch: 1, iteration: 2, loss: 3.9467705518424125
Epoch: 1, iteration: 3, loss: 3.135240022370196
Epoch: 1, iteration: 4, loss: 2.4401800074936935
Epoch: 1, iteration: 5, loss: 1.8699341844621462
Epoch: 1, iteration: 6, loss: 1.6754820654931406
Epoch: 1, iteration: 7, loss: 1.575284567548322
Epoch: 1, iteration: 8, loss: 1.8948483989164728
Epoch: 1, iteration: 9, loss: 2.11453765530687
Epoch: 1, iteration: 10, loss: 1.6269044793327798
Epoch: 1, iteration: 11, loss: 1.639783476224051
Epoch: 1, iteration: 12, loss: 1.6024120041219778
Epoch: 1, iteration: 13, loss: 1.6891106679575971
Epoch: 1, iteration: 14, loss: 1.9678760716375612
Epoch: 1, iteration: 15, loss: 1.7651557487428178
Epoch: 1, iteration: 16, loss: 1.6569346953704487
Epoch: 1, iteration: 17, loss: 2.0637901058167225
Epoch: 1, iteration: 18, loss: 1.9704490387488762
Epoch: 1, iteration: 19, loss: 1.883867752456107
Epoch: 1, iter

Epoch: 1, iteration: 164, loss: 0.949254024001734
Epoch: 1, iteration: 165, loss: 1.0683517465718095
Epoch: 1, iteration: 166, loss: 1.0458701918302293
Epoch: 1, iteration: 167, loss: 0.9607339008968547
Epoch: 1, iteration: 168, loss: 1.4059721038962738
Epoch: 1, iteration: 169, loss: 1.0788369078929412
Epoch: 1, iteration: 170, loss: 1.0269718398876726
Epoch: 1, iteration: 171, loss: 1.0207748888242323
Epoch: 1, iteration: 172, loss: 1.0332984428003336
Epoch: 1, iteration: 173, loss: 1.0209423710650587
Epoch: 1, iteration: 174, loss: 1.1497414434849171
Epoch: 1, iteration: 175, loss: 0.9103804759156666
Epoch: 1, iteration: 176, loss: 0.8959105372573908
Epoch: 1, iteration: 177, loss: 1.0858162530922757
Epoch: 1, iteration: 178, loss: 0.9776646376154943
Epoch: 1, iteration: 179, loss: 1.0203664167423492
Epoch: 1, iteration: 180, loss: 0.8764746367031718
Epoch: 1, iteration: 181, loss: 1.0412620106737698
Epoch: 1, iteration: 182, loss: 1.1785063123382233
Epoch: 1, iteration: 183, loss: 

Epoch: 2, iteration: 139, loss: 1.0402844826356974
Epoch: 2, iteration: 140, loss: 0.8252338949699444
Epoch: 2, iteration: 141, loss: 1.1047490977805021
Epoch: 2, iteration: 142, loss: 0.9449351458133544
Epoch: 2, iteration: 143, loss: 0.7271217465300271
Epoch: 2, iteration: 144, loss: 1.0433470516805101
Epoch: 2, iteration: 145, loss: 0.8614337418433536
Epoch: 2, iteration: 146, loss: 1.104852122907366
Epoch: 2, iteration: 147, loss: 1.1520677948599893
Epoch: 2, iteration: 148, loss: 0.8012939914370004
Epoch: 2, iteration: 149, loss: 0.7228323916598851
Epoch: 2, iteration: 150, loss: 0.9562745221717455
Epoch: 2, iteration: 151, loss: 0.7459229751780839
Epoch: 2, iteration: 152, loss: 0.7648956749164342
Epoch: 2, iteration: 153, loss: 0.7218887192073155
Epoch: 2, iteration: 154, loss: 1.2289078662217843
Epoch: 2, iteration: 155, loss: 0.8688317506553057
Epoch: 2, iteration: 156, loss: 0.8149393392912591
Epoch: 2, iteration: 157, loss: 0.865595392033135
Epoch: 2, iteration: 158, loss: 0

Epoch: 3, iteration: 114, loss: 0.5569904679769706
Epoch: 3, iteration: 115, loss: 0.6588393509521553
Epoch: 3, iteration: 116, loss: 0.9458604736624929
Epoch: 3, iteration: 117, loss: 0.9684875540335719
Epoch: 3, iteration: 118, loss: 0.803687476386638
Epoch: 3, iteration: 119, loss: 0.8382183288773241
Epoch: 3, iteration: 120, loss: 0.878649528662636
Epoch: 3, iteration: 121, loss: 1.2556702101666368
Epoch: 3, iteration: 122, loss: 0.5970130349286198
Epoch: 3, iteration: 123, loss: 0.7684471997838566
Epoch: 3, iteration: 124, loss: 0.6326481187918888
Epoch: 3, iteration: 125, loss: 0.7971529117723904
Epoch: 3, iteration: 126, loss: 0.7631923839950706
Epoch: 3, iteration: 127, loss: 0.8006917042985962
Epoch: 3, iteration: 128, loss: 0.9503026222402257
Epoch: 3, iteration: 129, loss: 0.7904448613524314
Epoch: 3, iteration: 130, loss: 0.8003204659359022
Epoch: 3, iteration: 131, loss: 0.6599534224103303
Epoch: 3, iteration: 132, loss: 0.7249489970298317
Epoch: 3, iteration: 133, loss: 0

Epoch: 4, iteration: 89, loss: 0.5896212477546653
Epoch: 4, iteration: 90, loss: 0.6702012988954887
Epoch: 4, iteration: 91, loss: 1.1159890027933652
Epoch: 4, iteration: 92, loss: 0.7426147341321448
Epoch: 4, iteration: 93, loss: 0.6052481667338272
Epoch: 4, iteration: 94, loss: 0.542508395107798
Epoch: 4, iteration: 95, loss: 0.5804809698872203
Epoch: 4, iteration: 96, loss: 0.7549133339392217
Epoch: 4, iteration: 97, loss: 0.5990464493254589
Epoch: 4, iteration: 98, loss: 0.6040594646384376
Epoch: 4, iteration: 99, loss: 0.5951373934236982
Epoch: 4, iteration: 100, loss: 0.609391525429277
Epoch: 4, iteration: 101, loss: 0.7228131778459886
Epoch: 4, iteration: 102, loss: 0.7657952354171366
Epoch: 4, iteration: 103, loss: 0.7946066972078332
Epoch: 4, iteration: 104, loss: 0.5414203296497349
Epoch: 4, iteration: 105, loss: 0.6024954210127319
Epoch: 4, iteration: 106, loss: 0.6593118655169973
Epoch: 4, iteration: 107, loss: 0.943259881424324
Epoch: 4, iteration: 108, loss: 0.78146986493

Epoch: 5, iteration: 63, loss: 0.49750453657371424
Epoch: 5, iteration: 64, loss: 0.6434869839697513
Epoch: 5, iteration: 65, loss: 0.49107072875504887
Epoch: 5, iteration: 66, loss: 0.50737154505866
Epoch: 5, iteration: 67, loss: 0.4505584738514699
Epoch: 5, iteration: 68, loss: 0.48709200266742725
Epoch: 5, iteration: 69, loss: 0.4837890643134233
Epoch: 5, iteration: 70, loss: 0.4903826783704982
Epoch: 5, iteration: 71, loss: 0.7235780757071031
Epoch: 5, iteration: 72, loss: 0.516533898116847
Epoch: 5, iteration: 73, loss: 0.5878673244288309
Epoch: 5, iteration: 74, loss: 0.49283216506253696
Epoch: 5, iteration: 75, loss: 0.4326770623949226
Epoch: 5, iteration: 76, loss: 0.5158918023575947
Epoch: 5, iteration: 77, loss: 0.7085097744321526
Epoch: 5, iteration: 78, loss: 0.49495649719806417
Epoch: 5, iteration: 79, loss: 0.48268645714462494
Epoch: 5, iteration: 80, loss: 0.5174454069944215
Epoch: 5, iteration: 81, loss: 0.7879895487020233
Epoch: 5, iteration: 82, loss: 0.70282101846274

Epoch: 6, iteration: 37, loss: 0.6779076845987448
Epoch: 6, iteration: 38, loss: 0.6814908068834233
Epoch: 6, iteration: 39, loss: 0.5215174024044267
Epoch: 6, iteration: 40, loss: 0.5154773912391152
Epoch: 6, iteration: 41, loss: 0.637303850241451
Epoch: 6, iteration: 42, loss: 0.5265767602510394
Epoch: 6, iteration: 43, loss: 0.4871514696357766
Epoch: 6, iteration: 44, loss: 0.5058791914914303
Epoch: 6, iteration: 45, loss: 0.4269416314353955
Epoch: 6, iteration: 46, loss: 0.6389571966797378
Epoch: 6, iteration: 47, loss: 0.45195459877797917
Epoch: 6, iteration: 48, loss: 0.60176156068293
Epoch: 6, iteration: 49, loss: 0.5786630670972606
Epoch: 6, iteration: 50, loss: 0.6117055827926013
Epoch: 6, iteration: 51, loss: 0.4451970976686408
Epoch: 6, iteration: 52, loss: 0.5473221552614641
Epoch: 6, iteration: 53, loss: 0.446041139710984
Epoch: 6, iteration: 54, loss: 0.5263281613097425
Epoch: 6, iteration: 55, loss: 0.8003531495464996
Epoch: 6, iteration: 56, loss: 0.49732447310678923
Ep

Epoch: 7, iteration: 11, loss: 0.4986613292729313
Epoch: 7, iteration: 12, loss: 0.5208348188584397
Epoch: 7, iteration: 13, loss: 0.47933857340701636
Epoch: 7, iteration: 14, loss: 0.47794240684265993
Epoch: 7, iteration: 15, loss: 0.5089461534915893
Epoch: 7, iteration: 16, loss: 0.4864184156733258
Epoch: 7, iteration: 17, loss: 0.41407401932349847
Epoch: 7, iteration: 18, loss: 0.5560965479645984
Epoch: 7, iteration: 19, loss: 0.45658893665153966
Epoch: 7, iteration: 20, loss: 0.41607472468559764
Epoch: 7, iteration: 21, loss: 0.7646666202775767
Epoch: 7, iteration: 22, loss: 0.4218218076260419
Epoch: 7, iteration: 23, loss: 0.5827527108700359
Epoch: 7, iteration: 24, loss: 0.5290346342583229
Epoch: 7, iteration: 25, loss: 0.5447067745404413
Epoch: 7, iteration: 26, loss: 0.8898965934006347
Epoch: 7, iteration: 27, loss: 0.48746218833134575
Epoch: 7, iteration: 28, loss: 0.41418480809648256
Epoch: 7, iteration: 29, loss: 0.5127221550707113
Epoch: 7, iteration: 30, loss: 0.5537824227

Epoch: 7, iteration: 174, loss: 0.5351560288595812
Epoch: 7, iteration: 175, loss: 0.7450569867026064
Epoch: 7, iteration: 176, loss: 0.4092973911999732
Epoch: 7, iteration: 177, loss: 0.6215763692788436
Epoch: 7, iteration: 178, loss: 0.4391506878401634
Epoch: 7, iteration: 179, loss: 0.42244305133372395
Epoch: 7, iteration: 180, loss: 0.6044911892740195
Epoch: 7, iteration: 181, loss: 0.5049962126855475
Epoch: 7, iteration: 182, loss: 0.4737195047895154
Epoch: 7, iteration: 183, loss: 0.6914322327209175
Epoch: 7, iteration: 184, loss: 0.6647816340570439
Epoch: 7, iteration: 185, loss: 0.4646527841975526
Epoch: 7, iteration: 186, loss: 0.5476492527042052
Epoch: 7, iteration: 187, loss: 0.45504077030821105
Epoch: 7, iteration: 188, loss: 0.8028680521423716
6     0.5192239959 0.7885950211 0.5411207 0.5740912
Epoch: 8, iteration: 1, loss: 0.4475812198917814
Epoch: 8, iteration: 2, loss: 0.48479223462466564
Epoch: 8, iteration: 3, loss: 0.5317713673606826
Epoch: 8, iteration: 4, loss: 0.4

Epoch: 8, iteration: 148, loss: 0.46192110989812535
Epoch: 8, iteration: 149, loss: 0.45768612896627775
Epoch: 8, iteration: 150, loss: 0.4603024811304586
Epoch: 8, iteration: 151, loss: 0.4890223933345283
Epoch: 8, iteration: 152, loss: 0.46785545872071305
Epoch: 8, iteration: 153, loss: 0.5143539436546349
Epoch: 8, iteration: 154, loss: 0.3844987011426876
Epoch: 8, iteration: 155, loss: 0.5268120180883304
Epoch: 8, iteration: 156, loss: 0.44107733703232166
Epoch: 8, iteration: 157, loss: 0.4113772847319699
Epoch: 8, iteration: 158, loss: 0.7616812769424011
Epoch: 8, iteration: 159, loss: 0.3749434878361752
Epoch: 8, iteration: 160, loss: 0.5675851886357907
Epoch: 8, iteration: 161, loss: 0.4102415702925585
Epoch: 8, iteration: 162, loss: 0.42774084155324116
Epoch: 8, iteration: 163, loss: 0.6666795711204465
Epoch: 8, iteration: 164, loss: 0.5131878603312665
Epoch: 8, iteration: 165, loss: 0.4267896127755406
Epoch: 8, iteration: 166, loss: 0.514734154670415
Epoch: 8, iteration: 167, l

Epoch: 9, iteration: 122, loss: 0.47111054308356726
Epoch: 9, iteration: 123, loss: 0.42407564241415496
Epoch: 9, iteration: 124, loss: 0.42196378737398366
Epoch: 9, iteration: 125, loss: 0.5186157606885492
Epoch: 9, iteration: 126, loss: 0.474093338152226
Epoch: 9, iteration: 127, loss: 0.4253848913074156
Epoch: 9, iteration: 128, loss: 0.4891185123558151
Epoch: 9, iteration: 129, loss: 0.5338315661274324
Epoch: 9, iteration: 130, loss: 0.5159149615596453
Epoch: 9, iteration: 131, loss: 0.4376268107985358
Epoch: 9, iteration: 132, loss: 0.5017032505344869
Epoch: 9, iteration: 133, loss: 0.472957597207263
Epoch: 9, iteration: 134, loss: 0.42954748799721
Epoch: 9, iteration: 135, loss: 0.4156225756696374
Epoch: 9, iteration: 136, loss: 0.5366255639975442
Epoch: 9, iteration: 137, loss: 0.47356121623699676
Epoch: 9, iteration: 138, loss: 0.41453123382481843
Epoch: 9, iteration: 139, loss: 0.4602502801254708
Epoch: 9, iteration: 140, loss: 0.3881818287950515
Epoch: 9, iteration: 141, loss

Epoch: 10, iteration: 93, loss: 0.5270504934791018
Epoch: 10, iteration: 94, loss: 0.5031026273764053
Epoch: 10, iteration: 95, loss: 0.35849778162838
Epoch: 10, iteration: 96, loss: 0.37118085176833754
Epoch: 10, iteration: 97, loss: 0.5963851067352671
Epoch: 10, iteration: 98, loss: 0.35065987880010857
Epoch: 10, iteration: 99, loss: 0.38238194844220724
Epoch: 10, iteration: 100, loss: 0.5646205029810261
Epoch: 10, iteration: 101, loss: 0.3402454305072163
Epoch: 10, iteration: 102, loss: 0.4020664223131729
Epoch: 10, iteration: 103, loss: 0.38414151638856736
Epoch: 10, iteration: 104, loss: 0.5010704923739349
Epoch: 10, iteration: 105, loss: 0.4114386836609968
Epoch: 10, iteration: 106, loss: 0.406437244452797
Epoch: 10, iteration: 107, loss: 0.3879489187477521
Epoch: 10, iteration: 108, loss: 0.39984381494968846
Epoch: 10, iteration: 109, loss: 0.3743705843968669
Epoch: 10, iteration: 110, loss: 0.3554823501885932
Epoch: 10, iteration: 111, loss: 0.3029320860464304
Epoch: 10, iterat

Epoch: 11, iteration: 63, loss: 0.4827612217861664
Epoch: 11, iteration: 64, loss: 0.3816790456322395
Epoch: 11, iteration: 65, loss: 0.3865169274482141
Epoch: 11, iteration: 66, loss: 0.362991348922292
Epoch: 11, iteration: 67, loss: 0.40683332119952953
Epoch: 11, iteration: 68, loss: 0.3820373730276395
Epoch: 11, iteration: 69, loss: 0.41562569220410667
Epoch: 11, iteration: 70, loss: 0.3754015170162292
Epoch: 11, iteration: 71, loss: 0.40867556357142626
Epoch: 11, iteration: 72, loss: 0.47880078581011565
Epoch: 11, iteration: 73, loss: 0.31969254255385193
Epoch: 11, iteration: 74, loss: 0.3610156279294381
Epoch: 11, iteration: 75, loss: 0.4016128714234519
Epoch: 11, iteration: 76, loss: 0.3938368391791922
Epoch: 11, iteration: 77, loss: 0.3320007494988235
Epoch: 11, iteration: 78, loss: 0.5516679388578497
Epoch: 11, iteration: 79, loss: 0.4703366223646764
Epoch: 11, iteration: 80, loss: 0.6143731524646642
Epoch: 11, iteration: 81, loss: 0.35297462326830303
Epoch: 11, iteration: 82, 

Epoch: 12, iteration: 33, loss: 0.3461187087584045
Epoch: 12, iteration: 34, loss: 0.48654326610283827
Epoch: 12, iteration: 35, loss: 0.5685338981214566
Epoch: 12, iteration: 36, loss: 0.31864766645735965
Epoch: 12, iteration: 37, loss: 0.38068983790264727
Epoch: 12, iteration: 38, loss: 0.6357297536053437
Epoch: 12, iteration: 39, loss: 0.4057234250867113
Epoch: 12, iteration: 40, loss: 0.4103788271573215
Epoch: 12, iteration: 41, loss: 0.36453984389796135
Epoch: 12, iteration: 42, loss: 0.37384546074827685
Epoch: 12, iteration: 43, loss: 0.48427040090906154
Epoch: 12, iteration: 44, loss: 0.38444624640914643
Epoch: 12, iteration: 45, loss: 0.4146958488676472
Epoch: 12, iteration: 46, loss: 0.8355030189062527
Epoch: 12, iteration: 47, loss: 0.3400289634656345
Epoch: 12, iteration: 48, loss: 0.3710851095943803
Epoch: 12, iteration: 49, loss: 0.3947735597620727
Epoch: 12, iteration: 50, loss: 0.3699872157293392
Epoch: 12, iteration: 51, loss: 0.44395785687883393
Epoch: 12, iteration: 5

Epoch: 13, iteration: 2, loss: 0.3411848878359181
Epoch: 13, iteration: 3, loss: 0.4471518287171624
Epoch: 13, iteration: 4, loss: 0.5260819666754226
Epoch: 13, iteration: 5, loss: 0.3667303736408688
Epoch: 13, iteration: 6, loss: 0.3644002890501109
Epoch: 13, iteration: 7, loss: 0.3722851497900423
Epoch: 13, iteration: 8, loss: 0.3931185595868661
Epoch: 13, iteration: 9, loss: 0.33577987898074535
Epoch: 13, iteration: 10, loss: 0.39633049852412633
Epoch: 13, iteration: 11, loss: 0.4171963279364564
Epoch: 13, iteration: 12, loss: 0.42781527179342443
Epoch: 13, iteration: 13, loss: 0.41755530512210287
Epoch: 13, iteration: 14, loss: 0.3290515571688311
Epoch: 13, iteration: 15, loss: 0.4112774288174217
Epoch: 13, iteration: 16, loss: 0.34214097840167956
Epoch: 13, iteration: 17, loss: 0.37276983458396346
Epoch: 13, iteration: 18, loss: 0.41131100589061637
Epoch: 13, iteration: 19, loss: 0.36041965650292496
Epoch: 13, iteration: 20, loss: 0.5697128098313716
Epoch: 13, iteration: 21, loss:

Epoch: 13, iteration: 161, loss: 0.46704620363838517
Epoch: 13, iteration: 162, loss: 0.5342163227605253
Epoch: 13, iteration: 163, loss: 0.500316149047743
Epoch: 13, iteration: 164, loss: 0.3592774468448972
Epoch: 13, iteration: 165, loss: 0.48546827125666075
Epoch: 13, iteration: 166, loss: 0.3477728196231283
Epoch: 13, iteration: 167, loss: 0.4262372753384704
Epoch: 13, iteration: 168, loss: 0.5170744324965674
Epoch: 13, iteration: 169, loss: 0.3624190210415852
Epoch: 13, iteration: 170, loss: 0.3790250675980418
Epoch: 13, iteration: 171, loss: 0.5764370146266515
Epoch: 13, iteration: 172, loss: 0.4198156879969038
Epoch: 13, iteration: 173, loss: 0.472601741129574
Epoch: 13, iteration: 174, loss: 0.3722801750357455
Epoch: 13, iteration: 175, loss: 0.35783883106490505
Epoch: 13, iteration: 176, loss: 0.4195697793419543
Epoch: 13, iteration: 177, loss: 0.3047336275505469
Epoch: 13, iteration: 178, loss: 0.3140048907617519
Epoch: 13, iteration: 179, loss: 0.43698096806880776
Epoch: 13,

Epoch: 14, iteration: 131, loss: 0.3734836691334959
Epoch: 14, iteration: 132, loss: 0.3648374325191062
Epoch: 14, iteration: 133, loss: 0.30798569928137864
Epoch: 14, iteration: 134, loss: 0.2892945574155244
Epoch: 14, iteration: 135, loss: 0.3581669306007006
Epoch: 14, iteration: 136, loss: 0.3721328318592509
Epoch: 14, iteration: 137, loss: 0.32737645431310997
Epoch: 14, iteration: 138, loss: 0.3355514261302839
Epoch: 14, iteration: 139, loss: 0.34053367002206314
Epoch: 14, iteration: 140, loss: 0.33578044085514863
Epoch: 14, iteration: 141, loss: 0.41492909426558067
Epoch: 14, iteration: 142, loss: 0.3102980231078768
Epoch: 14, iteration: 143, loss: 0.2955898372349141
Epoch: 14, iteration: 144, loss: 0.4760138473654194
Epoch: 14, iteration: 145, loss: 0.3479932620879668
Epoch: 14, iteration: 146, loss: 0.28291804921574104
Epoch: 14, iteration: 147, loss: 0.31672002118975334
Epoch: 14, iteration: 148, loss: 0.3626995738439159
Epoch: 14, iteration: 149, loss: 0.41067366173871767
Epoc

Epoch: 15, iteration: 101, loss: 0.4067599072425921
Epoch: 15, iteration: 102, loss: 0.3218813284686763
Epoch: 15, iteration: 103, loss: 0.3260876173982122
Epoch: 15, iteration: 104, loss: 0.35999743811131635
Epoch: 15, iteration: 105, loss: 0.42224691866459585
Epoch: 15, iteration: 106, loss: 0.35965630712332614
Epoch: 15, iteration: 107, loss: 0.3227296514634471
Epoch: 15, iteration: 108, loss: 0.39350377906099954
Epoch: 15, iteration: 109, loss: 0.3642442478088419
Epoch: 15, iteration: 110, loss: 0.30513368587485323
Epoch: 15, iteration: 111, loss: 0.39360928608899104
Epoch: 15, iteration: 112, loss: 0.7392036775702328
Epoch: 15, iteration: 113, loss: 0.3112213894868761
Epoch: 15, iteration: 114, loss: 0.3839961427539263
Epoch: 15, iteration: 115, loss: 0.3725672082522472
Epoch: 15, iteration: 116, loss: 0.3555839306362508
Epoch: 15, iteration: 117, loss: 0.3632025049769505
Epoch: 15, iteration: 118, loss: 0.35675376657618346
Epoch: 15, iteration: 119, loss: 0.5994489259400893
Epoch

Epoch: 16, iteration: 70, loss: 0.48886734505017976
Epoch: 16, iteration: 71, loss: 0.59332736466499
Epoch: 16, iteration: 72, loss: 0.2879349827476208
Epoch: 16, iteration: 73, loss: 0.3885101959578515
Epoch: 16, iteration: 74, loss: 0.35962700482225246
Epoch: 16, iteration: 75, loss: 0.3470101862098495
Epoch: 16, iteration: 76, loss: 0.3262662282610753
Epoch: 16, iteration: 77, loss: 0.42415562189111133
Epoch: 16, iteration: 78, loss: 0.3507870487089078
Epoch: 16, iteration: 79, loss: 0.2918974147203241
Epoch: 16, iteration: 80, loss: 0.2868521841405965
Epoch: 16, iteration: 81, loss: 0.41629367967024156
Epoch: 16, iteration: 82, loss: 0.4113364926466134
Epoch: 16, iteration: 83, loss: 0.3701015817763555
Epoch: 16, iteration: 84, loss: 0.512212587039103
Epoch: 16, iteration: 85, loss: 0.30469304574849476
Epoch: 16, iteration: 86, loss: 0.27968128300364964
Epoch: 16, iteration: 87, loss: 0.35227983006337493
Epoch: 16, iteration: 88, loss: 0.3966204213940837
Epoch: 16, iteration: 89, l

Epoch: 17, iteration: 39, loss: 0.255254936352361
Epoch: 17, iteration: 40, loss: 0.2861643031985179
Epoch: 17, iteration: 41, loss: 0.37337944963958825
Epoch: 17, iteration: 42, loss: 0.46582073673552604
Epoch: 17, iteration: 43, loss: 0.2933631318917494
Epoch: 17, iteration: 44, loss: 0.7470749079586105
Epoch: 17, iteration: 45, loss: 0.44437882040208454
Epoch: 17, iteration: 46, loss: 0.3271812842734015
Epoch: 17, iteration: 47, loss: 0.28552355568253773
Epoch: 17, iteration: 48, loss: 0.3126682788627091
Epoch: 17, iteration: 49, loss: 0.3943230416689608
Epoch: 17, iteration: 50, loss: 0.329544602167257
Epoch: 17, iteration: 51, loss: 0.3657450146474447
Epoch: 17, iteration: 52, loss: 0.3450015609523653
Epoch: 17, iteration: 53, loss: 0.299602979310783
Epoch: 17, iteration: 54, loss: 0.30172758098377617
Epoch: 17, iteration: 55, loss: 0.33428610776550166
Epoch: 17, iteration: 56, loss: 0.3596093991856193
Epoch: 17, iteration: 57, loss: 0.27516636816996315
Epoch: 17, iteration: 58, l

Epoch: 18, iteration: 8, loss: 0.2407768338189768
Epoch: 18, iteration: 9, loss: 0.36403424107959653
Epoch: 18, iteration: 10, loss: 0.37833666980274316
Epoch: 18, iteration: 11, loss: 0.3398395893850863
Epoch: 18, iteration: 12, loss: 0.3283273647340267
Epoch: 18, iteration: 13, loss: 0.3781197059988364
Epoch: 18, iteration: 14, loss: 0.33602350659708624
Epoch: 18, iteration: 15, loss: 0.331931462747754
Epoch: 18, iteration: 16, loss: 0.3925361005139226
Epoch: 18, iteration: 17, loss: 0.2674638906873642
Epoch: 18, iteration: 18, loss: 0.3024989869323129
Epoch: 18, iteration: 19, loss: 0.27651967515152226
Epoch: 18, iteration: 20, loss: 0.2939068402913423
Epoch: 18, iteration: 21, loss: 0.2874253057113892
Epoch: 18, iteration: 22, loss: 0.24563726929863022
Epoch: 18, iteration: 23, loss: 0.32727183939689314
Epoch: 18, iteration: 24, loss: 0.2895979227987354
Epoch: 18, iteration: 25, loss: 0.3150143834510763
Epoch: 18, iteration: 26, loss: 0.425420599822478
Epoch: 18, iteration: 27, los

Epoch: 18, iteration: 167, loss: 0.2909654743510043
Epoch: 18, iteration: 168, loss: 0.2528408368244098
Epoch: 18, iteration: 169, loss: 0.34433873343918464
Epoch: 18, iteration: 170, loss: 0.3547920470908826
Epoch: 18, iteration: 171, loss: 0.2946415666516906
Epoch: 18, iteration: 172, loss: 0.28192540338236083
Epoch: 18, iteration: 173, loss: 0.2766915188162635
Epoch: 18, iteration: 174, loss: 0.3734887934932639
Epoch: 18, iteration: 175, loss: 0.269237139134762
Epoch: 18, iteration: 176, loss: 0.3054420200358533
Epoch: 18, iteration: 177, loss: 0.37310876185882136
Epoch: 18, iteration: 178, loss: 0.33968394172177174
Epoch: 18, iteration: 179, loss: 0.3139818134115654
Epoch: 18, iteration: 180, loss: 0.285012129195457
Epoch: 18, iteration: 181, loss: 0.31691965745764905
Epoch: 18, iteration: 182, loss: 0.3072388278760015
Epoch: 18, iteration: 183, loss: 0.27994696666010344
Epoch: 18, iteration: 184, loss: 0.2920795077809256
Epoch: 18, iteration: 185, loss: 0.39871847567363267
Epoch: 

Epoch: 19, iteration: 137, loss: 0.28273632625775164
Epoch: 19, iteration: 138, loss: 0.38405974725088055
Epoch: 19, iteration: 139, loss: 0.2129766223393533
Epoch: 19, iteration: 140, loss: 0.27097982094849526
Epoch: 19, iteration: 141, loss: 0.24630553824583504
Epoch: 19, iteration: 142, loss: 0.3212266161518046
Epoch: 19, iteration: 143, loss: 0.32541750904659333
Epoch: 19, iteration: 144, loss: 0.317612444308289
Epoch: 19, iteration: 145, loss: 0.29235417800355223
Epoch: 19, iteration: 146, loss: 0.2725124755047494
Epoch: 19, iteration: 147, loss: 0.30477124233856323
Epoch: 19, iteration: 148, loss: 0.355633411303183
Epoch: 19, iteration: 149, loss: 0.3146162635043888
Epoch: 19, iteration: 150, loss: 0.22396518781550662
Epoch: 19, iteration: 151, loss: 0.30134705509757787
Epoch: 19, iteration: 152, loss: 0.3486658335521556
Epoch: 19, iteration: 153, loss: 0.2639557834289565
Epoch: 19, iteration: 154, loss: 0.24661797451571402
Epoch: 19, iteration: 155, loss: 0.2744406862162277
Epoc

Epoch: 20, iteration: 107, loss: 0.2404506801814551
Epoch: 20, iteration: 108, loss: 0.2910506488005902
Epoch: 20, iteration: 109, loss: 0.23079596658334067
Epoch: 20, iteration: 110, loss: 0.27969324324878947
Epoch: 20, iteration: 111, loss: 0.24478637635201897
Epoch: 20, iteration: 112, loss: 0.24853316177420673
Epoch: 20, iteration: 113, loss: 0.22376710282792509
Epoch: 20, iteration: 114, loss: 0.291786931771005
Epoch: 20, iteration: 115, loss: 0.24831928403806783
Epoch: 20, iteration: 116, loss: 0.43997652983906144
Epoch: 20, iteration: 117, loss: 0.30143524907528557
Epoch: 20, iteration: 118, loss: 0.35868817923872826
Epoch: 20, iteration: 119, loss: 0.44877021582676696
Epoch: 20, iteration: 120, loss: 0.25201604962779817
Epoch: 20, iteration: 121, loss: 0.2299514221990578
Epoch: 20, iteration: 122, loss: 0.806797162669034
Epoch: 20, iteration: 123, loss: 0.2782972177543697
Epoch: 20, iteration: 124, loss: 0.2679663762299836
Epoch: 20, iteration: 125, loss: 0.20154561176772787
Ep

In [None]:
evaluate(valid_dataloader, model, criterion, device, threshold=0.75)

In [None]:
# This is for finding the optimal threshold.
# Haven't proved out to be very useful yet.
if True:
    f1_scores = []
    for threshold in np.arange(0.05, 1, 0.05):
        _, f1 = evaluate(valid_dataloader, model, criterion, device, threshold=threshold)
        f1_scores.append(f1)
        print(f'threshold: {threshold}, f1 score: {f1}')

    plt.plot(np.arange(0.05, 1, 0.05), f1_scores)

# Visualization and evaluation

### Show some images with predictions

In [None]:
visualize_predictions(model, device, train_dataloader, n_to_show=10, threshold=0.75)

### Confusion matrix

In [None]:
re_predict = True

if re_predict:

    # Predict
    y_true, y_pred = predict(model, device, valid_dataloader)
    np.save(f'../data/valid_true_labels.npy', y_true)
    np.save(f'../data/valid_pred_labels.npy', y_pred)
    
    # Save classification report
    with open(f'../data/valid_classification_report.txt', 'w') as file:
        file.write(skm.classification_report(y_true, y_pred))
    
    # Save confusion matrix plot
    labels = [k for k, v in get_class_map().items()]
    visualize_confusion_matrix(y_true, y_pred, labels, f'../data/valid_confusion_matrix.png')

In [None]:
# Show classification report
with open(f'../data/valid_classification_report.txt', 'r') as file:
    report = ''.join(file.readlines())
    print(report)

In [None]:
# Show confus|ion matrix plot
IPython_Image(filename=f'../data/valid_confusion_matrix.png', width=1000)