In [1]:
import PIL
import gc
import torch
import torchvision
import os
import sys

import numpy as np
import matplotlib.pyplot as plt  
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import Subset
from IPython.core.display import display, HTML
from numpy.random import RandomState
from wide_resnet import WideResNet
from auto_augment import AutoAugment, Cutout
from efficientnet_pytorch import EfficientNet
from cifar_loader import SmallSampleController

sys.path.insert(0,'glico-learning-small-sample/glico_model')

from tester import runGlico



# display(HTML("<style>.container { width:40% !important; }</style>"))


In [2]:

def getAcc(preds,targets):
    return np.sum([1 if preds[i] == targets[i] else 0 for i in range(len(preds))])/len(preds)

def train(model, device, train_loader, optimizer, epoch, display=True):
    """
    Summary: Implements the training procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    optimizer: the optimizer for our training
    train_loader: dataloader for our train data
    display: output flag
    == output ==
    the mean train loss, the train accuracy
    """
    
    lossTracker = []
    
    targets=[]
    preds=[]
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        lossTracker.append(loss.detach())
        with torch.no_grad():
            pred = torch.argmax(output,1).cpu().numpy()
            preds.extend(pred)
            targets.extend(target.cpu().numpy())
        
    lossTracker = [x.item() for x in lossTracker]
    meanLoss = np.mean(lossTracker)
    accuracy = getAcc(preds,targets)
    if display:
        print('Train Epoch: {} [acc: {:.0f}%]\tLoss: {:.6f}'.format(
          epoch, 100. * accuracy, meanLoss))
        
    return accuracy, meanLoss


def glicoTrain(model, device, train_loader, optimizer, epoch, glicoLoader,replaceProb=0.5,display=True):
    """
    Summary: Implements the training procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    optimizer: the optimizer for our training
    train_loader: dataloader for our train data
    display: output flag
    == output ==
    the mean train loss, the train accuracy
    """
    
    lossTracker = []
    
    targets=[]
    preds=[]
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        #replace samples with samples from glico with probability replaceprob
        data = glicoLoader.replaceBatch(data,target,replaceProb) 
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        lossTracker.append(loss.detach())
        with torch.no_grad():
            pred = torch.argmax(output,1).cpu().numpy()
            preds.extend(pred)
            targets.extend(target.cpu().numpy())
        
    lossTracker = [x.item() for x in lossTracker]
    meanLoss = np.mean(lossTracker)
    accuracy = getAcc(preds,targets)
    if display:
        print('Train Epoch: {} [acc: {:.0f}%]\tLoss: {:.6f}'.format(
          epoch, 100. * accuracy, meanLoss))
        
    return accuracy, meanLoss



def test(model, device, test_loader,verbose=True):
    """
    Summary: Implements the testing procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    test_loader: dataloader for our test data
    verbose: output flag
    == output ==
    the mean test loss, the test accuracy
    """
    
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            

    meanLoss = test_loss / len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    if verbose: print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        mean_test_loss, correct, len(test_loader.dataset),
        accuracy))
        
    return accuracy, meanLoss


def checkTest(model,device,valSets,valTracker,latexTracker,epoch,
              model_name,optim_name,lr,totalTestSamples,seed,verbose=True):
    """
    Summary: checks the test accuracy, prints, and saves statistics
    """
    tempAcc = []
    tempLoss = []
    for val_loader in valSets:
        acc,loss = test(model, device, val_loader,verbose = False)
        tempAcc.append(acc)
        tempLoss.append(loss)
        
    meanAcc = np.mean(tempAcc)
    stdAcc = np.std(tempAcc)
    
    meanLoss = np.mean(tempLoss)
    if verbose:
        print('[Trained for {} epochs and tested on {} sets of 2000 images]\
        Avg Acc: {:.2f} +- {:.2f} , Avg Loss: {:.2f}'.format(
            epoch,VALIDATION_SET_NUM,meanAcc,stdAcc,meanLoss))
        
        
    tableRow = getLatexRow(architecture=model_name,epoch=epoch,accuracy=meanAcc,optim=optim_name,
                           lr=lr,totalTestSamples=totalTestSamples,dataAug="Nothing",
                           seed=seed,title=False)
    
    latexTracker.append(tableRow)
        
    valTracker["allLoss"].extend(tempLoss)
    valTracker["allAcc"].extend(tempAcc)
    valTracker["meanLoss"].append(meanLoss)
    valTracker["meanAcc"].append(meanAcc)
    valTracker["stdAcc"].append(stdAcc)





In [3]:
def getLatexRow(architecture,epoch,accuracy,optim,lr,
                totalTestSamples,dataAug,seed,title=False):
    """
    Summary: generates one row of latex for a results table
    """
    categories = ["Model","Epoch","Accuracy","Optimizer","lr","Test Sample Num",
                  "data augmentation","seed"]
    row = [str(architecture),str(epoch),str(round(accuracy,3)),str(optim),
           str(lr),str(totalTestSamples),str(dataAug),str(seed)]
    
    if title:
        c = "&".join(categories)
        r = "&".join(row)
        return "{}\\\\\n{}\\\\".format(c,r)
    else:
        r = "&".join(row)
        return "{}\\\\".format(r)
    
    
def plot(xlist,ylist,xlab,ylab,title,color,label,savedir=".",save=False):
    """
    Summary: plots the given list of numbers against its idices and 
    allows for high resolution saving
    """
    fig = plt.figure()
    plt.title(title)
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.plot(xlist,ylist,color=color,marker=".",label=label)
    plt.legend()
    
    if save:
        if not os.path.isdir(savedir):
            os.mkdir(savedir)
        filepath = os.path.join(savedir,"{}".format(title))
        plt.savefig(filepath+".pdf")
        os.system("pdftoppm -png -r 300 {}.pdf {}.png".format(filepath,filepath))
        
    plt.show()
    
    

In [4]:
def getModel(model_name):
    if "wide" in model_name.lower():
        return WideResNet(28, 10, num_classes=10)
    elif "fix" in model_name.lower():
        return EfficientNet.from_pretrained(model_name) # change to not be pretrained
    
    
def getOptimizer128(optimizer_name,parameters):
    if "sgd" in  optimizer_name.lower():
        LR = 0.09
        optim = torch.optim.SGD(parameters, 
                                  lr=LR, momentum=0.9,
                                  weight_decay=0.0005)
        return optim, LR
    elif "adam" in optimizer_name.lower():
        LR = 0.001
        optim = torch.optim.Adam(parameters, 
                              lr=LR, weight_decay=0)
        return optim, LR
        
    

In [5]:
torch.cuda.empty_cache()
gc.collect()

OPTIM = "Adam"
MODEL = "WideResNet28"
EPOCH_NUM = 2000
TRAIN_SAMPLE_NUM = 100
VAL_SAMPLE_NUM = 2000
BATCH_SIZE = 128
VALIDATION_SET_NUM = 1
AUGMENT = True
VAL_DISPLAY_DIVISOR = 25
CIFAR_TRAIN = True
REPLACE_PROB = 0.05
SEED = None

#cifar-10:
#mean = (0.4914, 0.4822, 0.4465)
#std = (0.247, 0.243, 0.261)

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                  std=[0.247, 0.243, 0.261])

# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                   std=[0.229, 0.224, 0.225])
if AUGMENT:
    dataAugmentation = [ 
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        AutoAugment(),
        Cutout()
    ]
    augment = "Crop,Flip,AutoAugment,Cutout"
else: 
    dataAugmentation = []
    augment = "Nothing"



transform_train = transforms.Compose(dataAugmentation + [transforms.ToTensor(), normalize]) 
transform_val = transforms.Compose([transforms.ToTensor(), normalize]) #careful to keep this one same

glico_train = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, download=True)
cifar_train = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, transform=transform_train, download=True)
cifar_val = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, transform=transform_val, download=True)

ss = SmallSampleController(numClasses=10,trainSampleNum=TRAIN_SAMPLE_NUM, # abstract the data-loading procedure
                           valSampleNum=VAL_SAMPLE_NUM, batchSize=BATCH_SIZE, 
                           multiplier=VALIDATION_SET_NUM, trainDataset=cifar_train, 
                           valDataset=cifar_val)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data, valSets, seed = ss.generateNewSet(device,valMultiplier = VALIDATION_SET_NUM) #Sample from datasets


