In [1]:
import numpy as np
import torch
from torch import nn, optim
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from Models.selective_sequential import *
from Loss.triplet_regularized import *
from session import *
from LR_Schedule.cyclical import Cyclical
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *
import util
from session import LossMeter, EvalModel
from Layers.flatten import Flatten
from torch.utils.tensorboard import SummaryWriter

In [2]:
%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0); torch.backends.cudnn.benchmark=True;
torch.cuda.get_device_name(torch.cuda.current_device())

    Found GPU0 GeForce GTX 770 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability that we support is 3.5.
    


'GeForce GTX 770'

In [3]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

trainset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/train', train=True,
                                        download=True, 
                                        transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
partial_trainset = torch.utils.data.dataset.Subset(trainset, np.arange(3200))

valset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/test', train=False,
                                       download=True, 
                                      transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
partial_valset = torch.utils.data.dataset.Subset(valset, np.arange(3200))

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
resnet = models.resnet18(pretrained=False)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Sequential()

select = ['act2', 'out']
model = SelectiveSequential(
    select,
    {
     'resnet': resnet,
     'fc1': nn.Linear(num_ftrs, 1024),
     'act1': nn.ReLU(True),
     'fc2': nn.Linear(1024, 512),
     'act2': nn.ReLU(True),
     'out': nn.Linear(512, 10)})

In [5]:
criterion = TripletRegularizedMultiMarginLoss(0.2, .5, select)
sess = Session(model, criterion, optim.AdamW, 1e-3)

In [6]:
num_epochs = 50
validator = EmbeddingSpaceValidator(valloader, select, CustomOneHotAccuracy, tensorboard_dir="./runs/")
lr_scheduler = CosAnneal(len(trainloader)*50, T_mult=2, lr_min=1e-6)
schedule = TrainingSchedule(trainloader, num_epochs, [lr_scheduler, validator])

In [7]:
sess.train(schedule)

HBox(children=(IntProgress(value=0, description='Epochs', max=50, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.4964 train accuracy:  0.4081 
train loss:  0.4664  train unreg loss :  0.3859 
valid loss:  0.427  valid unreg loss :  0.2774


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.567 train accuracy:  0.5225 
train loss:  0.3825  train unreg loss :  0.269 
valid loss:  0.3681  valid unreg loss :  0.2296


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.6303 train accuracy:  0.5824 
train loss:  0.3356  train unreg loss :  0.2241 
valid loss:  0.3152  valid unreg loss :  0.1924


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.6283 train accuracy:  0.6265 
train loss:  0.3064  train unreg loss :  0.1928 
valid loss:  0.3161  valid unreg loss :  0.1961


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.6708 train accuracy:  0.6541 
train loss:  0.2794  train unreg loss :  0.1742 
valid loss:  0.2845  valid unreg loss :  0.1634


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.701 train accuracy:  0.6771 
train loss:  0.2654  train unreg loss :  0.1599 
valid loss:  0.2668  valid unreg loss :  0.1482


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.7172 train accuracy:  0.6972 
train loss:  0.2538  train unreg loss :  0.1479 
valid loss:  0.2483  valid unreg loss :  0.136


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.7151 train accuracy:  0.7146 
train loss:  0.2498  train unreg loss :  0.1372 
valid loss:  0.2543  valid unreg loss :  0.1401


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.7471 train accuracy:  0.7276 
train loss:  0.2375  train unreg loss :  0.1293 
valid loss:  0.2279  valid unreg loss :  0.1187


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Validating', max=79, style=ProgressStyle(description_width='i…


val accuracy:  0.7378 train accuracy:  0.7447 
train loss:  0.2253  train unreg loss :  0.1203 
valid loss:  0.2406  valid unreg loss :  0.1262


HBox(children=(IntProgress(value=0, description='Steps', max=391, style=ProgressStyle(description_width='initi…

KeyboardInterrupt: 

In [None]:
validator.plot()