In [1]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn.modules.loss import *
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from pathlib import Path
# %matplotlib notebook

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
from Loss.triplet import *
from session import *
from LR_Schedule.cyclical import Cyclical
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
from validation import _AccuracyMeter
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *
import util
import Datasets.ModelData as md

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
torch.cuda.set_device(0); torch.backends.cudnn.benchmark=True;

    Found GPU0 GeForce GTX 770 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability that we support is 3.5.
    


In [5]:
data_path = Path("/media/drake/MX500/Datasets/Kannada-MNIST")
train_path = data_path/'train.csv'
test_path = data_path/'test.csv'
df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

In [6]:
labels = df[df.columns[0]].to_numpy()
inputs = df[df.columns[1:]].to_numpy().reshape(-1, 28, 28)

ids = test_df[test_df.columns[0]].to_numpy()
test_inputs = test_df[test_df.columns[1:]].to_numpy().reshape(-1, 28, 28)

inputs.shape, labels.shape

((60000, 28, 28), (60000,))

In [7]:
class KannadaMNISTDataset(Dataset):
    def __init__(self, inputs, labels, transform):
        self.data = inputs
        self.targets = labels
        self.tsfm = transform
        
    def __len__(self): return self.targets.shape[0]
        
    def __getitem__(self, i):
        x, y = self.data[i], self.targets[i]  
        x = x.astype(np.uint8)
        x = self.tsfm(x)
        x = x.float()
        
        return x, y

In [8]:
i_dict = md.make_partition_indices(labels.shape[0], {'train': .9, 'valid': .1})

In [9]:
transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomRotation(9),
                                transforms.RandomResizedCrop(28, scale=(.95, 1.05)),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])

train_dataset = KannadaMNISTDataset(inputs[i_dict['train']], labels[i_dict['train']], transform)
valid_dataset = KannadaMNISTDataset(inputs[i_dict['valid']], labels[i_dict['valid']], transform)
test_dataset = KannadaMNISTDataset(test_inputs, ids, transform)
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [10]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

## Baseline

In [None]:
model_base = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    
    nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
     
    nn.MaxPool2d(kernel_size=2, stride=2),
    
    nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    
    nn.MaxPool2d(kernel_size=2, stride=2),
    
    nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    
    nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    
    nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    
    Flatten(),
    
    nn.Linear(7 * 7 * 32, 50),
    nn.Linear(50, 25),
    nn.Linear(25, 10)
    # nn.LogSoftmax()
)

sess_base = Session(model_base, nn.CrossEntropyLoss(), optim.Adam, 1e-3)

In [None]:
lr_find(sess_base, trainloader, start_lr=1e-7)

In [None]:
sess_base.set_lr(1e-3)

In [None]:
validator = Validator(valloader, OneHotAccuracy(), save_best=True, model_dir='./base')
lr_scheduler = CosAnneal(len(trainloader), T_mult=2)
schedule = TrainingSchedule(trainloader, [lr_scheduler, validator])
sess_base.train(schedule, 7)

In [None]:
def inference_test_data(model, submission_path=None):
    if submission_path is None:
        submission_path = Path('./Kannada.submission.csv')
        
    submission = []
    
    model = util.to_gpu(model)
    
    with EvalModel(model):       
        for input, id in testloader:
            outputs = model(util.to_gpu(input))
            _, preds = torch.max(outputs, 1)
            

            batch = torch.stack([id, preds.cpu()], dim=1).numpy()          

            submission.append(batch)

    submission = np.concatenate(submission)
   
    df = pd.DataFrame(submission, columns=['id', 'label'])
    
    df.to_csv(submission_path, index=False)
        
    return df

In [None]:
sub = inference_test_data(model_base)
sub

## Triplet Loss 


In [11]:
class TripletRegularizedCrossEntropyLoss(nn.Module):
    def __init__(self, alpha, margin):     
        super().__init__()
        self.alpha = alpha
        self.margin = margin
        
    def forward(self, x, y):
        loss = F.cross_entropy(x[-1], y)
        triplet = 0
        
        for layer in x[:-1]:
            triplet += batch_hard_triplet_loss(layer.view(layer.size(0), -1), y, self.margin)
            
        # triplet *= min(self.alpha/math.sqrt(loss.item()), 1)
        triplet *= self.alpha
            
        return loss + triplet
    
    