train_labeled_dataset = Subset(glico_train, ss.trainSampler.indexes[0]) #get the same subset without transform
nagTrainer = runGlico(train_labeled_dataset=train_labeled_dataset, classes=10,epochs=750)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Generated new permutation of the CIFAR train dataset with                 seed:1620010424, train sample num: 100, test sample num: 2000



[*** NEW PRINT ***]nag_params: NAGParams(nz=512, force_l2=False, is_pixel=True, z_init='rndm', is_classifier=True, disc_net='conv', loss='ce', data_name='cifar-10', noise_proj=True, shot=0), rn: tester.py_cifar100_test_run, dataset <class 'list'>



[DEBUG] data name=cifar-10, data res = 32
[NEW CODE] image_params:ImageParams(sz=(32, 32), nc=3, n=100, mu=None, sd=None)
not eval | augment : TRUE
=>Generated data loader, res=32, workers=6 transform=Compose(
    ToTensor()
) sampler=None
[DEBUG] dataset name: <class 'torch.utils.data.dataset.subset'>,offset_labels=0
torch.Size([100, 512])
init_models_weights(self): init rndm
num classes = 10
dataset size = 100
not eval | augment : TRUE
=>Generated data loader, res=32, workers=6 transform=Co

  return self.softmax(x)


Epoch: [0][0/1]	Time  0.645 ( 0.645)	Data  0.260 ( 0.260)	Loss 8.55879 (8.55879)	d_Loss 2.32977 (2.32977)
=>NAG Epoch: 0 Error: 8.558786392211914, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_0.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_0.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_0.pkl
Show images...
=>epoch:1/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [1][0/1]	Time  0.300 ( 0.300)	Data  0.271 ( 0.271)	Loss 8.51396 (8.51396)	d_Loss 2.34086 (2.34086)
=>NAG Epoch: 1 Error: 8.513964653015137, Time:     0.01m
=>epoch:2/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [2][0/1]	Time  0.296 ( 0.296)	Data  0.267 ( 0.267)	Loss 8.42989 (8.42989)	d_Loss 2.35310 (2.35310)
=>NAG Epoch: 2 Error: 8.429893493652344, Time:     0.01m
=>epoch:3/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [3][0/1]	Time  0.294 ( 0.294)	Data  0.264 ( 0.264)	Loss 8.27003 (8.27003)	d_Loss 2.31783 (2.31783)
=>N

