In [1]:
import PIL
import gc
import torch
import torchvision
import os
import sys

import numpy as np
import matplotlib.pyplot as plt  
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import Subset
from IPython.core.display import display, HTML
from numpy.random import RandomState
from wide_resnet import WideResNet
from auto_augment import AutoAugment, Cutout
from efficientnet_pytorch import EfficientNet
from cifar_loader import SmallSampleController
import torchvision.models as models

sys.path.insert(0,'glico-learning-small-sample/glico_model')

from tester import runGlico



# display(HTML("<style>.container { width:40% !important; }</style>"))


In [2]:

def getAcc(preds,targets):
    return np.sum([1 if preds[i] == targets[i] else 0 for i in range(len(preds))])/len(preds)

def train(model, device, train_loader, optimizer, epoch, display=True):
    """
    Summary: Implements the training procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    optimizer: the optimizer for our training
    train_loader: dataloader for our train data
    display: output flag
    == output ==
    the mean train loss, the train accuracy
    """
    
    lossTracker = []
    
    targets=[]
    preds=[]
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        lossTracker.append(loss.detach())
        with torch.no_grad():
            pred = torch.argmax(output,1).cpu().numpy()
            preds.extend(pred)
            targets.extend(target.cpu().numpy())
        
    lossTracker = [x.item() for x in lossTracker]
    meanLoss = np.mean(lossTracker)
    accuracy = getAcc(preds,targets)
    if display:
        print('Train Epoch: {} [acc: {:.0f}%]\tLoss: {:.6f}'.format(
          epoch, 100. * accuracy, meanLoss))
        
    return accuracy, meanLoss


def glicoTrain(model, device, train_loader, optimizer, epoch, glicoLoader,replaceProb=0.5,display=True):
    """
    Summary: Implements the training procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    optimizer: the optimizer for our training
    train_loader: dataloader for our train data
    display: output flag
    == output ==
    the mean train loss, the train accuracy
    """
    
    lossTracker = []
    
    targets=[]
    preds=[]
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        #replace samples with samples from glico with probability replaceprob
        data = glicoLoader.replaceBatch(data,target,replaceProb) 
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        lossTracker.append(loss.detach())
        with torch.no_grad():
            pred = torch.argmax(output,1).cpu().numpy()
            preds.extend(pred)
            targets.extend(target.cpu().numpy())
        
    lossTracker = [x.item() for x in lossTracker]
    meanLoss = np.mean(lossTracker)
    accuracy = getAcc(preds,targets)
    if display:
        print('Train Epoch: {} [acc: {:.0f}%]\tLoss: {:.6f}'.format(
          epoch, 100. * accuracy, meanLoss))
        
    return accuracy, meanLoss



def test(model, device, test_loader,verbose=True):
    """
    Summary: Implements the testing procedure for a given model
    == params ==
    model: the model to test
    device: cuda or cpu 
    test_loader: dataloader for our test data
    verbose: output flag
    == output ==
    the mean test loss, the test accuracy
    """
    
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            

    meanLoss = test_loss / len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    if verbose: print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        mean_test_loss, correct, len(test_loader.dataset),
        accuracy))
        
    return accuracy, meanLoss


def checkTest(model,device,valSets,valTracker,latexTracker,epoch,
              model_name,optim_name,lr,totalTestSamples,seed,verbose=True):
    """
    Summary: checks the test accuracy, prints, and saves statistics
    """
    tempAcc = []
    tempLoss = []
    for val_loader in valSets:
        acc,loss = test(model, device, val_loader,verbose = False)
        tempAcc.append(acc)
        tempLoss.append(loss)
        
    meanAcc = np.mean(tempAcc)
    stdAcc = np.std(tempAcc)
    
    meanLoss = np.mean(tempLoss)
    if verbose:
        print('[Trained for {} epochs and tested on {} sets of 2000 images]\
        Avg Acc: {:.2f} +- {:.2f} , Avg Loss: {:.2f}'.format(
            epoch,VALIDATION_SET_NUM,meanAcc,stdAcc,meanLoss))
        
        
    tableRow = getLatexRow(architecture=model_name,epoch=epoch,accuracy=meanAcc,optim=optim_name,
                           lr=lr,totalTestSamples=totalTestSamples,dataAug="Nothing",
                           seed=seed,title=False)
    
    latexTracker.append(tableRow)
        
    valTracker["allLoss"].extend(tempLoss)
    valTracker["allAcc"].extend(tempAcc)
    valTracker["meanLoss"].append(meanLoss)
    valTracker["meanAcc"].append(meanAcc)
    valTracker["stdAcc"].append(stdAcc)





In [3]:
def getLatexRow(architecture,epoch,accuracy,optim,lr,
                totalTestSamples,dataAug,seed,title=False):
    """
    Summary: generates one row of latex for a results table
    """
    categories = ["Model","Epoch","Accuracy","Optimizer","lr","Test Sample Num",
                  "data augmentation","seed"]
    row = [str(architecture),str(epoch),str(round(accuracy,3)),str(optim),
           str(lr),str(totalTestSamples),str(dataAug),str(seed)]
    
    if title:
        c = "&".join(categories)
        r = "&".join(row)
        return "{}\\\\\n{}\\\\".format(c,r)
    else:
        r = "&".join(row)
        return "{}\\\\".format(r)
    
    
def plot(xlist,ylist,xlab,ylab,title,color,label,savedir=".",save=False):
    """
    Summary: plots the given list of numbers against its idices and 
    allows for high resolution saving
    """
    fig = plt.figure()
    plt.title(title)
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.plot(xlist,ylist,color=color,marker=".",label=label)
    plt.legend()
    
    if save:
        if not os.path.isdir(savedir):
            os.mkdir(savedir)
        filepath = os.path.join(savedir,"{}".format(title))
        plt.savefig(filepath+".pdf")
        os.system("pdftoppm -png -r 300 {}.pdf {}.png".format(filepath,filepath))
        
    plt.show()
    
    

In [13]:
def getModel(model_name):
    if "wide" in model_name.lower():
        return WideResNet(28, 10, num_classes=10)
    elif "efficient" in model_name.lower():
        return EfficientNet.from_pretrained(model_name,num_classes = 10) # change to not be pretrained
    elif "vgg16" in model_name.lower():
        model = models.vgg16(pretrained=True)
#         model.classifier[6] = nn.Linear(4096, 10)
        return model
    elif "alexnet" in model_name.lower():
        model = models.alexnet(pretrained=True)
#         model.classifier = nn.Linear(256 * 6 * 6, 10)
        return model
    elif "resnet18" in model_name.lower():
        model = models.resnet18(pretrained=True)
#         model.fc.out_features = 10
        return model
    elif "resnet50" in model_name.lower():
        model = models.resnet50(pretrained=True)
#         model.fc.out_features = 10
        return model
    elif "densenet161" in model_name.lower():
        model = models.densenet161(pretrained=True)
#         model.fc.out_features = 10
        return model
    elif "wideresnet" in model_name.lower():
            model = models.wide_resnet50_2(pretrained=True)
            return model
    elif "resnext101" in model_name.lower():
            model = models.resnext101_32x8d(pretrained=True)
            return model
    elif "inception_v3" in model_name.lower():
            model = models.inception_v3(pretrained=True,aux_logits=False)
            return model
    elif "squeezenet" in model_name.lower():
        model = models.squeezenet1_0(pretrained=True)
        model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
        return model
    
    