class CustomOneHotAccuracy(OneHotAccuracy):
    def __init__(self):
        super().__init__()
        self.reset()

    def update(self, output, label):
        super().update(output[-1], label)
        
class SelectiveSequential(nn.Module):
    def __init__(self, to_select, modules_dict):
        super(SelectiveSequential, self).__init__()
        for key, module in modules_dict.items():
            self.add_module(key, module)
        self._to_select = to_select
    
    def forward(self, x):
        list = []
        for name, module in self._modules.items():               
            x = module(x)      
            
            if name in self._to_select:
                list.append(x)
                
        return list
    

In [12]:
model = SelectiveSequential(
    ['act64a', 'act128', 'act64b', 'act32b', 'fc1', 'fc2', 'out'],
    {'conv16': nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
     'batch16': nn.BatchNorm2d(16),
     'act16': nn.ReLU(),
     
     'drop1': nn.Dropout(.12),
    
     'conv32a': nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
     'batch32a': nn.BatchNorm2d(32),
     'act32a': nn.ReLU(),
     
     'drop1': nn.Dropout(.15),
     
     'max1': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv64a': nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
     'batch64a': nn.BatchNorm2d(64),
     'act64a': nn.ReLU(),
    
     'max2': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv128': nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
     'batch128': nn.BatchNorm2d(128),
     'act128': nn.ReLU(),
     
     'drop1': nn.Dropout(.2),
    
     'conv64b': nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
     'batch64b': nn.BatchNorm2d(64),
     'act64b': nn.ReLU(),
     
     'drop1': nn.Dropout(.15),
    
     'conv32b': nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
     'batch32b': nn.BatchNorm2d(32),
     'act32b': nn.ReLU(),
     
     'drop1': nn.Dropout(.12),
    
     'flatten': Flatten(),
    
     'fc1': nn.Linear(7 * 7 * 32, 50),
     'drop1': nn.Dropout(.05),
     'fc2': nn.Linear(50, 25),
     'drop1': nn.Dropout(.05),
     'out': nn.Linear(25, 10)})

In [13]:
criterion = TripletRegularizedCrossEntropyLoss(.25, 1)
sess = Session(model, criterion, optim.AdamW, 1e-3)

In [None]:
lr_find(sess, trainloader, start_lr=1e-7)

In [14]:
sess.set_lr(2e-4)

In [None]:
lr_scheduler.plot(len(trainloader) * 4)

In [15]:
validator = Validator(valloader, CustomOneHotAccuracy(), save_best=True, model_dir='./triplet')
lr_scheduler = CosAnneal(len(trainloader), T_mult=2)
schedule = TrainingSchedule(trainloader, [lr_scheduler, validator])
sess.train(schedule, 63)

HBox(children=(IntProgress(value=0, description='Epochs', max=63, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.69it/s]


Training Loss: 1.408563494682312  Validaton Loss: 1.5481818914413452 Validation Accuracy: 0.9866666666666666


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.01it/s]


Training Loss: 0.9349409937858582  Validaton Loss: 0.9910321235656738 Validation Accuracy: 0.9911666666666666


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.83it/s]


Training Loss: 0.7632272243499756  Validaton Loss: 0.8499147295951843 Validation Accuracy: 0.9924999999999999


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.53it/s]


Training Loss: 0.683520495891571  Validaton Loss: 0.7270594835281372 Validation Accuracy: 0.9923333333333333


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.10it/s]


Training Loss: 0.5708628296852112  Validaton Loss: 0.6129052042961121 Validation Accuracy: 0.9926666666666666


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.00it/s]


Training Loss: 0.4965168237686157  Validaton Loss: 0.5488703846931458 Validation Accuracy: 0.9954999999999999


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.89it/s]

Training Loss: 0.4726262390613556  Validaton Loss: 0.5414458513259888 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.02it/s]

Training Loss: 0.5322763919830322  Validaton Loss: 0.5482181906700134 Validation Accuracy: 0.9931666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.64it/s]