Epoch: [31][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 7.07978 (7.07978)	d_Loss 1.89031 (1.89031)
=>NAG Epoch: 31 Error: 7.079784393310547, Time:     0.01m
=>epoch:32/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [32][0/1]	Time  0.293 ( 0.293)	Data  0.264 ( 0.264)	Loss 7.08693 (7.08693)	d_Loss 1.87127 (1.87127)
=>NAG Epoch: 32 Error: 7.086932182312012, Time:     0.01m
=>consecutive_loss: 1
=>epoch:33/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [33][0/1]	Time  0.295 ( 0.295)	Data  0.264 ( 0.264)	Loss 7.02147 (7.02147)	d_Loss 1.86036 (1.86036)
=>NAG Epoch: 33 Error: 7.021470546722412, Time:     0.01m
=>epoch:34/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [34][0/1]	Time  0.310 ( 0.310)	Data  0.280 ( 0.280)	Loss 7.02162 (7.02162)	d_Loss 1.83507 (1.83507)
=>NAG Epoch: 34 Error: 7.021623134613037, Time:     0.01m
=>consecutive_loss: 1
=>epoch:35/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [35][0/1]	Time  0.289 ( 0.289)	Data  0.259 ( 0.259)	Loss 6.95557 

Epoch: [63][0/1]	Time  0.292 ( 0.292)	Data  0.262 ( 0.262)	Loss 6.15942 (6.15942)	d_Loss 1.65037 (1.65037)
=>NAG Epoch: 63 Error: 6.159422874450684, Time:     0.01m
=>epoch:64/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [64][0/1]	Time  0.294 ( 0.294)	Data  0.262 ( 0.262)	Loss 6.13697 (6.13697)	d_Loss 1.62957 (1.62957)
=>NAG Epoch: 64 Error: 6.136966705322266, Time:     0.01m
=>epoch:65/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [65][0/1]	Time  0.302 ( 0.302)	Data  0.271 ( 0.271)	Loss 6.08325 (6.08325)	d_Loss 1.61898 (1.61898)
=>NAG Epoch: 65 Error: 6.083247184753418, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_65.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_65.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_65.pkl
=>epoch:66/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [66][0/1]	Time  0.307 ( 0.307)	Data  0.275 ( 0.275)	Loss 6.08051 (6.08051)	d_Loss 1.60427 (1.60427)
=>NAG

Epoch: [95][0/1]	Time  0.293 ( 0.293)	Data  0.263 ( 0.263)	Loss 5.38796 (5.38796)	d_Loss 1.55323 (1.55323)
=>NAG Epoch: 95 Error: 5.387956142425537, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_95.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_95.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_95.pkl
=>epoch:96/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [96][0/1]	Time  0.289 ( 0.289)	Data  0.258 ( 0.258)	Loss 5.36667 (5.36667)	d_Loss 1.55305 (1.55305)
=>NAG Epoch: 96 Error: 5.366673469543457, Time:     0.01m
=>epoch:97/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [97][0/1]	Time  0.292 ( 0.292)	Data  0.262 ( 0.262)	Loss 5.36712 (5.36712)	d_Loss 1.55164 (1.55164)
=>NAG Epoch: 97 Error: 5.3671183586120605, Time:     0.01m
=>consecutive_loss: 1
=>epoch:98/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [98][0/1]	Time  0.291 ( 0.291)	Data  0.261 ( 0.261)	Loss 5.34139 (5.34139)	d_Loss 

Epoch: [125][0/1]	Time  0.329 ( 0.329)	Data  0.297 ( 0.297)	Loss 5.01804 (5.01804)	d_Loss 1.46905 (1.46905)
=>NAG Epoch: 125 Error: 5.018039226531982, Time:     0.01m
=>consecutive_loss: 1
=>epoch:126/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [126][0/1]	Time  0.310 ( 0.310)	Data  0.279 ( 0.279)	Loss 4.99229 (4.99229)	d_Loss 1.46873 (1.46873)
=>NAG Epoch: 126 Error: 4.992288112640381, Time:     0.01m
=>epoch:127/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [127][0/1]	Time  0.295 ( 0.295)	Data  0.265 ( 0.265)	Loss 4.99217 (4.99217)	d_Loss 1.46913 (1.46913)
=>NAG Epoch: 127 Error: 4.992172718048096, Time:     0.01m
=>epoch:128/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [128][0/1]	Time  0.296 ( 0.296)	Data  0.265 ( 0.265)	Loss 4.97181 (4.97181)	d_Loss 1.46918 (1.46918)
=>NAG Epoch: 128 Error: 4.971814155578613, Time:     0.01m
=>epoch:129/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [129][0/1]	Time  0.294 ( 0.294)	Data  0.265 ( 0.265)	Loss 4.96634 (4.96634)

Epoch: [156][0/1]	Time  0.295 ( 0.295)	Data  0.266 ( 0.266)	Loss 4.63409 (4.63409)	d_Loss 1.46719 (1.46719)
=>NAG Epoch: 156 Error: 4.634086608886719, Time:     0.01m
=>epoch:157/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [157][0/1]	Time  0.303 ( 0.303)	Data  0.274 ( 0.274)	Loss 4.63254 (4.63254)	d_Loss 1.46756 (1.46756)
=>NAG Epoch: 157 Error: 4.632535457611084, Time:     0.01m
=>epoch:158/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [158][0/1]	Time  0.303 ( 0.303)	Data  0.270 ( 0.270)	Loss 4.62157 (4.62157)	d_Loss 1.46701 (1.46701)
=>NAG Epoch: 158 Error: 4.6215667724609375, Time:     0.01m
=>epoch:159/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [159][0/1]	Time  0.326 ( 0.326)	Data  0.296 ( 0.296)	Loss 4.61717 (4.61717)	d_Loss 1.46781 (1.46781)
=>NAG Epoch: 159 Error: 4.617172718048096, Time:     0.01m
=>epoch:160/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [160][0/1]	Time  0.297 ( 0.297)	Data  0.268 ( 0.268)	Loss 4.60848 (4.60848)	d_Loss 1.46672 (1.46

Epoch: [187][0/1]	Time  0.306 ( 0.306)	Data  0.274 ( 0.274)	Loss 4.44216 (4.44216)	d_Loss 1.46708 (1.46708)
=>NAG Epoch: 187 Error: 4.442156791687012, Time:     0.01m
=>epoch:188/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [188][0/1]	Time  0.308 ( 0.308)	Data  0.277 ( 0.277)	Loss 4.43301 (4.43301)	d_Loss 1.46609 (1.46609)
=>NAG Epoch: 188 Error: 4.433007717132568, Time:     0.01m
=>epoch:189/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [189][0/1]	Time  0.296 ( 0.296)	Data  0.265 ( 0.265)	Loss 4.42970 (4.42970)	d_Loss 1.46686 (1.46686)
=>NAG Epoch: 189 Error: 4.429704189300537, Time:     0.01m
=>epoch:190/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [190][0/1]	Time  0.294 ( 0.294)	Data  0.263 ( 0.263)	Loss 4.42305 (4.42305)	d_Loss 1.46642 (1.46642)
=>NAG Epoch: 190 Error: 4.42304801940918, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_190.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_190.pkl
Save check

Epoch: [218][0/1]	Time  0.288 ( 0.288)	Data  0.258 ( 0.258)	Loss 4.24063 (4.24063)	d_Loss 1.46568 (1.46568)
=>NAG Epoch: 218 Error: 4.240632057189941, Time:     0.01m
=>epoch:219/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [219][0/1]	Time  0.288 ( 0.288)	Data  0.259 ( 0.259)	Loss 4.23799 (4.23799)	d_Loss 1.46576 (1.46576)
=>NAG Epoch: 219 Error: 4.237991809844971, Time:     0.01m
=>epoch:220/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [220][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 4.23094 (4.23094)	d_Loss 1.46573 (1.46573)
=>NAG Epoch: 220 Error: 4.230940818786621, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_220.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_220.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_220.pkl
=>epoch:221/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [221][0/1]	Time  0.288 ( 0.288)	Data  0.258 ( 0.258)	Loss 4.23052 (4.23052)	d_Loss 1.46570 (1

Epoch: [249][0/1]	Time  0.290 ( 0.290)	Data  0.261 ( 0.261)	Loss 4.12073 (4.12073)	d_Loss 1.46525 (1.46525)
=>NAG Epoch: 249 Error: 4.120730400085449, Time:     0.01m
=>epoch:250/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [250][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 4.11594 (4.11594)	d_Loss 1.46522 (1.46522)
=>NAG Epoch: 250 Error: 4.115941047668457, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_250.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_250.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_250.pkl
Show images...
=>epoch:251/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [251][0/1]	Time  0.328 ( 0.328)	Data  0.297 ( 0.297)	Loss 4.11327 (4.11327)	d_Loss 1.46516 (1.46516)
=>NAG Epoch: 251 Error: 4.1132731437683105, Time:     0.01m
=>epoch:252/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [252][0/1]	Time  0.301 ( 0.301)	Data  0.272 ( 0.272)	Loss 4.11048 (4.11048)	d

Epoch: [280][0/1]	Time  0.298 ( 0.298)	Data  0.267 ( 0.267)	Loss 4.01472 (4.01472)	d_Loss 1.46492 (1.46492)
=>NAG Epoch: 280 Error: 4.014720439910889, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_280.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_280.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_280.pkl
=>epoch:281/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [281][0/1]	Time  0.293 ( 0.293)	Data  0.263 ( 0.263)	Loss 3.99687 (3.99687)	d_Loss 1.46486 (1.46486)
=>NAG Epoch: 281 Error: 3.9968743324279785, Time:     0.01m
=>epoch:282/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [282][0/1]	Time  0.294 ( 0.294)	Data  0.265 ( 0.265)	Loss 3.99227 (3.99227)	d_Loss 1.46485 (1.46485)
=>NAG Epoch: 282 Error: 3.9922726154327393, Time:     0.01m
=>epoch:283/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [283][0/1]	Time  0.296 ( 0.296)	Data  0.267 ( 0.267)	Loss 3.99017 (3.99017)	d_Loss 1.46471 

Epoch: [311][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.91573 (3.91573)	d_Loss 1.46452 (1.46452)
=>NAG Epoch: 311 Error: 3.9157261848449707, Time:     0.01m
=>epoch:312/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [312][0/1]	Time  0.288 ( 0.288)	Data  0.258 ( 0.258)	Loss 3.91328 (3.91328)	d_Loss 1.46458 (1.46458)
=>NAG Epoch: 312 Error: 3.9132790565490723, Time:     0.01m
=>epoch:313/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [313][0/1]	Time  0.287 ( 0.287)	Data  0.257 ( 0.257)	Loss 3.91068 (3.91068)	d_Loss 1.46443 (1.46443)
=>NAG Epoch: 313 Error: 3.910681962966919, Time:     0.01m
=>epoch:314/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [314][0/1]	Time  0.288 ( 0.288)	Data  0.258 ( 0.258)	Loss 3.90652 (3.90652)	d_Loss 1.46461 (1.46461)
=>NAG Epoch: 314 Error: 3.9065167903900146, Time:     0.01m
=>epoch:315/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [315][0/1]	Time  0.289 ( 0.289)	Data  0.258 ( 0.258)	Loss 3.90409 (3.90409)	d_Loss 1.46455 (1.

Epoch: [342][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.84218 (3.84218)	d_Loss 1.46448 (1.46448)
=>NAG Epoch: 342 Error: 3.8421757221221924, Time:     0.01m
=>epoch:343/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [343][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 3.84105 (3.84105)	d_Loss 1.46431 (1.46431)
=>NAG Epoch: 343 Error: 3.8410491943359375, Time:     0.01m
=>epoch:344/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [344][0/1]	Time  0.292 ( 0.292)	Data  0.261 ( 0.261)	Loss 3.83956 (3.83956)	d_Loss 1.46431 (1.46431)
=>NAG Epoch: 344 Error: 3.8395557403564453, Time:     0.01m
=>epoch:345/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [345][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 3.83684 (3.83684)	d_Loss 1.46435 (1.46435)
=>NAG Epoch: 345 Error: 3.836836814880371, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_345.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_345.pkl
Save c

Epoch: [373][0/1]	Time  0.309 ( 0.309)	Data  0.279 ( 0.279)	Loss 3.78389 (3.78389)	d_Loss 1.46407 (1.46407)
=>NAG Epoch: 373 Error: 3.7838902473449707, Time:     0.01m
=>epoch:374/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [374][0/1]	Time  0.289 ( 0.289)	Data  0.260 ( 0.260)	Loss 3.78253 (3.78253)	d_Loss 1.46401 (1.46401)
=>NAG Epoch: 374 Error: 3.7825264930725098, Time:     0.01m
=>epoch:375/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [375][0/1]	Time  0.296 ( 0.296)	Data  0.265 ( 0.265)	Loss 3.78282 (3.78282)	d_Loss 1.46413 (1.46413)
=>NAG Epoch: 375 Error: 3.7828168869018555, Time:     0.01m
=>consecutive_loss: 1
=>epoch:376/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [376][0/1]	Time  0.292 ( 0.292)	Data  0.262 ( 0.262)	Loss 3.77901 (3.77901)	d_Loss 1.46398 (1.46398)
=>NAG Epoch: 376 Error: 3.7790122032165527, Time:     0.01m
=>epoch:377/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [377][0/1]	Time  0.295 ( 0.295)	Data  0.265 ( 0.265)	Loss 3.77987 (3.77

Epoch: [404][0/1]	Time  0.292 ( 0.292)	Data  0.263 ( 0.263)	Loss 3.74074 (3.74074)	d_Loss 1.46375 (1.46375)
=>NAG Epoch: 404 Error: 3.740741014480591, Time:     0.01m
=>epoch:405/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [405][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.73724 (3.73724)	d_Loss 1.46385 (1.46385)
=>NAG Epoch: 405 Error: 3.737241268157959, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_405.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_405.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_405.pkl
=>epoch:406/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [406][0/1]	Time  0.297 ( 0.297)	Data  0.267 ( 0.267)	Loss 3.73622 (3.73622)	d_Loss 1.46379 (1.46379)
=>NAG Epoch: 406 Error: 3.736215353012085, Time:     0.01m
=>epoch:407/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [407][0/1]	Time  0.308 ( 0.308)	Data  0.278 ( 0.278)	Loss 3.73579 (3.73579)	d_Loss 1.46386 (1

Epoch: [435][0/1]	Time  0.297 ( 0.297)	Data  0.266 ( 0.266)	Loss 3.70150 (3.70150)	d_Loss 1.46369 (1.46369)
=>NAG Epoch: 435 Error: 3.7015023231506348, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_435.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_435.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_435.pkl
=>epoch:436/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [436][0/1]	Time  0.304 ( 0.304)	Data  0.274 ( 0.274)	Loss 3.69863 (3.69863)	d_Loss 1.46365 (1.46365)
=>NAG Epoch: 436 Error: 3.698633909225464, Time:     0.01m
=>epoch:437/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [437][0/1]	Time  0.293 ( 0.293)	Data  0.263 ( 0.263)	Loss 3.69972 (3.69972)	d_Loss 1.46359 (1.46359)
=>NAG Epoch: 437 Error: 3.6997151374816895, Time:     0.01m
=>consecutive_loss: 1
=>epoch:438/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [438][0/1]	Time  0.293 ( 0.293)	Data  0.263 ( 0.263)	Loss 3.69948 (3.

Epoch: [465][0/1]	Time  0.289 ( 0.289)	Data  0.259 ( 0.259)	Loss 3.67273 (3.67273)	d_Loss 1.46351 (1.46351)
=>NAG Epoch: 465 Error: 3.6727278232574463, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_465.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_465.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_465.pkl
=>epoch:466/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [466][0/1]	Time  0.302 ( 0.302)	Data  0.272 ( 0.272)	Loss 3.67226 (3.67226)	d_Loss 1.46336 (1.46336)
=>NAG Epoch: 466 Error: 3.6722590923309326, Time:     0.01m
=>epoch:467/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [467][0/1]	Time  0.295 ( 0.295)	Data  0.264 ( 0.264)	Loss 3.67205 (3.67205)	d_Loss 1.46348 (1.46348)
=>NAG Epoch: 467 Error: 3.672053337097168, Time:     0.01m
=>epoch:468/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [468][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 3.66986 (3.66986)	d_Loss 1.46335 

Epoch: [497][0/1]	Time  0.299 ( 0.299)	Data  0.269 ( 0.269)	Loss 3.64712 (3.64712)	d_Loss 1.46328 (1.46328)
=>NAG Epoch: 497 Error: 3.6471171379089355, Time:     0.01m
=>consecutive_loss: 1
=>epoch:498/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [498][0/1]	Time  0.289 ( 0.289)	Data  0.259 ( 0.259)	Loss 3.64733 (3.64733)	d_Loss 1.46324 (1.46324)
=>NAG Epoch: 498 Error: 3.6473255157470703, Time:     0.01m
=>consecutive_loss: 2
=>epoch:499/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [499][0/1]	Time  0.288 ( 0.288)	Data  0.259 ( 0.259)	Loss 3.64619 (3.64619)	d_Loss 1.46330 (1.46330)
=>NAG Epoch: 499 Error: 3.646191358566284, Time:     0.01m
=>consecutive_loss: 3
=>epoch:500/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [500][0/1]	Time  0.289 ( 0.289)	Data  0.259 ( 0.259)	Loss 3.64407 (3.64407)	d_Loss 1.46332 (1.46332)
=>NAG Epoch: 500 Error: 3.6440749168395996, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_500.pkl
Save checkpoints

Epoch: [528][0/1]	Time  0.288 ( 0.288)	Data  0.259 ( 0.259)	Loss 3.63132 (3.63132)	d_Loss 1.46316 (1.46316)
=>NAG Epoch: 528 Error: 3.6313247680664062, Time:     0.01m
=>consecutive_loss: 2
=>epoch:529/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [529][0/1]	Time  0.291 ( 0.291)	Data  0.260 ( 0.260)	Loss 3.63063 (3.63063)	d_Loss 1.46313 (1.46313)
=>NAG Epoch: 529 Error: 3.6306283473968506, Time:     0.01m
=>epoch:530/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [530][0/1]	Time  0.291 ( 0.291)	Data  0.261 ( 0.261)	Loss 3.62941 (3.62941)	d_Loss 1.46315 (1.46315)
=>NAG Epoch: 530 Error: 3.629409074783325, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_530.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_530.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_530.pkl
=>epoch:531/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [531][0/1]	Time  0.286 ( 0.286)	Data  0.256 ( 0.256)	Loss 3.62901 (3.

Epoch: [560][0/1]	Time  0.305 ( 0.305)	Data  0.274 ( 0.274)	Loss 3.61364 (3.61364)	d_Loss 1.46303 (1.46303)
=>NAG Epoch: 560 Error: 3.613635540008545, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_560.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_560.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_560.pkl
=>epoch:561/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [561][0/1]	Time  0.293 ( 0.293)	Data  0.263 ( 0.263)	Loss 3.61360 (3.61360)	d_Loss 1.46297 (1.46297)
=>NAG Epoch: 561 Error: 3.613603353500366, Time:     0.01m
=>epoch:562/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [562][0/1]	Time  0.301 ( 0.301)	Data  0.270 ( 0.270)	Loss 3.61384 (3.61384)	d_Loss 1.46295 (1.46295)
=>NAG Epoch: 562 Error: 3.6138358116149902, Time:     0.01m
=>consecutive_loss: 1
=>epoch:563/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [563][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.61189 (3.6

Epoch: [592][0/1]	Time  0.291 ( 0.291)	Data  0.262 ( 0.262)	Loss 3.60584 (3.60584)	d_Loss 1.46277 (1.46277)
=>NAG Epoch: 592 Error: 3.6058402061462402, Time:     0.01m
=>consecutive_loss: 3
=>epoch:593/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [593][0/1]	Time  0.290 ( 0.290)	Data  0.261 ( 0.261)	Loss 3.60599 (3.60599)	d_Loss 1.46284 (1.46284)
=>NAG Epoch: 593 Error: 3.605985403060913, Time:     0.01m
=>consecutive_loss: 4
=>epoch:594/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [594][0/1]	Time  0.288 ( 0.288)	Data  0.259 ( 0.259)	Loss 3.60406 (3.60406)	d_Loss 1.46287 (1.46287)
=>NAG Epoch: 594 Error: 3.604062795639038, Time:     0.01m
=>epoch:595/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [595][0/1]	Time  0.300 ( 0.300)	Data  0.270 ( 0.270)	Loss 3.60455 (3.60455)	d_Loss 1.46281 (1.46281)
=>NAG Epoch: 595 Error: 3.6045451164245605, Time:     0.01m
=>consecutive_loss: 1
=>epoch:596/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [596][0/1]	Time  0.298 ( 0.29

Epoch: [626][0/1]	Time  0.300 ( 0.300)	Data  0.270 ( 0.270)	Loss 3.59499 (3.59499)	d_Loss 1.46268 (1.46268)
=>NAG Epoch: 626 Error: 3.5949912071228027, Time:     0.01m
=>epoch:627/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [627][0/1]	Time  0.299 ( 0.299)	Data  0.268 ( 0.268)	Loss 3.59770 (3.59770)	d_Loss 1.46261 (1.46261)
=>NAG Epoch: 627 Error: 3.5977025032043457, Time:     0.01m
=>consecutive_loss: 1
=>epoch:628/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [628][0/1]	Time  0.301 ( 0.301)	Data  0.271 ( 0.271)	Loss 3.59483 (3.59483)	d_Loss 1.46274 (1.46274)
=>NAG Epoch: 628 Error: 3.5948326587677, Time:     0.01m
=>epoch:629/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [629][0/1]	Time  0.302 ( 0.302)	Data  0.270 ( 0.270)	Loss 3.59701 (3.59701)	d_Loss 1.46262 (1.46262)
=>NAG Epoch: 629 Error: 3.597005844116211, Time:     0.01m
=>consecutive_loss: 1
=>epoch:630/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [630][0/1]	Time  0.310 ( 0.310)	Data  0.280 ( 0.280)	

Epoch: [659][0/1]	Time  0.290 ( 0.290)	Data  0.261 ( 0.261)	Loss 3.59238 (3.59238)	d_Loss 1.46248 (1.46248)
=>NAG Epoch: 659 Error: 3.5923776626586914, Time:     0.01m
=>consecutive_loss: 10
=>epoch:660/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [660][0/1]	Time  0.306 ( 0.306)	Data  0.274 ( 0.274)	Loss 3.59066 (3.59066)	d_Loss 1.46248 (1.46248)
=>NAG Epoch: 660 Error: 3.5906574726104736, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_660.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_660.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_660.pkl
=>epoch:661/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [661][0/1]	Time  0.298 ( 0.298)	Data  0.267 ( 0.267)	Loss 3.58993 (3.58993)	d_Loss 1.46253 (1.46253)
=>NAG Epoch: 661 Error: 3.5899322032928467, Time:     0.01m
=>epoch:662/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [662][0/1]	Time  0.313 ( 0.313)	Data  0.283 ( 0.283)	Loss 3.59125 (

Epoch: [690][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.58800 (3.58800)	d_Loss 1.46243 (1.46243)
=>NAG Epoch: 690 Error: 3.5879971981048584, Time:     0.01m
=>consecutive_loss: 2
=>epoch:691/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [691][0/1]	Time  0.303 ( 0.303)	Data  0.273 ( 0.273)	Loss 3.58690 (3.58690)	d_Loss 1.46243 (1.46243)
=>NAG Epoch: 691 Error: 3.586901903152466, Time:     0.01m
=>consecutive_loss: 3
=>epoch:692/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [692][0/1]	Time  0.291 ( 0.291)	Data  0.261 ( 0.261)	Loss 3.58552 (3.58552)	d_Loss 1.46244 (1.46244)
=>NAG Epoch: 692 Error: 3.585519552230835, Time:     0.01m
=>epoch:693/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [693][0/1]	Time  0.289 ( 0.289)	Data  0.260 ( 0.260)	Loss 3.58700 (3.58700)	d_Loss 1.46235 (1.46235)
=>NAG Epoch: 693 Error: 3.58699893951416, Time:     0.01m
=>consecutive_loss: 1
=>epoch:694/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [694][0/1]	Time  0.290 ( 0.290)

Epoch: [724][0/1]	Time  0.296 ( 0.296)	Data  0.266 ( 0.266)	Loss 3.58417 (3.58417)	d_Loss 1.46233 (1.46233)
=>NAG Epoch: 724 Error: 3.5841710567474365, Time:     0.01m
=>consecutive_loss: 10
=>epoch:725/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [725][0/1]	Time  0.289 ( 0.289)	Data  0.260 ( 0.260)	Loss 3.58429 (3.58429)	d_Loss 1.46228 (1.46228)
=>NAG Epoch: 725 Error: 3.5842902660369873, Time:     0.01m
=>consecutive_loss: 11
=>epoch:726/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [726][0/1]	Time  0.290 ( 0.290)	Data  0.260 ( 0.260)	Loss 3.58444 (3.58444)	d_Loss 1.46227 (1.46227)
=>NAG Epoch: 726 Error: 3.5844411849975586, Time:     0.01m
=>consecutive_loss: 12
=>epoch:727/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [727][0/1]	Time  0.295 ( 0.295)	Data  0.265 ( 0.265)	Loss 3.58311 (3.58311)	d_Loss 1.46227 (1.46227)
=>NAG Epoch: 727 Error: 3.583106517791748, Time:     0.01m
=>consecutive_loss: 13
=>epoch:728/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [7

In [6]:
import importlib
import interpolate
importlib.reload(interpolate)
import interpolate
from interpolate import interpolate_points


def getGenImage(z1,z2,gen,steps,codeSize):
    interp = interpolate_points(z1,z2,n_steps=steps, slerp=True, print_mode=False)
    code = torch.cuda.FloatTensor(steps, codeSize).normal_(0, 0.15)
    im = gen(interp,code)
    return im

In [7]:
import time
import itertools
class GlicoLoader:
    
    def getInterpolations(self):
        interps = [[x for x in itertools.combinations(self.netZ.label2idx[y], 2)] for y in range(10)]
        self.interps = {}
        for i,combinations in enumerate(interps):
            zvecs = self.netZ(torch.tensor(combinations).cuda())
            temp = []
            for idx in range(zvecs.size(0)):
                z1 = zvecs[idx,0,:]
                z2 = zvecs[idx,1,:]
                gen = getGenImage(z1,z2,self.netG,steps=self.steps,codeSize= self.nag.code_size)
                temp.append(gen.detach())
                
            self.interps[i] = torch.cat(temp,dim=0)
            
        
    def __init__(self,nagTrainer,steps):
        self.nagTrainer = nagTrainer
        self.nag = nagTrainer.nag
        self.netZ = self.nag.netZ
        self.netG = self.nag.netG
        self.steps = steps
        self.getInterpolations()
        
        self.indices = [0 for x in range(10)]
        a = self.interps[0].size(0)
        prng = RandomState(int(time.time()))
        self.indexers = [list(prng.permutation(np.arange(0,a))) for x in range(10)]
        
        
    def sample(self,classNum):
        if self.interps[0].size(0) == self.indices[classNum]:
            self.indices[classNum] = 0
            prng = RandomState(int(time.time()))
            self.indexers[classNum] = list(prng.permutation(np.arange(0,self.interps[0].size(0))))
            
        temp = self.interps[classNum][self.indexers[classNum][self.indices[classNum]],:,:,:]
            
        return temp
    
    def replaceBatch(self,batch,targets,replaceProb):
        for x in range(batch.size(0)):
            prob = np.random.rand()
            if prob < replaceProb:
                batch[x,:,:,:] = self.sample(targets[x].item())
            else:
                pass
        return batch
        
        
        
        

In [8]:
glicoLoader = GlicoLoader(nagTrainer,5)

In [9]:
import matplotlib

c_to_idx = {"airplane" : 0,
"automobile" : 1,
"bird" : 2,
"cat" : 3,
"deer" : 4,
"dog" : 5,
"frog" : 6,
"horse" : 7,
"ship" : 8,
"truck" : 9}

# idx_to_c =  {v:k for k,v in c_to_idx.items()}
# save = True
# matplotlib.use( 'agg' )

# for _ in range(5):
#     for class_ in range(10):
#         gen = glicoLoader.sample(classNum=class_)
#         plt.imshow(gen.cpu().permute(1, 2, 0))
#         plt.axis("off")
#         if save: 
#             plt.savefig(os.path.join("{}{}.pdf".format(idx_to_c[class_],_)),
#                 bbox_inches="tight",pad_inches=0)
        


# plt.show()

In [10]:

model = getModel(MODEL).cuda()
optimizer,LR = getOptimizer128(OPTIM,model.parameters())

print(' => Total trainable parameters: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))        

trainTracker = {"meanLoss":[],"accuracy":[]}
valTracker = {"allLoss":[],"allAcc":[],"meanLoss":[],"meanAcc":[],"stdAcc":[]}
latexTracker = []

print("Begin Train for {} epochs".format(EPOCH_NUM))
for epoch in range(EPOCH_NUM):
    acc, loss = glicoTrain(model, device, train_data[0], optimizer, 
                      epoch+1,glicoLoader,replaceProb=REPLACE_PROB, display=True)
    trainTracker["meanLoss"].append(loss)
    trainTracker["accuracy"].append(acc)
    
    if (epoch+1) % VAL_DISPLAY_DIVISOR == 0 or epoch == 0:
        checkTest(model,device,valSets,valTracker,latexTracker,epoch+1,
              model_name=MODEL,optim_name=OPTIM,lr=LR,totalTestSamples=VAL_SAMPLE_NUM*VALIDATION_SET_NUM,
                  seed=seed,verbose=True)
        
          
        
        

 => Total trainable parameters: 36.48M
Begin Train for 2000 epochs
Train Epoch: 1 [acc: 7%]	Loss: 2.310721
Train Epoch: 2 [acc: 13%]	Loss: 2.516140
Train Epoch: 3 [acc: 13%]	Loss: 2.499593
Train Epoch: 4 [acc: 13%]	Loss: 2.367805
Train Epoch: 5 [acc: 18%]	Loss: 2.383141
Train Epoch: 6 [acc: 17%]	Loss: 2.221243
Train Epoch: 7 [acc: 19%]	Loss: 2.216757
Train Epoch: 8 [acc: 24%]	Loss: 2.215537
Train Epoch: 9 [acc: 17%]	Loss: 2.174945
Train Epoch: 10 [acc: 15%]	Loss: 2.187501
Train Epoch: 11 [acc: 18%]	Loss: 2.144031
Train Epoch: 12 [acc: 23%]	Loss: 2.054556
Train Epoch: 13 [acc: 26%]	Loss: 2.151839
Train Epoch: 14 [acc: 17%]	Loss: 2.124151
Train Epoch: 15 [acc: 17%]	Loss: 2.195880
Train Epoch: 16 [acc: 24%]	Loss: 2.128168
Train Epoch: 17 [acc: 27%]	Loss: 2.124317
Train Epoch: 18 [acc: 30%]	Loss: 2.045443
Train Epoch: 19 [acc: 23%]	Loss: 2.133019
Train Epoch: 20 [acc: 27%]	Loss: 2.000544
Train Epoch: 21 [acc: 22%]	Loss: 1.996009
Train Epoch: 22 [acc: 25%]	Loss: 2.046394
Train Epoch: 23 [ac



[Trained for 25 epochs and tested on 5 sets of 2000 images]        Avg Acc: 16.11 +- 0.43 , Avg Loss: 4.98
Train Epoch: 26 [acc: 20%]	Loss: 2.116291
Train Epoch: 27 [acc: 23%]	Loss: 2.069545
Train Epoch: 28 [acc: 23%]	Loss: 2.121457
Train Epoch: 29 [acc: 27%]	Loss: 2.002287
Train Epoch: 30 [acc: 19%]	Loss: 2.124004
Train Epoch: 31 [acc: 21%]	Loss: 2.069081
Train Epoch: 32 [acc: 24%]	Loss: 2.072709
Train Epoch: 33 [acc: 21%]	Loss: 2.114552
Train Epoch: 34 [acc: 19%]	Loss: 2.056277
Train Epoch: 35 [acc: 21%]	Loss: 2.025362
Train Epoch: 36 [acc: 22%]	Loss: 1.961313
Train Epoch: 37 [acc: 23%]	Loss: 2.024331
Train Epoch: 38 [acc: 23%]	Loss: 1.964233
Train Epoch: 39 [acc: 25%]	Loss: 2.103429
Train Epoch: 40 [acc: 35%]	Loss: 1.966812
Train Epoch: 41 [acc: 32%]	Loss: 1.965144
Train Epoch: 42 [acc: 24%]	Loss: 2.013800
Train Epoch: 43 [acc: 28%]	Loss: 1.985408
Train Epoch: 44 [acc: 29%]	Loss: 1.889034
Train Epoch: 45 [acc: 35%]	Loss: 1.959948
Train Epoch: 46 [acc: 31%]	Loss: 1.979236
Train Epoch

[Trained for 200 epochs and tested on 5 sets of 2000 images]        Avg Acc: 28.62 +- 0.91 , Avg Loss: 3.03
Train Epoch: 201 [acc: 54%]	Loss: 1.424509
Train Epoch: 202 [acc: 50%]	Loss: 1.508255
Train Epoch: 203 [acc: 44%]	Loss: 1.575523
Train Epoch: 204 [acc: 48%]	Loss: 1.422364
Train Epoch: 205 [acc: 48%]	Loss: 1.377148
Train Epoch: 206 [acc: 49%]	Loss: 1.414764
Train Epoch: 207 [acc: 40%]	Loss: 1.554871
Train Epoch: 208 [acc: 48%]	Loss: 1.445502
Train Epoch: 209 [acc: 52%]	Loss: 1.471260
Train Epoch: 210 [acc: 51%]	Loss: 1.464094
Train Epoch: 211 [acc: 52%]	Loss: 1.395939
Train Epoch: 212 [acc: 54%]	Loss: 1.533380
Train Epoch: 213 [acc: 53%]	Loss: 1.476205
Train Epoch: 214 [acc: 53%]	Loss: 1.431048
Train Epoch: 215 [acc: 47%]	Loss: 1.399814
Train Epoch: 216 [acc: 48%]	Loss: 1.541440
Train Epoch: 217 [acc: 44%]	Loss: 1.533271
Train Epoch: 218 [acc: 45%]	Loss: 1.582577
Train Epoch: 219 [acc: 47%]	Loss: 1.486198
Train Epoch: 220 [acc: 45%]	Loss: 1.440551
Train Epoch: 221 [acc: 47%]	Loss

Train Epoch: 374 [acc: 75%]	Loss: 0.707633
Train Epoch: 375 [acc: 77%]	Loss: 0.777782
[Trained for 375 epochs and tested on 5 sets of 2000 images]        Avg Acc: 28.88 +- 0.58 , Avg Loss: 3.39
Train Epoch: 376 [acc: 69%]	Loss: 0.877320
Train Epoch: 377 [acc: 75%]	Loss: 0.729101
Train Epoch: 378 [acc: 75%]	Loss: 0.773058
Train Epoch: 379 [acc: 71%]	Loss: 0.849803
Train Epoch: 380 [acc: 79%]	Loss: 0.662570
Train Epoch: 381 [acc: 76%]	Loss: 0.800932
Train Epoch: 382 [acc: 77%]	Loss: 0.716508
Train Epoch: 383 [acc: 72%]	Loss: 0.803547
Train Epoch: 384 [acc: 76%]	Loss: 0.758541
Train Epoch: 385 [acc: 81%]	Loss: 0.665457
Train Epoch: 386 [acc: 82%]	Loss: 0.667425
Train Epoch: 387 [acc: 72%]	Loss: 0.879336
Train Epoch: 388 [acc: 73%]	Loss: 0.869117
Train Epoch: 389 [acc: 78%]	Loss: 0.822597
Train Epoch: 390 [acc: 76%]	Loss: 0.775286
Train Epoch: 391 [acc: 71%]	Loss: 0.850131
Train Epoch: 392 [acc: 73%]	Loss: 0.819240
Train Epoch: 393 [acc: 77%]	Loss: 0.694973
Train Epoch: 394 [acc: 62%]	Loss

Train Epoch: 547 [acc: 85%]	Loss: 0.455451
Train Epoch: 548 [acc: 85%]	Loss: 0.457053
Train Epoch: 549 [acc: 81%]	Loss: 0.630164
Train Epoch: 550 [acc: 90%]	Loss: 0.388346
[Trained for 550 epochs and tested on 5 sets of 2000 images]        Avg Acc: 32.48 +- 0.83 , Avg Loss: 3.49
Train Epoch: 551 [acc: 87%]	Loss: 0.425309
Train Epoch: 552 [acc: 76%]	Loss: 0.681576
Train Epoch: 553 [acc: 84%]	Loss: 0.437723
Train Epoch: 554 [acc: 88%]	Loss: 0.430569
Train Epoch: 555 [acc: 84%]	Loss: 0.480617
Train Epoch: 556 [acc: 82%]	Loss: 0.519431
Train Epoch: 557 [acc: 85%]	Loss: 0.490569
Train Epoch: 558 [acc: 85%]	Loss: 0.446392
Train Epoch: 559 [acc: 85%]	Loss: 0.444860
Train Epoch: 560 [acc: 89%]	Loss: 0.416115
Train Epoch: 561 [acc: 87%]	Loss: 0.497296
Train Epoch: 562 [acc: 81%]	Loss: 0.579522
Train Epoch: 563 [acc: 85%]	Loss: 0.498872
Train Epoch: 564 [acc: 82%]	Loss: 0.680264
Train Epoch: 565 [acc: 87%]	Loss: 0.451224
Train Epoch: 566 [acc: 86%]	Loss: 0.376970
Train Epoch: 567 [acc: 86%]	Loss

Train Epoch: 720 [acc: 86%]	Loss: 0.434215
Train Epoch: 721 [acc: 90%]	Loss: 0.344665
Train Epoch: 722 [acc: 83%]	Loss: 0.465253
Train Epoch: 723 [acc: 92%]	Loss: 0.263638
Train Epoch: 724 [acc: 89%]	Loss: 0.382008
Train Epoch: 725 [acc: 86%]	Loss: 0.465212
[Trained for 725 epochs and tested on 5 sets of 2000 images]        Avg Acc: 31.25 +- 0.94 , Avg Loss: 3.99
Train Epoch: 726 [acc: 86%]	Loss: 0.444160
Train Epoch: 727 [acc: 91%]	Loss: 0.248023
Train Epoch: 728 [acc: 93%]	Loss: 0.221045
Train Epoch: 729 [acc: 88%]	Loss: 0.353692
Train Epoch: 730 [acc: 87%]	Loss: 0.403658
Train Epoch: 731 [acc: 83%]	Loss: 0.528544
Train Epoch: 732 [acc: 89%]	Loss: 0.394739
Train Epoch: 733 [acc: 88%]	Loss: 0.380801
Train Epoch: 734 [acc: 90%]	Loss: 0.327715
Train Epoch: 735 [acc: 89%]	Loss: 0.349909
Train Epoch: 736 [acc: 89%]	Loss: 0.308402
Train Epoch: 737 [acc: 87%]	Loss: 0.369115
Train Epoch: 738 [acc: 88%]	Loss: 0.395132
Train Epoch: 739 [acc: 95%]	Loss: 0.214806
Train Epoch: 740 [acc: 89%]	Loss

Train Epoch: 893 [acc: 89%]	Loss: 0.305837
Train Epoch: 894 [acc: 94%]	Loss: 0.242398
Train Epoch: 895 [acc: 85%]	Loss: 0.410225
Train Epoch: 896 [acc: 87%]	Loss: 0.455984
Train Epoch: 897 [acc: 92%]	Loss: 0.223403
Train Epoch: 898 [acc: 86%]	Loss: 0.405915
Train Epoch: 899 [acc: 89%]	Loss: 0.337764
Train Epoch: 900 [acc: 91%]	Loss: 0.328697
[Trained for 900 epochs and tested on 5 sets of 2000 images]        Avg Acc: 33.58 +- 0.56 , Avg Loss: 3.80
Train Epoch: 901 [acc: 88%]	Loss: 0.365829
Train Epoch: 902 [acc: 91%]	Loss: 0.280500
Train Epoch: 903 [acc: 86%]	Loss: 0.400793
Train Epoch: 904 [acc: 91%]	Loss: 0.299073
Train Epoch: 905 [acc: 85%]	Loss: 0.388184
Train Epoch: 906 [acc: 91%]	Loss: 0.259754
Train Epoch: 907 [acc: 88%]	Loss: 0.294186
Train Epoch: 908 [acc: 91%]	Loss: 0.326948
Train Epoch: 909 [acc: 86%]	Loss: 0.426873
Train Epoch: 910 [acc: 90%]	Loss: 0.351987
Train Epoch: 911 [acc: 84%]	Loss: 0.432686
Train Epoch: 912 [acc: 89%]	Loss: 0.411080
Train Epoch: 913 [acc: 91%]	Loss

Train Epoch: 1065 [acc: 93%]	Loss: 0.225480
Train Epoch: 1066 [acc: 93%]	Loss: 0.233333
Train Epoch: 1067 [acc: 82%]	Loss: 0.446198
Train Epoch: 1068 [acc: 95%]	Loss: 0.179901
Train Epoch: 1069 [acc: 92%]	Loss: 0.377622
Train Epoch: 1070 [acc: 95%]	Loss: 0.215319
Train Epoch: 1071 [acc: 93%]	Loss: 0.227502
Train Epoch: 1072 [acc: 94%]	Loss: 0.246355
Train Epoch: 1073 [acc: 94%]	Loss: 0.289882
Train Epoch: 1074 [acc: 89%]	Loss: 0.391351
Train Epoch: 1075 [acc: 90%]	Loss: 0.237601
[Trained for 1075 epochs and tested on 5 sets of 2000 images]        Avg Acc: 33.54 +- 0.56 , Avg Loss: 3.72
Train Epoch: 1076 [acc: 88%]	Loss: 0.297678
Train Epoch: 1077 [acc: 97%]	Loss: 0.173317
Train Epoch: 1078 [acc: 92%]	Loss: 0.310626
Train Epoch: 1079 [acc: 92%]	Loss: 0.292943
Train Epoch: 1080 [acc: 91%]	Loss: 0.295999
Train Epoch: 1081 [acc: 91%]	Loss: 0.257612
Train Epoch: 1082 [acc: 87%]	Loss: 0.356668
Train Epoch: 1083 [acc: 92%]	Loss: 0.245187
Train Epoch: 1084 [acc: 90%]	Loss: 0.317259
Train Epoch

Train Epoch: 1234 [acc: 90%]	Loss: 0.302717
Train Epoch: 1235 [acc: 93%]	Loss: 0.175379
Train Epoch: 1236 [acc: 89%]	Loss: 0.335007
Train Epoch: 1237 [acc: 94%]	Loss: 0.208150
Train Epoch: 1238 [acc: 86%]	Loss: 0.472129
Train Epoch: 1239 [acc: 93%]	Loss: 0.247372
Train Epoch: 1240 [acc: 91%]	Loss: 0.211595
Train Epoch: 1241 [acc: 86%]	Loss: 0.348800
Train Epoch: 1242 [acc: 91%]	Loss: 0.190233
Train Epoch: 1243 [acc: 94%]	Loss: 0.198584
Train Epoch: 1244 [acc: 88%]	Loss: 0.359482
Train Epoch: 1245 [acc: 93%]	Loss: 0.210928
Train Epoch: 1246 [acc: 88%]	Loss: 0.297268
Train Epoch: 1247 [acc: 93%]	Loss: 0.238180
Train Epoch: 1248 [acc: 92%]	Loss: 0.289185
Train Epoch: 1249 [acc: 87%]	Loss: 0.391184
Train Epoch: 1250 [acc: 90%]	Loss: 0.256188
[Trained for 1250 epochs and tested on 5 sets of 2000 images]        Avg Acc: 29.55 +- 1.14 , Avg Loss: 4.54
Train Epoch: 1251 [acc: 95%]	Loss: 0.217475
Train Epoch: 1252 [acc: 88%]	Loss: 0.368473
Train Epoch: 1253 [acc: 88%]	Loss: 0.373994
Train Epoch

Train Epoch: 1403 [acc: 95%]	Loss: 0.191035
Train Epoch: 1404 [acc: 92%]	Loss: 0.217832
Train Epoch: 1405 [acc: 91%]	Loss: 0.198371
Train Epoch: 1406 [acc: 95%]	Loss: 0.195391
Train Epoch: 1407 [acc: 100%]	Loss: 0.060834
Train Epoch: 1408 [acc: 95%]	Loss: 0.171411
Train Epoch: 1409 [acc: 96%]	Loss: 0.147393
Train Epoch: 1410 [acc: 96%]	Loss: 0.126285
Train Epoch: 1411 [acc: 90%]	Loss: 0.260664
Train Epoch: 1412 [acc: 90%]	Loss: 0.282993
Train Epoch: 1413 [acc: 88%]	Loss: 0.284825
Train Epoch: 1414 [acc: 97%]	Loss: 0.134564
Train Epoch: 1415 [acc: 93%]	Loss: 0.235181
Train Epoch: 1416 [acc: 96%]	Loss: 0.114329
Train Epoch: 1417 [acc: 94%]	Loss: 0.286214
Train Epoch: 1418 [acc: 96%]	Loss: 0.104106
Train Epoch: 1419 [acc: 92%]	Loss: 0.284724
Train Epoch: 1420 [acc: 90%]	Loss: 0.328143
Train Epoch: 1421 [acc: 92%]	Loss: 0.213622
Train Epoch: 1422 [acc: 91%]	Loss: 0.206463
Train Epoch: 1423 [acc: 90%]	Loss: 0.349695
Train Epoch: 1424 [acc: 96%]	Loss: 0.175591
Train Epoch: 1425 [acc: 91%]	Lo

Train Epoch: 1575 [acc: 89%]	Loss: 0.277492
[Trained for 1575 epochs and tested on 5 sets of 2000 images]        Avg Acc: 35.06 +- 0.62 , Avg Loss: 4.42
Train Epoch: 1576 [acc: 92%]	Loss: 0.230349
Train Epoch: 1577 [acc: 90%]	Loss: 0.313298
Train Epoch: 1578 [acc: 99%]	Loss: 0.057142
Train Epoch: 1579 [acc: 95%]	Loss: 0.151535
Train Epoch: 1580 [acc: 95%]	Loss: 0.140721
Train Epoch: 1581 [acc: 96%]	Loss: 0.116118
Train Epoch: 1582 [acc: 96%]	Loss: 0.141786
Train Epoch: 1583 [acc: 94%]	Loss: 0.159263
Train Epoch: 1584 [acc: 93%]	Loss: 0.231313
Train Epoch: 1585 [acc: 93%]	Loss: 0.246126
Train Epoch: 1586 [acc: 97%]	Loss: 0.084991
Train Epoch: 1587 [acc: 96%]	Loss: 0.084852
Train Epoch: 1588 [acc: 94%]	Loss: 0.158712
Train Epoch: 1589 [acc: 93%]	Loss: 0.233070
Train Epoch: 1590 [acc: 95%]	Loss: 0.119156
Train Epoch: 1591 [acc: 89%]	Loss: 0.281254
Train Epoch: 1592 [acc: 94%]	Loss: 0.175119
Train Epoch: 1593 [acc: 95%]	Loss: 0.192424
Train Epoch: 1594 [acc: 89%]	Loss: 0.325479
Train Epoch

Train Epoch: 1744 [acc: 92%]	Loss: 0.258450
Train Epoch: 1745 [acc: 95%]	Loss: 0.200334
Train Epoch: 1746 [acc: 92%]	Loss: 0.349224
Train Epoch: 1747 [acc: 90%]	Loss: 0.394925
Train Epoch: 1748 [acc: 94%]	Loss: 0.185571
Train Epoch: 1749 [acc: 95%]	Loss: 0.199728
Train Epoch: 1750 [acc: 96%]	Loss: 0.093285
[Trained for 1750 epochs and tested on 5 sets of 2000 images]        Avg Acc: 32.90 +- 0.67 , Avg Loss: 4.37
Train Epoch: 1751 [acc: 99%]	Loss: 0.065901
Train Epoch: 1752 [acc: 89%]	Loss: 0.335816
Train Epoch: 1753 [acc: 93%]	Loss: 0.194855
Train Epoch: 1754 [acc: 97%]	Loss: 0.118738
Train Epoch: 1755 [acc: 93%]	Loss: 0.198381
Train Epoch: 1756 [acc: 91%]	Loss: 0.310581
Train Epoch: 1757 [acc: 95%]	Loss: 0.143106
Train Epoch: 1758 [acc: 97%]	Loss: 0.146292
Train Epoch: 1759 [acc: 92%]	Loss: 0.199741
Train Epoch: 1760 [acc: 97%]	Loss: 0.086256
Train Epoch: 1761 [acc: 94%]	Loss: 0.187009
Train Epoch: 1762 [acc: 95%]	Loss: 0.152203
Train Epoch: 1763 [acc: 92%]	Loss: 0.270757
Train Epoch

Train Epoch: 1913 [acc: 95%]	Loss: 0.229541
Train Epoch: 1914 [acc: 93%]	Loss: 0.249079
Train Epoch: 1915 [acc: 94%]	Loss: 0.171019
Train Epoch: 1916 [acc: 95%]	Loss: 0.153181
Train Epoch: 1917 [acc: 96%]	Loss: 0.164813
Train Epoch: 1918 [acc: 95%]	Loss: 0.145639
Train Epoch: 1919 [acc: 93%]	Loss: 0.277524
Train Epoch: 1920 [acc: 93%]	Loss: 0.182770
Train Epoch: 1921 [acc: 96%]	Loss: 0.121205
Train Epoch: 1922 [acc: 95%]	Loss: 0.133642
Train Epoch: 1923 [acc: 89%]	Loss: 0.251794
Train Epoch: 1924 [acc: 97%]	Loss: 0.208952
Train Epoch: 1925 [acc: 94%]	Loss: 0.244054
[Trained for 1925 epochs and tested on 5 sets of 2000 images]        Avg Acc: 35.60 +- 0.62 , Avg Loss: 4.03
Train Epoch: 1926 [acc: 93%]	Loss: 0.153874
Train Epoch: 1927 [acc: 99%]	Loss: 0.071766
Train Epoch: 1928 [acc: 93%]	Loss: 0.201775
Train Epoch: 1929 [acc: 93%]	Loss: 0.204459
Train Epoch: 1930 [acc: 94%]	Loss: 0.185598
Train Epoch: 1931 [acc: 94%]	Loss: 0.194224
Train Epoch: 1932 [acc: 94%]	Loss: 0.168477
Train Epoch

In [11]:
dirname = latexTracker[-1][:-2] 

def writeTex(latexTracker,dirname):
    if not os.path.isdir(dirname):
        os.mkdir(dirname)
        
    f= open(os.path.join(dirname,"latexTable.txt"),"w")
    for x in latexTracker:
        f.write(x)
    f.close()

writeTex(latexTracker,dirname)

for x in latexTracker:
    print(x)

WideResNet28&25&16.11&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&50&17.02&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&75&22.75&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&100&21.92&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&125&24.37&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&150&26.47&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&175&27.08&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&200&28.62&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&225&28.76&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&250&27.82&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&275&29.74&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&300&29.38&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&325&26.97&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&350&29.82&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&375&28.88&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&400&31.92&Adam&0.001&10000&Nothing&1620010424\\
WideResNet28&425&32.57&Adam

In [12]:

epochList = [x+1 for x in range(len(trainTracker["meanLoss"]))]

plot(xlist=epochList,ylist=trainTracker["meanLoss"],xlab="Mean Train Loss",
    ylab="Epochs",title="Mean Train Loss over Epochs",
    color="#243A92",label="mean train loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=trainTracker["accuracy"],xlab="Train Accuracy",
    ylab="Epochs",title="Train Accuracy Over Epochs",
    color="#34267E",label="Train Accuracy",savedir=dirname,save=True)

  plt.show()
  plt.show()


In [13]:

epochList = [VAL_DISPLAY_DIVISOR*(x+1) for x in range(len(valTracker["meanLoss"]))]

plot(xlist=epochList,ylist=valTracker["meanLoss"],xlab="Epochs",
    ylab="Mean Val Loss",title="Mean Val Loss over Epochs",
    color="#243A92",label="mean val loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["meanAcc"],xlab="Epochs",
    ylab="Val Accuracy",title="Val Accuracy Over Epochs",
    color="#34267E",label="Val Accuracy",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["stdAcc"],xlab="Epochs",
    ylab="Val Accuracy Standard Deviation",title="Val Accuracy Standard Deviation Over Epochs",
    color="#34267E",label="Val Accuracy SD",savedir=dirname,save=True)


valSetEvalCount = VAL_DISPLAY_DIVISOR * EPOCH_NUM * VALIDATION_SET_NUM
epochList = [VAL_DISPLAY_DIVISOR*(x+1) for x in range(len(valTracker["meanLoss"]))\
             for y in range(VALIDATION_SET_NUM)]


plot(xlist=epochList,ylist=valTracker["allLoss"],xlab="Val Set Evaluations",
    ylab="Val Loss",title="Val loss over val set evaluations ({} \
every {} epochs)".format(VALIDATION_SET_NUM,VAL_DISPLAY_DIVISOR),
    color="#34267E",label="Val Loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["allAcc"],xlab="Val Set Evaluations",
    ylab="Val Accuracy",title="Val loss over val set evaluations ({} \
every {} epochs) ".format(VALIDATION_SET_NUM,VAL_DISPLAY_DIVISOR),
    color="#34267E",label="Val Accuracy",savedir=dirname,save=True)

  plt.show()
  plt.show()
  plt.show()
  plt.show()
  plt.show()