def getOptimizer128(optimizer_name,parameters):
    if "sgd" in  optimizer_name.lower():
        LR = 0.01
        optim = torch.optim.SGD(parameters, 
                                  lr=LR, momentum=0.9,
                                  weight_decay=0.0005)
        return optim, LR
    elif "adam" in optimizer_name.lower():
        LR = 0.001
        optim = torch.optim.Adam(parameters, 
                              lr=LR, weight_decay=0)
        return optim, LR
        
    

In [5]:
torch.cuda.empty_cache()
gc.collect()

OPTIM = "sgd"
MODEL = "alexnet"
EPOCH_NUM = 2000
TRAIN_SAMPLE_NUM = 100
VAL_SAMPLE_NUM = 2000
BATCH_SIZE = 128
VALIDATION_SET_NUM = 5
AUGMENT = True
VAL_DISPLAY_DIVISOR = 25
CIFAR_TRAIN = True
REPLACE_PROB = 0.1

#cifar-10:
#mean = (0.4914, 0.4822, 0.4465)
#std = (0.247, 0.243, 0.261)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
if AUGMENT:
    dataAugmentation = [ 
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        AutoAugment(),
        Cutout()
    ]
    augment = "Crop,Flip,AutoAugment,Cutout"
else: 
    dataAugmentation = []
    augment = "Nothing"



transform_train = transforms.Compose(dataAugmentation + [transforms.ToTensor(), normalize]) 
transform_val = transforms.Compose([transforms.ToTensor(), normalize]) #careful to keep this one same

glico_train = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, download=True)
cifar_train = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, transform=transform_train, download=True)
cifar_val = datasets.CIFAR10(root='.',train=CIFAR_TRAIN, transform=transform_val, download=True)

ss = SmallSampleController(numClasses=10,trainSampleNum=TRAIN_SAMPLE_NUM, # abstract the data-loading procedure
                           valSampleNum=VAL_SAMPLE_NUM, batchSize=BATCH_SIZE, 
                           multiplier=VALIDATION_SET_NUM, trainDataset=cifar_train, 
                           valDataset=cifar_val)
    
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data, valSets, seed = ss.generateNewSet(device,valMultiplier = VALIDATION_SET_NUM) #Sample from datasets


train_labeled_dataset = Subset(glico_train, ss.trainSampler.indexes[0]) #get the same subset without transform
nagTrainer = runGlico(train_labeled_dataset=train_labeled_dataset, classes=10,epochs=750)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Generated new permutation of the CIFAR train dataset with                 seed:1619996852, train sample num: 100, test sample num: 2000



[*** NEW PRINT ***]nag_params: NAGParams(nz=512, force_l2=False, is_pixel=True, z_init='rndm', is_classifier=True, disc_net='conv', loss='ce', data_name='cifar-10', noise_proj=True, shot=0), rn: tester.py_cifar100_test_run, dataset <class 'list'>



[DEBUG] data name=cifar-10, data res = 32
[NEW CODE] image_params:ImageParams(sz=(32, 32), nc=3, n=100, mu=None, sd=None)
not eval | augment : TRUE
=>Generated data loader, res=32, workers=6 transform=Compose(
    ToTensor()
) sampler=None
[DEBUG] dataset name: <class 'torch.utils.data.dataset.subset'>,offset_labels=0
torch.Size([100, 512])
init_models_weights(self): init rndm
num classes = 10
dataset size = 100
not eval | augment : TRUE
=>Generated data loader, res=32, workers=6 transform=Co

  return self.softmax(x)


Epoch: [0][0/1]	Time  0.821 ( 0.821)	Data  0.205 ( 0.205)	Loss 8.65532 (8.65532)	d_Loss 2.30658 (2.30658)
=>NAG Epoch: 0 Error: 8.655317306518555, Time:     0.02m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_0.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_0.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_0.pkl
Show images...
=>epoch:1/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [1][0/1]	Time  0.260 ( 0.260)	Data  0.196 ( 0.196)	Loss 8.56495 (8.56495)	d_Loss 2.30087 (2.30087)
=>NAG Epoch: 1 Error: 8.564947128295898, Time:     0.01m
=>epoch:2/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [2][0/1]	Time  0.259 ( 0.259)	Data  0.225 ( 0.225)	Loss 8.43421 (8.43421)	d_Loss 2.28885 (2.28885)
=>NAG Epoch: 2 Error: 8.434211730957031, Time:     0.01m
=>epoch:3/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [3][0/1]	Time  0.441 ( 0.441)	Data  0.379 ( 0.379)	Loss 8.32316 (8.32316)	d_Loss 2.28463 (2.28463)
=>N

Epoch: [31][0/1]	Time  0.265 ( 0.265)	Data  0.197 ( 0.197)	Loss 7.28584 (7.28584)	d_Loss 2.16363 (2.16363)
=>NAG Epoch: 31 Error: 7.285839557647705, Time:     0.01m
=>epoch:32/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [32][0/1]	Time  0.237 ( 0.237)	Data  0.208 ( 0.208)	Loss 7.15735 (7.15735)	d_Loss 2.02980 (2.02980)
=>NAG Epoch: 32 Error: 7.157351970672607, Time:     0.00m
=>epoch:33/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [33][0/1]	Time  0.276 ( 0.276)	Data  0.210 ( 0.210)	Loss 7.19607 (7.19607)	d_Loss 2.07454 (2.07454)
=>NAG Epoch: 33 Error: 7.196073532104492, Time:     0.01m
=>consecutive_loss: 1
=>epoch:34/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [34][0/1]	Time  0.249 ( 0.249)	Data  0.210 ( 0.210)	Loss 7.07715 (7.07715)	d_Loss 2.00619 (2.00619)
=>NAG Epoch: 34 Error: 7.077151298522949, Time:     0.01m
=>epoch:35/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [35][0/1]	Time  0.242 ( 0.242)	Data  0.213 ( 0.213)	Loss 7.10423 (7.10423)	d_Loss 2.017

Epoch: [64][0/1]	Time  0.270 ( 0.270)	Data  0.203 ( 0.203)	Loss 6.06760 (6.06760)	d_Loss 1.47508 (1.47508)
=>NAG Epoch: 64 Error: 6.067599773406982, Time:     0.01m
=>consecutive_loss: 1
=>epoch:65/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [65][0/1]	Time  0.260 ( 0.260)	Data  0.191 ( 0.191)	Loss 6.00143 (6.00143)	d_Loss 1.47351 (1.47351)
=>NAG Epoch: 65 Error: 6.001429557800293, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_65.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_65.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_65.pkl
=>epoch:66/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [66][0/1]	Time  0.230 ( 0.230)	Data  0.201 ( 0.201)	Loss 6.00200 (6.00200)	d_Loss 1.47344 (1.47344)
=>NAG Epoch: 66 Error: 6.002002239227295, Time:     0.00m
=>consecutive_loss: 1
=>epoch:67/750  resolution:32 batch:128, lr1.000000e-03
Epoch: [67][0/1]	Time  0.262 ( 0.262)	Data  0.194 ( 0.194)	Loss 5.94