Training Loss: 0.4884168207645416  Validaton Loss: 0.5329318642616272 Validation Accuracy: 0.992





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.55it/s]

Training Loss: 0.42312416434288025  Validaton Loss: 0.4571601152420044 Validation Accuracy: 0.9953333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.55it/s]

Training Loss: 0.384813517332077  Validaton Loss: 0.44114622473716736 Validation Accuracy: 0.9934999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.90it/s]

Training Loss: 0.37912648916244507  Validaton Loss: 0.4006510376930237 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.68it/s]

Training Loss: 0.32645273208618164  Validaton Loss: 0.39921385049819946 Validation Accuracy: 0.9946666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.96it/s]

Training Loss: 0.32715749740600586  Validaton Loss: 0.37222516536712646 Validation Accuracy: 0.9964999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.75it/s]


Training Loss: 0.30470097064971924  Validaton Loss: 0.3682711124420166 Validation Accuracy: 0.9971666666666666


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.90it/s]

Training Loss: 0.38470658659935  Validaton Loss: 0.4565311372280121 Validation Accuracy: 0.9911666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.15it/s]

Training Loss: 0.39669251441955566  Validaton Loss: 0.43066349625587463 Validation Accuracy: 0.994





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.81it/s]

Training Loss: 0.37050744891166687  Validaton Loss: 0.41601020097732544 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.91it/s]

Training Loss: 0.3430190086364746  Validaton Loss: 0.37087199091911316 Validation Accuracy: 0.9944999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.97it/s]

Training Loss: 0.30776506662368774  Validaton Loss: 0.3418295085430145 Validation Accuracy: 0.9948333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.65it/s]

Training Loss: 0.31713736057281494  Validaton Loss: 0.3481385409832001 Validation Accuracy: 0.995





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.08it/s]

Training Loss: 0.2614837884902954  Validaton Loss: 0.31894952058792114 Validation Accuracy: 0.9956666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.84it/s]

Training Loss: 0.2689815163612366  Validaton Loss: 0.313056081533432 Validation Accuracy: 0.9953333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.25it/s]

Training Loss: 0.26719164848327637  Validaton Loss: 0.30614718794822693 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.89it/s]

Training Loss: 0.25422921776771545  Validaton Loss: 0.28537192940711975 Validation Accuracy: 0.9961666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.76it/s]

Training Loss: 0.25265517830848694  Validaton Loss: 0.31050634384155273 Validation Accuracy: 0.9956666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.78it/s]

Training Loss: 0.2546689510345459  Validaton Loss: 0.27317094802856445 Validation Accuracy: 0.9956666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.85it/s]

Training Loss: 0.240036278963089  Validaton Loss: 0.2733834683895111 Validation Accuracy: 0.9963333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.09it/s]

Training Loss: 0.2301771342754364  Validaton Loss: 0.26964277029037476 Validation Accuracy: 0.9966666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.99it/s]

Training Loss: 0.22571080923080444  Validaton Loss: 0.28680700063705444 Validation Accuracy: 0.9966666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.25it/s]


Training Loss: 0.21662336587905884  Validaton Loss: 0.2729403078556061 Validation Accuracy: 0.9973333333333333


HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.22it/s]

Training Loss: 0.2774008810520172  Validaton Loss: 0.31880828738212585 Validation Accuracy: 0.9926666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.06it/s]

Training Loss: 0.2650086581707001  Validaton Loss: 0.31911805272102356 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.31it/s]

Training Loss: 0.2575075328350067  Validaton Loss: 0.3465617299079895 Validation Accuracy: 0.9943333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.28it/s]

Training Loss: 0.2519906461238861  Validaton Loss: 0.28872567415237427 Validation Accuracy: 0.9953333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.79it/s]

Training Loss: 0.27018189430236816  Validaton Loss: 0.27148011326789856 Validation Accuracy: 0.994





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.75it/s]

Training Loss: 0.23573730885982513  Validaton Loss: 0.2771734297275543 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.87it/s]

Training Loss: 0.22736598551273346  Validaton Loss: 0.2888261079788208 Validation Accuracy: 0.9964999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.05it/s]

