In [1]:
import torch.nn as nn 
import torch.nn.functional as F
import torch

from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:40% !important; }</style>"))



In [2]:
def train(model, device, train_loader, optimizer, epoch, display=True):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
    if display:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader,verbose=True):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    
    if verbose: print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
        
    return 100. * correct / len(test_loader.dataset)

In [3]:
def getValSets(num,random_permute,cifar_data,cifar_data_val):
    sets = []
    for x in range(num):
        s = 200*x 
        indx_val = np.concatenate([np.where(np.array(cifar_data.targets) == class_)[0][random_permute[s:s+200]] for class_ in range(0, 10)])
        subset = Subset(cifar_data_val, indx_val)
        sets.append(torch.utils.data.DataLoader(subset,
                                           batch_size=128, 
                                           shuffle=False))
    return sets

def getLatexRow(seed,net,acc,epoch,lr,dataAug="Nothing"):
    categories = ["seed","network","Accuracy","Epochs","learning rate","data augmentation"]
    row = [str(seed),str(net),str(round(acc,3)),str(epoch),str(lr),str(dataAug)]
    
    c = "&".join(categories)
    r = "&".join(row)
    return "{}\\\\\n{}\\\\\n".format(c,r)

In [4]:
from numpy.random import RandomState
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import Subset
import PIL
import torchvision

from torchvision import datasets, transforms
from wide_resnet import WideResNet
from auto_augment import AutoAugment, Cutout
from models.EfficientNets import EfficientNet

results = {}



VALIDATION_SET_NUM = 1



for epochNum in [375]:
    results[epochNum] = 0
# Avoid cuda out of memory
    torch.cuda.empty_cache()

    import gc
    gc.collect()

    AUGMENT = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])


    if AUGMENT:
        dataAugmentation = [ 
#             torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
#             torchvision.transforms.RandomHorizontalFlip(),
#             torchvision.transforms.RandomRotation(20, resample=PIL.Image.BILINEAR)
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            AutoAugment(),
            Cutout()
        ]
        augment = "Color Jitter+HorizFlip+rotation+AutoAugment"
    else: 
        dataAugmentation = []
        augment = "Nothing"


    transform_val = transforms.Compose([transforms.ToTensor(), normalize]) #careful to keep this one same
    transform_train = transforms.Compose(dataAugmentation + [transforms.ToTensor(), normalize]) 

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    ##### Cifar Data
    cifar_data = datasets.CIFAR10(root='.',train=True, transform=transform_train, download=True)
    #We need two copies of this due to weird dataset api 
    cifar_data_val = datasets.CIFAR10(root='.',train=True, transform=transform_val, download=True)

    # Extract a subset of 100 (class balanced) samples per class

    accsGlobal = []

    for seed in range(1):
        accsLocal = []
        prng = RandomState(seed)
        random_permute = prng.permutation(np.arange(0, 5000))

        indx_train = np.concatenate([np.where(np.array(cifar_data.targets) == classe)[0][random_permute[0:10]] for classe in range(0, 10)])
        indx_val = np.concatenate([np.where(np.array(cifar_data_val.targets) == classe)[0][random_permute[0:10]] for classe in range(0, 10)])
        
        
        valSets = getValSets(VALIDATION_SET_NUM,random_permute,cifar_data,cifar_data_val)

        train_data = Subset(cifar_data, indx_train)
        val_data = Subset(cifar_data_val, indx_val)


        print('Num Samples For Training {} Num Samples For Val {}'.format(train_data.indices.shape[0],val_data.indices.shape[0]))

        train_loader = torch.utils.data.DataLoader(train_data,
                                                 batch_size=128, 
                                                 shuffle=True)

        val_loader = torch.utils.data.DataLoader(val_data,
                                               batch_size=128, 
                                               shuffle=False)



        model = WideResNet(28, 10, num_classes=10)
#         model = EfficientNet.from_name('efficientnet-b5',num_classes=10)
        model.to(device)
#         optimizer = torch.optim.SGD(model.parameters(), 
#                                   lr=0.09, momentum=0.9,
#                                   weight_decay=0.0005)
        print(model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), 
                                  lr=0.0001, weight_decay=0)
        
        

        print("Begin Train for {} epochs".format(epochNum))
        for epoch in range(epochNum):
            
            train(model, device, train_loader, optimizer, epoch, display=True)
            
            if (epoch+1) % 25 == 0: 
                temp = []
                print("epoch {}".format(epoch+1))

        for val_loader in valSets:
            accsLocal.append(test(model, device, val_loader,verbose = False))

        temp = getLatexRow(seed,
                           net="default",
                           acc=round(float(np.mean(accsLocal)),3),
                           epoch=epochNum,
                           lr=0.2,
                           dataAug=augment)
        
        accsGlobal = accsGlobal + accsLocal
        accs = np.array(accsLocal)
        print('[Trained for {} epochs and tested on {} sets of 2000 images] Avg Acc: {:.2f} +- {:.2f}'.format(
            epochNum,VALIDATION_SET_NUM,accs.mean(),accs.std()))

    accsGlobal = np.array(accsGlobal)
  
    
    results[epochNum] = (accs.mean(),accs.std())

Files already downloaded and verified
Files already downloaded and verified
Num Samples For Training 100 Num Samples For Val 100
<generator object Module.parameters at 0x7f15e84cd190>
Begin Train for 375 epochs
epoch 25
epoch 50
epoch 75


KeyboardInterrupt: 

In [None]:
results

In [None]:
original3trainings = {10: (11.84, 0.3160696125855822),
 50: (21.75, 0.7968688725254613),
 100: (23.69, 0.7748548251124211),
 150: (27.059999999999995, 0.7873372847769876),
 200: (22.72, 0.9693296652842113)}

In [None]:
secondWith10TrainingsEach = {100: (24.979999999999997, 0.8562125904236637),
 125: (23.845000000000002, 0.7198784619642397),
 150: (23.574999999999996, 0.7413669806512837),
 175: (28.440000000000005, 1.0403845442911963),
 200: (22.69, 0.9350935782048763)}

In [None]:
thirdwith10 = {175: (24.009999999999998, 0.8882004278314662)}

In [5]:
import torch
import numpy as np

In [6]:
idx = torch.from_numpy(np.arange(64)).cuda()

In [8]:
idx

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63], device='cuda:0')