Epoch: [95][0/1]	Time  0.244 ( 0.244)	Data  0.212 ( 0.212)	Loss 5.31075 (5.31075)	d_Loss 1.47017 (1.47017)
=>NAG Epoch: 95 Error: 5.310746192932129, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_95.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_95.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_95.pkl
=>epoch:96/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [96][0/1]	Time  0.314 ( 0.314)	Data  0.248 ( 0.248)	Loss 5.28442 (5.28442)	d_Loss 1.46939 (1.46939)
=>NAG Epoch: 96 Error: 5.284422397613525, Time:     0.01m
=>epoch:97/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [97][0/1]	Time  0.300 ( 0.300)	Data  0.232 ( 0.232)	Loss 5.28366 (5.28366)	d_Loss 1.47054 (1.47054)
=>NAG Epoch: 97 Error: 5.2836594581604, Time:     0.01m
=>epoch:98/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [98][0/1]	Time  0.304 ( 0.304)	Data  0.238 ( 0.238)	Loss 5.25785 (5.25785)	d_Loss 1.46955 (1.46955)
=>NAG E

Epoch: [126][0/1]	Time  0.232 ( 0.232)	Data  0.197 ( 0.197)	Loss 4.92640 (4.92640)	d_Loss 1.46893 (1.46893)
=>NAG Epoch: 126 Error: 4.926395416259766, Time:     0.00m
=>epoch:127/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [127][0/1]	Time  0.236 ( 0.236)	Data  0.206 ( 0.206)	Loss 4.92384 (4.92384)	d_Loss 1.46914 (1.46914)
=>NAG Epoch: 127 Error: 4.923844814300537, Time:     0.00m
=>epoch:128/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [128][0/1]	Time  0.270 ( 0.270)	Data  0.205 ( 0.205)	Loss 4.90723 (4.90723)	d_Loss 1.46884 (1.46884)
=>NAG Epoch: 128 Error: 4.907231330871582, Time:     0.01m
=>epoch:129/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [129][0/1]	Time  0.245 ( 0.245)	Data  0.215 ( 0.215)	Loss 4.90462 (4.90462)	d_Loss 1.46874 (1.46874)
=>NAG Epoch: 129 Error: 4.904623508453369, Time:     0.01m
=>epoch:130/750  resolution:32 batch:128, lr5.000000e-04
Epoch: [130][0/1]	Time  0.270 ( 0.270)	Data  0.204 ( 0.204)	Loss 4.88785 (4.88785)	d_Loss 1.46851 (1.468

Epoch: [158][0/1]	Time  0.261 ( 0.261)	Data  0.230 ( 0.230)	Loss 4.54716 (4.54716)	d_Loss 1.46758 (1.46758)
=>NAG Epoch: 158 Error: 4.547164440155029, Time:     0.01m
=>epoch:159/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [159][0/1]	Time  0.279 ( 0.279)	Data  0.209 ( 0.209)	Loss 4.54289 (4.54289)	d_Loss 1.46760 (1.46760)
=>NAG Epoch: 159 Error: 4.542888641357422, Time:     0.01m
=>epoch:160/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [160][0/1]	Time  0.249 ( 0.249)	Data  0.217 ( 0.217)	Loss 4.53163 (4.53163)	d_Loss 1.46736 (1.46736)
=>NAG Epoch: 160 Error: 4.5316267013549805, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_160.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_160.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_160.pkl
=>epoch:161/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [161][0/1]	Time  0.293 ( 0.293)	Data  0.224 ( 0.224)	Loss 4.52758 (4.52758)	d_Loss 1.46770 (

Epoch: [189][0/1]	Time  0.296 ( 0.296)	Data  0.228 ( 0.228)	Loss 4.34822 (4.34822)	d_Loss 1.46724 (1.46724)
=>NAG Epoch: 189 Error: 4.348221778869629, Time:     0.01m
=>epoch:190/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [190][0/1]	Time  0.295 ( 0.295)	Data  0.230 ( 0.230)	Loss 4.33621 (4.33621)	d_Loss 1.46681 (1.46681)
=>NAG Epoch: 190 Error: 4.336205959320068, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_190.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_190.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_190.pkl
=>epoch:191/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [191][0/1]	Time  0.283 ( 0.283)	Data  0.221 ( 0.221)	Loss 4.33374 (4.33374)	d_Loss 1.46711 (1.46711)
=>NAG Epoch: 191 Error: 4.333739280700684, Time:     0.01m
=>epoch:192/750  resolution:32 batch:128, lr2.500000e-04
Epoch: [192][0/1]	Time  0.269 ( 0.269)	Data  0.202 ( 0.202)	Loss 4.32412 (4.32412)	d_Loss 1.46698 (1

Epoch: [220][0/1]	Time  0.275 ( 0.275)	Data  0.210 ( 0.210)	Loss 4.13842 (4.13842)	d_Loss 1.46622 (1.46622)
=>NAG Epoch: 220 Error: 4.138423919677734, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_220.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_220.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_220.pkl
=>epoch:221/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [221][0/1]	Time  0.264 ( 0.264)	Data  0.204 ( 0.204)	Loss 4.13802 (4.13802)	d_Loss 1.46593 (1.46593)
=>NAG Epoch: 221 Error: 4.1380228996276855, Time:     0.01m
=>epoch:222/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [222][0/1]	Time  0.238 ( 0.238)	Data  0.208 ( 0.208)	Loss 4.13294 (4.13294)	d_Loss 1.46590 (1.46590)
=>NAG Epoch: 222 Error: 4.132941722869873, Time:     0.01m
=>epoch:223/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [223][0/1]	Time  0.286 ( 0.286)	Data  0.221 ( 0.221)	Loss 4.12796 (4.12796)	d_Loss 1.46610 (

=>epoch:251/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [251][0/1]	Time  0.292 ( 0.292)	Data  0.221 ( 0.221)	Loss 4.01524 (4.01524)	d_Loss 1.46542 (1.46542)
=>NAG Epoch: 251 Error: 4.015237808227539, Time:     0.01m
=>epoch:252/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [252][0/1]	Time  0.234 ( 0.234)	Data  0.196 ( 0.196)	Loss 4.00994 (4.00994)	d_Loss 1.46581 (1.46581)
=>NAG Epoch: 252 Error: 4.00993537902832, Time:     0.00m
=>epoch:253/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [253][0/1]	Time  0.245 ( 0.245)	Data  0.214 ( 0.214)	Loss 4.00960 (4.00960)	d_Loss 1.46534 (1.46534)
=>NAG Epoch: 253 Error: 4.0096049308776855, Time:     0.00m
=>epoch:254/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [254][0/1]	Time  0.274 ( 0.274)	Data  0.207 ( 0.207)	Loss 4.00268 (4.00268)	d_Loss 1.46570 (1.46570)
=>NAG Epoch: 254 Error: 4.0026774406433105, Time:     0.01m
=>epoch:255/750  resolution:32 batch:128, lr1.250000e-04
Epoch: [255][0/1]	Time  0.253 ( 0.253)	Data  0

Epoch: [282][0/1]	Time  0.301 ( 0.301)	Data  0.231 ( 0.231)	Loss 3.88971 (3.88971)	d_Loss 1.46521 (1.46521)
=>NAG Epoch: 282 Error: 3.889713764190674, Time:     0.01m
=>epoch:283/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [283][0/1]	Time  0.323 ( 0.323)	Data  0.248 ( 0.248)	Loss 3.88675 (3.88675)	d_Loss 1.46518 (1.46518)
=>NAG Epoch: 283 Error: 3.8867502212524414, Time:     0.01m
=>epoch:284/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [284][0/1]	Time  0.290 ( 0.290)	Data  0.214 ( 0.214)	Loss 3.88199 (3.88199)	d_Loss 1.46514 (1.46514)
=>NAG Epoch: 284 Error: 3.881990909576416, Time:     0.01m
=>epoch:285/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [285][0/1]	Time  0.284 ( 0.284)	Data  0.207 ( 0.207)	Loss 3.88306 (3.88306)	d_Loss 1.46515 (1.46515)
=>NAG Epoch: 285 Error: 3.8830604553222656, Time:     0.01m
=>consecutive_loss: 1
=>epoch:286/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [286][0/1]	Time  0.253 ( 0.253)	Data  0.224 ( 0.224)	Loss 3.87702 (3.8770

Epoch: [314][0/1]	Time  0.247 ( 0.247)	Data  0.218 ( 0.218)	Loss 3.80536 (3.80536)	d_Loss 1.46474 (1.46474)
=>NAG Epoch: 314 Error: 3.805361270904541, Time:     0.00m
=>epoch:315/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [315][0/1]	Time  0.274 ( 0.274)	Data  0.205 ( 0.205)	Loss 3.80373 (3.80373)	d_Loss 1.46482 (1.46482)
=>NAG Epoch: 315 Error: 3.8037333488464355, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_315.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_315.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_315.pkl
=>epoch:316/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [316][0/1]	Time  0.243 ( 0.243)	Data  0.212 ( 0.212)	Loss 3.80292 (3.80292)	d_Loss 1.46488 (1.46488)
=>NAG Epoch: 316 Error: 3.80291748046875, Time:     0.01m
=>epoch:317/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [317][0/1]	Time  0.295 ( 0.295)	Data  0.227 ( 0.227)	Loss 3.79833 (3.79833)	d_Loss 1.46479 (1

Epoch: [346][0/1]	Time  0.252 ( 0.252)	Data  0.205 ( 0.205)	Loss 3.73493 (3.73493)	d_Loss 1.46478 (1.46478)
=>NAG Epoch: 346 Error: 3.734928607940674, Time:     0.01m
=>epoch:347/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [347][0/1]	Time  0.268 ( 0.268)	Data  0.199 ( 0.199)	Loss 3.73279 (3.73279)	d_Loss 1.46466 (1.46466)
=>NAG Epoch: 347 Error: 3.732786178588867, Time:     0.01m
=>epoch:348/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [348][0/1]	Time  0.249 ( 0.249)	Data  0.220 ( 0.220)	Loss 3.72865 (3.72865)	d_Loss 1.46452 (1.46452)
=>NAG Epoch: 348 Error: 3.728649616241455, Time:     0.01m
=>epoch:349/750  resolution:32 batch:128, lr6.250000e-05
Epoch: [349][0/1]	Time  0.269 ( 0.269)	Data  0.198 ( 0.198)	Loss 3.72818 (3.72818)	d_Loss 1.46451 (1.46451)
=>NAG Epoch: 349 Error: 3.728177070617676, Time:     0.01m
=>epoch:350/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [350][0/1]	Time  0.243 ( 0.243)	Data  0.214 ( 0.214)	Loss 3.72530 (3.72530)	d_Loss 1.46474 (1.464

Epoch: [378][0/1]	Time  0.419 ( 0.419)	Data  0.356 ( 0.356)	Loss 3.67528 (3.67528)	d_Loss 1.46410 (1.46410)
=>NAG Epoch: 378 Error: 3.675276517868042, Time:     0.01m
=>epoch:379/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [379][0/1]	Time  0.314 ( 0.314)	Data  0.246 ( 0.246)	Loss 3.67562 (3.67562)	d_Loss 1.46413 (1.46413)
=>NAG Epoch: 379 Error: 3.6756186485290527, Time:     0.01m
=>consecutive_loss: 1
=>epoch:380/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [380][0/1]	Time  0.326 ( 0.326)	Data  0.251 ( 0.251)	Loss 3.67339 (3.67339)	d_Loss 1.46420 (1.46420)
=>NAG Epoch: 380 Error: 3.6733884811401367, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_380.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_380.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_380.pkl
=>epoch:381/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [381][0/1]	Time  0.294 ( 0.294)	Data  0.227 ( 0.227)	Loss 3.67289 (3.

Epoch: [410][0/1]	Time  0.266 ( 0.266)	Data  0.199 ( 0.199)	Loss 3.63279 (3.63279)	d_Loss 1.46402 (1.46402)
=>NAG Epoch: 410 Error: 3.6327908039093018, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_410.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_410.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_410.pkl
=>epoch:411/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [411][0/1]	Time  0.279 ( 0.279)	Data  0.246 ( 0.246)	Loss 3.63095 (3.63095)	d_Loss 1.46404 (1.46404)
=>NAG Epoch: 411 Error: 3.630949020385742, Time:     0.01m
=>epoch:412/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [412][0/1]	Time  0.282 ( 0.282)	Data  0.212 ( 0.212)	Loss 3.62994 (3.62994)	d_Loss 1.46393 (1.46393)
=>NAG Epoch: 412 Error: 3.629943370819092, Time:     0.01m
=>epoch:413/750  resolution:32 batch:128, lr3.125000e-05
Epoch: [413][0/1]	Time  0.232 ( 0.232)	Data  0.202 ( 0.202)	Loss 3.62842 (3.62842)	d_Loss 1.46392 (

Epoch: [440][0/1]	Time  0.244 ( 0.244)	Data  0.215 ( 0.215)	Loss 3.59546 (3.59546)	d_Loss 1.46368 (1.46368)
=>NAG Epoch: 440 Error: 3.595459461212158, Time:     0.00m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_440.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_440.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_440.pkl
=>epoch:441/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [441][0/1]	Time  0.279 ( 0.279)	Data  0.211 ( 0.211)	Loss 3.59845 (3.59845)	d_Loss 1.46373 (1.46373)
=>NAG Epoch: 441 Error: 3.598452091217041, Time:     0.01m
=>consecutive_loss: 1
=>epoch:442/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [442][0/1]	Time  0.247 ( 0.247)	Data  0.204 ( 0.204)	Loss 3.59605 (3.59605)	d_Loss 1.46373 (1.46373)
=>NAG Epoch: 442 Error: 3.596047878265381, Time:     0.01m
=>consecutive_loss: 2
=>epoch:443/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [443][0/1]	Time  0.236 ( 0.236)	Data  0.207 ( 0.2

Epoch: [472][0/1]	Time  0.300 ( 0.300)	Data  0.238 ( 0.238)	Loss 3.56820 (3.56820)	d_Loss 1.46350 (1.46350)
=>NAG Epoch: 472 Error: 3.5682034492492676, Time:     0.01m
=>epoch:473/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [473][0/1]	Time  0.281 ( 0.281)	Data  0.214 ( 0.214)	Loss 3.56819 (3.56819)	d_Loss 1.46343 (1.46343)
=>NAG Epoch: 473 Error: 3.568190336227417, Time:     0.01m
=>epoch:474/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [474][0/1]	Time  0.282 ( 0.282)	Data  0.214 ( 0.214)	Loss 3.56701 (3.56701)	d_Loss 1.46346 (1.46346)
=>NAG Epoch: 474 Error: 3.5670089721679688, Time:     0.01m
=>epoch:475/750  resolution:32 batch:128, lr1.562500e-05
Epoch: [475][0/1]	Time  0.265 ( 0.265)	Data  0.199 ( 0.199)	Loss 3.56653 (3.56653)	d_Loss 1.46357 (1.46357)
=>NAG Epoch: 475 Error: 3.5665283203125, Time:     0.01m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_475.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_475.pkl
Save chec

Epoch: [503][0/1]	Time  0.231 ( 0.231)	Data  0.199 ( 0.199)	Loss 3.54583 (3.54583)	d_Loss 1.46320 (1.46320)
=>NAG Epoch: 503 Error: 3.545832395553589, Time:     0.00m
=>epoch:504/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [504][0/1]	Time  0.224 ( 0.224)	Data  0.194 ( 0.194)	Loss 3.54697 (3.54697)	d_Loss 1.46335 (1.46335)
=>NAG Epoch: 504 Error: 3.5469653606414795, Time:     0.00m
=>consecutive_loss: 1
=>epoch:505/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [505][0/1]	Time  0.225 ( 0.225)	Data  0.195 ( 0.195)	Loss 3.54582 (3.54582)	d_Loss 1.46336 (1.46336)
=>NAG Epoch: 505 Error: 3.5458202362060547, Time:     0.00m
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netZ_nag_505.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netG_nag_505.pkl
Save checkpoints...! runs/nets_tester.py_cifar100_test_run/netD_nag_505.pkl
=>epoch:506/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [506][0/1]	Time  0.225 ( 0.225)	Data  0.195 ( 0.195)	Loss 3.54294 (3.

Epoch: [535][0/1]	Time  0.318 ( 0.318)	Data  0.288 ( 0.288)	Loss 3.53147 (3.53147)	d_Loss 1.46320 (1.46320)
=>NAG Epoch: 535 Error: 3.531466007232666, Time:     0.01m
=>consecutive_loss: 1
=>epoch:536/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [536][0/1]	Time  0.295 ( 0.295)	Data  0.261 ( 0.261)	Loss 3.52914 (3.52914)	d_Loss 1.46318 (1.46318)
=>NAG Epoch: 536 Error: 3.529144048690796, Time:     0.01m
=>epoch:537/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [537][0/1]	Time  0.525 ( 0.525)	Data  0.493 ( 0.493)	Loss 3.53088 (3.53088)	d_Loss 1.46311 (1.46311)
=>NAG Epoch: 537 Error: 3.5308823585510254, Time:     0.01m
=>consecutive_loss: 1
=>epoch:538/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [538][0/1]	Time  0.239 ( 0.239)	Data  0.208 ( 0.208)	Loss 3.53016 (3.53016)	d_Loss 1.46308 (1.46308)
=>NAG Epoch: 538 Error: 3.5301623344421387, Time:     0.01m
=>consecutive_loss: 2
=>epoch:539/750  resolution:32 batch:128, lr7.812500e-06
Epoch: [539][0/1]	Time  0.549 ( 0.54

Epoch: [567][0/1]	Time  0.263 ( 0.263)	Data  0.193 ( 0.193)	Loss 3.51561 (3.51561)	d_Loss 1.46295 (1.46295)
=>NAG Epoch: 567 Error: 3.51560640335083, Time:     0.01m
=>consecutive_loss: 2
=>epoch:568/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [568][0/1]	Time  0.258 ( 0.258)	Data  0.229 ( 0.229)	Loss 3.51600 (3.51600)	d_Loss 1.46301 (1.46301)
=>NAG Epoch: 568 Error: 3.5160019397735596, Time:     0.01m
=>consecutive_loss: 3
=>epoch:569/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [569][0/1]	Time  0.283 ( 0.283)	Data  0.217 ( 0.217)	Loss 3.51475 (3.51475)	d_Loss 1.46295 (1.46295)
=>NAG Epoch: 569 Error: 3.5147507190704346, Time:     0.01m
=>consecutive_loss: 4
=>epoch:570/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [570][0/1]	Time  0.233 ( 0.233)	Data  0.199 ( 0.199)	Loss 3.51665 (3.51665)	d_Loss 1.46286 (1.46286)
=>NAG Epoch: 570 Error: 3.5166549682617188, Time:     0.00m
=>consecutive_loss: 5
=>epoch:571/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [571][0

Epoch: [600][0/1]	Time  0.241 ( 0.241)	Data  0.212 ( 0.212)	Loss 3.50810 (3.50810)	d_Loss 1.46274 (1.46274)
=>NAG Epoch: 600 Error: 3.5080995559692383, Time:     0.01m
=>consecutive_loss: 3
Show images...
=>epoch:601/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [601][0/1]	Time  0.267 ( 0.267)	Data  0.198 ( 0.198)	Loss 3.50700 (3.50700)	d_Loss 1.46282 (1.46282)
=>NAG Epoch: 601 Error: 3.506997585296631, Time:     0.01m
=>consecutive_loss: 4
=>epoch:602/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [602][0/1]	Time  0.246 ( 0.246)	Data  0.217 ( 0.217)	Loss 3.50484 (3.50484)	d_Loss 1.46276 (1.46276)
=>NAG Epoch: 602 Error: 3.5048398971557617, Time:     0.01m
=>epoch:603/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [603][0/1]	Time  0.275 ( 0.275)	Data  0.209 ( 0.209)	Loss 3.50838 (3.50838)	d_Loss 1.46280 (1.46280)
=>NAG Epoch: 603 Error: 3.5083820819854736, Time:     0.01m
=>consecutive_loss: 1
=>epoch:604/750  resolution:32 batch:128, lr3.906250e-06
Epoch: [604][0/1]	Ti

Epoch: [633][0/1]	Time  0.308 ( 0.308)	Data  0.245 ( 0.245)	Loss 3.49934 (3.49934)	d_Loss 1.46264 (1.46264)
=>NAG Epoch: 633 Error: 3.4993443489074707, Time:     0.01m
=>consecutive_loss: 1
=>epoch:634/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [634][0/1]	Time  0.309 ( 0.309)	Data  0.246 ( 0.246)	Loss 3.49861 (3.49861)	d_Loss 1.46265 (1.46265)
=>NAG Epoch: 634 Error: 3.498609781265259, Time:     0.01m
=>consecutive_loss: 2
=>epoch:635/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [635][0/1]	Time  0.257 ( 0.257)	Data  0.196 ( 0.196)	Loss 3.49881 (3.49881)	d_Loss 1.46271 (1.46271)
=>NAG Epoch: 635 Error: 3.4988129138946533, Time:     0.01m
=>consecutive_loss: 3
=>epoch:636/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [636][0/1]	Time  0.259 ( 0.259)	Data  0.195 ( 0.195)	Loss 3.49923 (3.49923)	d_Loss 1.46263 (1.46263)
=>NAG Epoch: 636 Error: 3.4992270469665527, Time:     0.01m
=>consecutive_loss: 4
=>epoch:637/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [637][

Epoch: [666][0/1]	Time  0.263 ( 0.263)	Data  0.196 ( 0.196)	Loss 3.49539 (3.49539)	d_Loss 1.46259 (1.46259)
=>NAG Epoch: 666 Error: 3.4953861236572266, Time:     0.01m
=>consecutive_loss: 4
=>epoch:667/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [667][0/1]	Time  0.242 ( 0.242)	Data  0.213 ( 0.213)	Loss 3.49515 (3.49515)	d_Loss 1.46256 (1.46256)
=>NAG Epoch: 667 Error: 3.4951491355895996, Time:     0.01m
=>consecutive_loss: 5
=>epoch:668/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [668][0/1]	Time  0.274 ( 0.274)	Data  0.204 ( 0.204)	Loss 3.49230 (3.49230)	d_Loss 1.46242 (1.46242)
=>NAG Epoch: 668 Error: 3.492302894592285, Time:     0.01m
=>epoch:669/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [669][0/1]	Time  0.224 ( 0.224)	Data  0.193 ( 0.193)	Loss 3.49532 (3.49532)	d_Loss 1.46252 (1.46252)
=>NAG Epoch: 669 Error: 3.49532413482666, Time:     0.00m
=>consecutive_loss: 1
=>epoch:670/750  resolution:32 batch:128, lr1.953125e-06
Epoch: [670][0/1]	Time  0.236 ( 0.236

Epoch: [699][0/1]	Time  0.237 ( 0.237)	Data  0.202 ( 0.202)	Loss 3.49183 (3.49183)	d_Loss 1.46238 (1.46238)
=>NAG Epoch: 699 Error: 3.4918336868286133, Time:     0.00m
=>consecutive_loss: 8
=>epoch:700/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [700][0/1]	Time  0.278 ( 0.278)	Data  0.242 ( 0.242)	Loss 3.49099 (3.49099)	d_Loss 1.46239 (1.46239)
=>NAG Epoch: 700 Error: 3.4909863471984863, Time:     0.01m
=>consecutive_loss: 9
Show images...
=>epoch:701/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [701][0/1]	Time  0.254 ( 0.254)	Data  0.225 ( 0.225)	Loss 3.48974 (3.48974)	d_Loss 1.46237 (1.46237)
=>NAG Epoch: 701 Error: 3.4897377490997314, Time:     0.01m
=>epoch:702/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [702][0/1]	Time  0.274 ( 0.274)	Data  0.207 ( 0.207)	Loss 3.49102 (3.49102)	d_Loss 1.46244 (1.46244)
=>NAG Epoch: 702 Error: 3.491016149520874, Time:     0.01m
=>consecutive_loss: 1
=>epoch:703/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [703][0/1]	Ti

Epoch: [733][0/1]	Time  0.298 ( 0.298)	Data  0.233 ( 0.233)	Loss 3.49031 (3.49031)	d_Loss 1.46226 (1.46226)
=>NAG Epoch: 733 Error: 3.490309238433838, Time:     0.01m
=>consecutive_loss: 10
=>epoch:734/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [734][0/1]	Time  0.287 ( 0.287)	Data  0.225 ( 0.225)	Loss 3.48904 (3.48904)	d_Loss 1.46225 (1.46225)
=>NAG Epoch: 734 Error: 3.4890365600585938, Time:     0.01m
=>consecutive_loss: 11
=>epoch:735/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [735][0/1]	Time  0.290 ( 0.290)	Data  0.222 ( 0.222)	Loss 3.48823 (3.48823)	d_Loss 1.46228 (1.46228)
=>NAG Epoch: 735 Error: 3.488226890563965, Time:     0.01m
=>consecutive_loss: 12
=>epoch:736/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [736][0/1]	Time  0.287 ( 0.287)	Data  0.222 ( 0.222)	Loss 3.48859 (3.48859)	d_Loss 1.46231 (1.46231)
=>NAG Epoch: 736 Error: 3.4885926246643066, Time:     0.01m
=>consecutive_loss: 13
=>epoch:737/750  resolution:32 batch:128, lr9.765625e-07
Epoch: [73

In [6]:
import importlib
import interpolate
importlib.reload(interpolate)
import interpolate
from interpolate import interpolate_points


def getGenImage(z1,z2,gen,steps,codeSize):
    interp = interpolate_points(z1,z2,n_steps=steps, slerp=True, print_mode=False)
    code = torch.cuda.FloatTensor(steps, codeSize).normal_(0, 0.15)
    im = gen(interp,code)
    return im

In [7]:
import time
import itertools
class GlicoLoader:
    
    def getInterpolations(self):
        interps = [[x for x in itertools.combinations(self.netZ.label2idx[y], 2)] for y in range(10)]
        self.interps = {}
        for i,combinations in enumerate(interps):
            zvecs = self.netZ(torch.tensor(combinations).cuda())
            temp = []
            for idx in range(zvecs.size(0)):
                z1 = zvecs[idx,0,:]
                z2 = zvecs[idx,1,:]
                gen = getGenImage(z1,z2,self.netG,steps=self.steps,codeSize= self.nag.code_size)
                temp.append(gen.detach())
                
            self.interps[i] = torch.cat(temp,dim=0)
            
        
    def __init__(self,nagTrainer,steps):
        self.nagTrainer = nagTrainer
        self.nag = nagTrainer.nag
        self.netZ = self.nag.netZ
        self.netG = self.nag.netG
        self.steps = steps
        self.getInterpolations()
        
        self.indices = [0 for x in range(10)]
        a = self.interps[0].size(0)
        prng = RandomState(int(time.time()))
        self.indexers = [list(prng.permutation(np.arange(0,a))) for x in range(10)]
        
        
    def sample(self,classNum):
        if self.interps[0].size(0) == self.indices[classNum]:
            self.indices[classNum] = 0
            prng = RandomState(int(time.time()))
            self.indexers[classNum] = list(prng.permutation(np.arange(0,self.interps[0].size(0))))
            
        temp = self.interps[classNum][self.indexers[classNum][self.indices[classNum]],:,:,:]
            
        return temp
    
    def replaceBatch(self,batch,targets,replaceProb):
        for x in range(batch.size(0)):
            prob = np.random.rand()
            if prob < replaceProb:
                batch[x,:,:,:] = self.sample(targets[x].item())
            else:
                pass
        return batch
        
        
        
        

In [8]:
glicoLoader = GlicoLoader(nagTrainer,5)

In [9]:
import matplotlib

c_to_idx = {"airplane" : 0,
"automobile" : 1,
"bird" : 2,
"cat" : 3,
"deer" : 4,
"dog" : 5,
"frog" : 6,
"horse" : 7,
"ship" : 8,
"truck" : 9}

idx_to_c =  {v:k for k,v in c_to_idx.items()}
save = True
matplotlib.use( 'agg' )

for _ in range(5):
    for class_ in range(10):
        gen = glicoLoader.sample(classNum=class_)
        plt.imshow(gen.cpu().permute(1, 2, 0))
        plt.axis("off")
        if save: 
            plt.savefig(os.path.join("{}{}.pdf".format(idx_to_c[class_],_)),
                bbox_inches="tight",pad_inches=0)
        


# plt.show()

In [None]:
torch.cuda.empty_cache()
gc.collect()

MODEL = "vgg16"

model = getModel(MODEL).cuda()
optimizer,LR = getOptimizer128(OPTIM,model.classifier.parameters())

print(' => Total trainable parameters: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))        

trainTracker = {"meanLoss":[],"accuracy":[]}
valTracker = {"allLoss":[],"allAcc":[],"meanLoss":[],"meanAcc":[],"stdAcc":[]}
latexTracker = []

print("Begin Train for {} epochs".format(EPOCH_NUM))
for epoch in range(EPOCH_NUM):
    acc, loss = glicoTrain(model, device, train_data[0], optimizer, 
                      epoch+1,glicoLoader,replaceProb=REPLACE_PROB, display=True)
    trainTracker["meanLoss"].append(loss)
    trainTracker["accuracy"].append(acc)
    
    if (epoch+1) % VAL_DISPLAY_DIVISOR == 0:
        checkTest(model,device,valSets,valTracker,latexTracker,epoch+1,
              model_name=MODEL,optim_name=OPTIM,lr=LR,totalTestSamples=VAL_SAMPLE_NUM*VALIDATION_SET_NUM,
                  seed=seed,verbose=True)
        
          
        
        

 => Total trainable parameters: 138.36M
Begin Train for 2000 epochs
Train Epoch: 1 [acc: 0%]	Loss: 20.194527
Train Epoch: 2 [acc: 0%]	Loss: 14.043938
Train Epoch: 3 [acc: 2%]	Loss: 8.105101
Train Epoch: 4 [acc: 5%]	Loss: 6.198342
Train Epoch: 5 [acc: 14%]	Loss: 5.042351
Train Epoch: 6 [acc: 15%]	Loss: 3.781166
Train Epoch: 7 [acc: 26%]	Loss: 2.757268
Train Epoch: 8 [acc: 27%]	Loss: 2.742980
Train Epoch: 9 [acc: 29%]	Loss: 3.587361
Train Epoch: 10 [acc: 27%]	Loss: 4.209502
Train Epoch: 11 [acc: 22%]	Loss: 6.482852
Train Epoch: 12 [acc: 20%]	Loss: 5.758770
Train Epoch: 13 [acc: 32%]	Loss: 4.295445
Train Epoch: 14 [acc: 27%]	Loss: 3.356111
Train Epoch: 15 [acc: 26%]	Loss: 3.379940
Train Epoch: 16 [acc: 32%]	Loss: 3.305503
Train Epoch: 17 [acc: 33%]	Loss: 2.569059
Train Epoch: 18 [acc: 32%]	Loss: 2.555778
Train Epoch: 19 [acc: 37%]	Loss: 2.811233
Train Epoch: 20 [acc: 34%]	Loss: 3.041678
Train Epoch: 21 [acc: 33%]	Loss: 2.736672
Train Epoch: 22 [acc: 33%]	Loss: 2.744677
Train Epoch: 23 [ac

Train Epoch: 176 [acc: 61%]	Loss: 1.096308
Train Epoch: 177 [acc: 53%]	Loss: 1.427940
Train Epoch: 178 [acc: 61%]	Loss: 1.188694
Train Epoch: 179 [acc: 57%]	Loss: 1.253952
Train Epoch: 180 [acc: 57%]	Loss: 1.285730
Train Epoch: 181 [acc: 52%]	Loss: 1.405145
Train Epoch: 182 [acc: 66%]	Loss: 0.973487
Train Epoch: 183 [acc: 62%]	Loss: 1.116121
Train Epoch: 184 [acc: 58%]	Loss: 1.292201
Train Epoch: 185 [acc: 63%]	Loss: 1.167402
Train Epoch: 186 [acc: 62%]	Loss: 1.036981
Train Epoch: 187 [acc: 64%]	Loss: 1.118590
Train Epoch: 188 [acc: 60%]	Loss: 1.178627
Train Epoch: 189 [acc: 67%]	Loss: 1.051732
Train Epoch: 190 [acc: 61%]	Loss: 1.239565
Train Epoch: 191 [acc: 68%]	Loss: 1.090055
Train Epoch: 192 [acc: 64%]	Loss: 1.049262
Train Epoch: 193 [acc: 62%]	Loss: 1.128961
Train Epoch: 194 [acc: 66%]	Loss: 1.100367
Train Epoch: 195 [acc: 57%]	Loss: 1.358129
Train Epoch: 196 [acc: 54%]	Loss: 1.449995
Train Epoch: 197 [acc: 61%]	Loss: 1.214399
Train Epoch: 198 [acc: 66%]	Loss: 1.019150
Train Epoch

Train Epoch: 351 [acc: 70%]	Loss: 1.097855
Train Epoch: 352 [acc: 67%]	Loss: 1.047122
Train Epoch: 353 [acc: 67%]	Loss: 1.182599
Train Epoch: 354 [acc: 65%]	Loss: 1.163797
Train Epoch: 355 [acc: 64%]	Loss: 1.003300
Train Epoch: 356 [acc: 62%]	Loss: 1.161421
Train Epoch: 357 [acc: 67%]	Loss: 0.986032
Train Epoch: 358 [acc: 64%]	Loss: 0.936363
Train Epoch: 359 [acc: 64%]	Loss: 1.022990
Train Epoch: 360 [acc: 68%]	Loss: 0.965055
Train Epoch: 361 [acc: 70%]	Loss: 0.922876
Train Epoch: 362 [acc: 67%]	Loss: 1.086470
Train Epoch: 363 [acc: 66%]	Loss: 1.043359
Train Epoch: 364 [acc: 68%]	Loss: 1.022227
Train Epoch: 365 [acc: 62%]	Loss: 1.251088
Train Epoch: 366 [acc: 49%]	Loss: 1.543579
Train Epoch: 367 [acc: 69%]	Loss: 0.943268
Train Epoch: 368 [acc: 63%]	Loss: 0.988857
Train Epoch: 369 [acc: 60%]	Loss: 1.136465
Train Epoch: 370 [acc: 63%]	Loss: 1.060202
Train Epoch: 371 [acc: 73%]	Loss: 0.815846
Train Epoch: 372 [acc: 62%]	Loss: 1.101354
Train Epoch: 373 [acc: 60%]	Loss: 1.106478
Train Epoch

Train Epoch: 526 [acc: 73%]	Loss: 0.934983
Train Epoch: 527 [acc: 65%]	Loss: 1.076161
Train Epoch: 528 [acc: 70%]	Loss: 0.915563
Train Epoch: 529 [acc: 57%]	Loss: 1.232334
Train Epoch: 530 [acc: 69%]	Loss: 0.957576
Train Epoch: 531 [acc: 66%]	Loss: 1.133381
Train Epoch: 532 [acc: 68%]	Loss: 1.043014
Train Epoch: 533 [acc: 67%]	Loss: 1.015037
Train Epoch: 534 [acc: 64%]	Loss: 1.142419
Train Epoch: 535 [acc: 62%]	Loss: 1.119675
Train Epoch: 536 [acc: 61%]	Loss: 1.102306
Train Epoch: 537 [acc: 76%]	Loss: 0.867603
Train Epoch: 538 [acc: 70%]	Loss: 0.958546
Train Epoch: 539 [acc: 65%]	Loss: 0.981098
Train Epoch: 540 [acc: 59%]	Loss: 1.374499
Train Epoch: 541 [acc: 68%]	Loss: 1.106417
Train Epoch: 542 [acc: 68%]	Loss: 0.871956
Train Epoch: 543 [acc: 70%]	Loss: 0.919955
Train Epoch: 544 [acc: 68%]	Loss: 1.080217
Train Epoch: 545 [acc: 67%]	Loss: 1.016346
Train Epoch: 546 [acc: 72%]	Loss: 0.867897
Train Epoch: 547 [acc: 71%]	Loss: 0.830044
Train Epoch: 548 [acc: 67%]	Loss: 0.975339
Train Epoch

Train Epoch: 701 [acc: 71%]	Loss: 0.850260
Train Epoch: 702 [acc: 72%]	Loss: 0.819636
Train Epoch: 703 [acc: 72%]	Loss: 0.827658
Train Epoch: 704 [acc: 69%]	Loss: 0.963281
Train Epoch: 705 [acc: 73%]	Loss: 0.860606
Train Epoch: 706 [acc: 69%]	Loss: 0.941398
Train Epoch: 707 [acc: 76%]	Loss: 0.819529
Train Epoch: 708 [acc: 74%]	Loss: 0.967247
Train Epoch: 709 [acc: 78%]	Loss: 0.622154
Train Epoch: 710 [acc: 75%]	Loss: 0.865966
Train Epoch: 711 [acc: 67%]	Loss: 0.880913
Train Epoch: 712 [acc: 70%]	Loss: 0.860147
Train Epoch: 713 [acc: 66%]	Loss: 1.183740
Train Epoch: 714 [acc: 70%]	Loss: 0.783358
Train Epoch: 715 [acc: 68%]	Loss: 0.991358
Train Epoch: 716 [acc: 79%]	Loss: 0.696332
Train Epoch: 717 [acc: 78%]	Loss: 0.631067
Train Epoch: 718 [acc: 70%]	Loss: 0.949987
Train Epoch: 719 [acc: 74%]	Loss: 0.836529
Train Epoch: 720 [acc: 79%]	Loss: 0.583716
Train Epoch: 721 [acc: 82%]	Loss: 0.686173
Train Epoch: 722 [acc: 71%]	Loss: 0.882945
Train Epoch: 723 [acc: 74%]	Loss: 0.835628
Train Epoch

Train Epoch: 876 [acc: 74%]	Loss: 0.907522
Train Epoch: 877 [acc: 69%]	Loss: 0.843195
Train Epoch: 878 [acc: 72%]	Loss: 0.770099
Train Epoch: 879 [acc: 74%]	Loss: 0.790295
Train Epoch: 880 [acc: 75%]	Loss: 0.884152
Train Epoch: 881 [acc: 67%]	Loss: 0.936270
Train Epoch: 882 [acc: 68%]	Loss: 1.073496
Train Epoch: 883 [acc: 74%]	Loss: 0.686296
Train Epoch: 884 [acc: 73%]	Loss: 0.833042
Train Epoch: 885 [acc: 70%]	Loss: 0.927934
Train Epoch: 886 [acc: 65%]	Loss: 1.081127
Train Epoch: 887 [acc: 66%]	Loss: 0.998859
Train Epoch: 888 [acc: 70%]	Loss: 0.946380
Train Epoch: 889 [acc: 71%]	Loss: 0.776779
Train Epoch: 890 [acc: 80%]	Loss: 0.651798
Train Epoch: 891 [acc: 71%]	Loss: 0.870192
Train Epoch: 892 [acc: 74%]	Loss: 0.707620
Train Epoch: 893 [acc: 73%]	Loss: 0.810595
Train Epoch: 894 [acc: 78%]	Loss: 0.761646
Train Epoch: 895 [acc: 68%]	Loss: 1.001301
Train Epoch: 896 [acc: 69%]	Loss: 0.850196
Train Epoch: 897 [acc: 72%]	Loss: 0.903476
Train Epoch: 898 [acc: 74%]	Loss: 0.876906
Train Epoch

[Trained for 1050 epochs and tested on 5 sets of 2000 images]        Avg Acc: 36.54 +- 1.50 , Avg Loss: 2.53
Train Epoch: 1051 [acc: 79%]	Loss: 0.772821
Train Epoch: 1052 [acc: 73%]	Loss: 0.763609
Train Epoch: 1053 [acc: 74%]	Loss: 0.801019
Train Epoch: 1054 [acc: 74%]	Loss: 0.795360
Train Epoch: 1055 [acc: 66%]	Loss: 0.951094
Train Epoch: 1056 [acc: 76%]	Loss: 0.846877
Train Epoch: 1057 [acc: 69%]	Loss: 0.904248
Train Epoch: 1058 [acc: 67%]	Loss: 0.963460
Train Epoch: 1059 [acc: 76%]	Loss: 0.909132
Train Epoch: 1060 [acc: 65%]	Loss: 0.963713
Train Epoch: 1061 [acc: 75%]	Loss: 0.929165
Train Epoch: 1062 [acc: 72%]	Loss: 0.859142
Train Epoch: 1063 [acc: 69%]	Loss: 0.811398
Train Epoch: 1064 [acc: 68%]	Loss: 0.955605
Train Epoch: 1065 [acc: 74%]	Loss: 0.854418
Train Epoch: 1066 [acc: 76%]	Loss: 0.739295
Train Epoch: 1067 [acc: 68%]	Loss: 0.973077
Train Epoch: 1068 [acc: 67%]	Loss: 1.039064
Train Epoch: 1069 [acc: 65%]	Loss: 1.055751
Train Epoch: 1070 [acc: 69%]	Loss: 0.939855
Train Epoch

In [None]:
dirname = latexTracker[-1][:-2] 

def writeTex(latexTracker,dirname):
    if not os.path.isdir(dirname):
        os.mkdir(dirname)
        
    f= open(os.path.join(dirname,"latexTable.txt"),"w")
    for x in latexTracker:
        f.write(x)
    f.close()

writeTex(latexTracker,dirname)

for x in latexTracker:
    print(x)

In [None]:

epochList = [x+1 for x in range(len(trainTracker["meanLoss"]))]

plot(xlist=epochList,ylist=trainTracker["meanLoss"],xlab="Mean Train Loss",
    ylab="Epochs",title="Mean Train Loss over Epochs",
    color="#243A92",label="mean train loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=trainTracker["accuracy"],xlab="Train Accuracy",
    ylab="Epochs",title="Train Accuracy Over Epochs",
    color="#34267E",label="Train Accuracy",savedir=dirname,save=True)

In [None]:

epochList = [VAL_DISPLAY_DIVISOR*(x+1) for x in range(len(valTracker["meanLoss"]))]

plot(xlist=epochList,ylist=valTracker["meanLoss"],xlab="Epochs",
    ylab="Mean Val Loss",title="Mean Val Loss over Epochs",
    color="#243A92",label="mean val loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["meanAcc"],xlab="Epochs",
    ylab="Val Accuracy",title="Val Accuracy Over Epochs",
    color="#34267E",label="Val Accuracy",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["stdAcc"],xlab="Epochs",
    ylab="Val Accuracy Standard Deviation",title="Val Accuracy Standard Deviation Over Epochs",
    color="#34267E",label="Val Accuracy SD",savedir=dirname,save=True)


valSetEvalCount = VAL_DISPLAY_DIVISOR * EPOCH_NUM * VALIDATION_SET_NUM
epochList = [VAL_DISPLAY_DIVISOR*(x+1) for x in range(len(valTracker["meanLoss"]))\
             for y in range(VALIDATION_SET_NUM)]


plot(xlist=epochList,ylist=valTracker["allLoss"],xlab="Val Set Evaluations",
    ylab="Val Loss",title="Val loss over val set evaluations ({} \
every {} epochs)".format(VALIDATION_SET_NUM,VAL_DISPLAY_DIVISOR),
    color="#34267E",label="Val Loss",savedir=dirname,save=True)

plot(xlist=epochList,ylist=valTracker["allAcc"],xlab="Val Set Evaluations",
    ylab="Val Accuracy",title="Val loss over val set evaluations ({} \
every {} epochs) ".format(VALIDATION_SET_NUM,VAL_DISPLAY_DIVISOR),
    color="#34267E",label="Val Accuracy",savedir=dirname,save=True)