Training Loss: 0.26019561290740967  Validaton Loss: 0.2819655239582062 Validation Accuracy: 0.997





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.11it/s]

Training Loss: 0.21986456215381622  Validaton Loss: 0.27879536151885986 Validation Accuracy: 0.9938333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.29it/s]

Training Loss: 0.2274598926305771  Validaton Loss: 0.24652668833732605 Validation Accuracy: 0.9963333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.49it/s]

Training Loss: 0.2161606103181839  Validaton Loss: 0.24795009195804596 Validation Accuracy: 0.9964999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.04it/s]

Training Loss: 0.22570231556892395  Validaton Loss: 0.27840638160705566 Validation Accuracy: 0.9946666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.84it/s]

Training Loss: 0.18998663127422333  Validaton Loss: 0.2453434020280838 Validation Accuracy: 0.996





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.87it/s]

Training Loss: 0.18508446216583252  Validaton Loss: 0.2575629949569702 Validation Accuracy: 0.9943333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.75it/s]

Training Loss: 0.21535171568393707  Validaton Loss: 0.2358287274837494 Validation Accuracy: 0.9948333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.84it/s]

Training Loss: 0.18585319817066193  Validaton Loss: 0.24402737617492676 Validation Accuracy: 0.996





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.94it/s]

Training Loss: 0.18151076138019562  Validaton Loss: 0.22069621086120605 Validation Accuracy: 0.9953333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.03it/s]

Training Loss: 0.19285590946674347  Validaton Loss: 0.23351985216140747 Validation Accuracy: 0.9954999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.78it/s]

Training Loss: 0.17989923059940338  Validaton Loss: 0.22271154820919037 Validation Accuracy: 0.997





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.88it/s]

Training Loss: 0.1726403683423996  Validaton Loss: 0.2261311560869217 Validation Accuracy: 0.9961666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.13it/s]

Training Loss: 0.16152265667915344  Validaton Loss: 0.22721272706985474 Validation Accuracy: 0.9944999999999999





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.30it/s]

Training Loss: 0.1677948236465454  Validaton Loss: 0.2122623324394226 Validation Accuracy: 0.9966666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.27it/s]

Training Loss: 0.1661091446876526  Validaton Loss: 0.2135084569454193 Validation Accuracy: 0.9971666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.71it/s]

Training Loss: 0.1533605456352234  Validaton Loss: 0.2218945026397705 Validation Accuracy: 0.9953333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.46it/s]

Training Loss: 0.15917879343032837  Validaton Loss: 0.20991143584251404 Validation Accuracy: 0.997





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.74it/s]

Training Loss: 0.16392505168914795  Validaton Loss: 0.22365359961986542 Validation Accuracy: 0.9961666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.09it/s]

Training Loss: 0.15130335092544556  Validaton Loss: 0.2126944363117218 Validation Accuracy: 0.997





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.91it/s]

Training Loss: 0.14637556672096252  Validaton Loss: 0.20446738600730896 Validation Accuracy: 0.9966666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.06it/s]

Training Loss: 0.16149507462978363  Validaton Loss: 0.20848314464092255 Validation Accuracy: 0.9968333333333333





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 33.04it/s]

Training Loss: 0.13694842159748077  Validaton Loss: 0.19868730008602142 Validation Accuracy: 0.9971666666666666





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.96it/s]

Training Loss: 0.15716391801834106  Validaton Loss: 0.23361383378505707 Validation Accuracy: 0.997





HBox(children=(IntProgress(value=0, description='Steps', max=844, style=ProgressStyle(description_width='initi…

Validating: 100%|██████████| 94/94 [00:02<00:00, 32.58it/s]

Training Loss: 0.15016405284404755  Validaton Loss: 0.21592000126838684 Validation Accuracy: 0.9958333333333333






sess2 = Session(model, TripletRegularizedCrossEntropyLoss(0, 1), optim.AdamW, 1e-4)
validator = Validator(valloader, CustomOneHotAccuracy(), save_best=True, model_dir='./triplet')
lr_scheduler = CosAnneal(len(trainloader), T_mult=2)
schedule = TrainingSchedule(trainloader, [lr_scheduler, validator])
sess2.train(schedule, 7)