In [1]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn.modules.loss import *
from Loss.triplet import *
from session import *
from LR_Schedule.cyclical import Cyclical
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
from validation import _AccuracyMeter
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *
import util
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from torch.utils.tensorboard import SummaryWriter
from session import LossMeter, EvalModel
# %matplotlib notebook

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
torch.cuda.set_device(0); torch.backends.cudnn.benchmark=True;

    Found GPU0 GeForce GTX 770 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability that we support is 3.5.
    


In [4]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

fulltrainset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/train', train=True,
                                        download=True, transform=transform)
trainset = torch.utils.data.dataset.Subset(fulltrainset, np.arange(3200))

fullvalset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/test', train=False,
                                       download=True, transform=transform)
valset = torch.utils.data.dataset.Subset(fullvalset, np.arange(3200))

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [6]:
class TripletRegularizedCrossEntropyLoss(nn.Module):
    def __init__(self, alpha, margin):     
        super().__init__()
        self.alpha = alpha
        self.margin = margin
        
    def forward(self, x, y):
        loss = F.cross_entropy(x[-1][0], y)
        triplet = 0
        if (self.alpha > 0):
            for layer in x[:-1]:
                # print(layer[0])
                triplet += batch_all_triplet_loss(layer[0].view(layer[0].size(0), -1), y, self.margin)

            triplet *= self.alpha
            
        return loss + triplet

In [7]:
class CustomOneHotAccuracy(OneHotAccuracy):
    def __init__(self):
        super().__init__()
        self.reset()

    def update(self, output, label):
        return super().update(output[-1][0], label)

In [8]:
class EmbeddingSpaceValidator(TrainCallback):
    def __init__(self, val_data, num_embeddings, accuracy_meter_fn):
        self.val_data = val_data
        self.val_accuracy_meter = accuracy_meter_fn()
        self.train_accuracy_meter = accuracy_meter_fn()
        self.num_embeddings=num_embeddings
        
        self.train_accuracies = []
        self.batch_train_accuracies = []
        self.val_accuracies = []
        
        self.train_losses = []
        self.batch_train_losses = []
        self.train_bce_losses = []
        self.val_losses = []
        self.val_bce_losses = []
        
        self.batch_train_embedding_losses = [[] for x in range(self.num_embeddings)]
        self.val_embedding_losses = [[] for x in range(self.num_embeddings)]
        
        self.num_batches = 0
        self.num_epochs = 0
        
        self.epochs = []

    def run(self, session, lossMeter=None):
        self.val_accuracy_meter.reset()
            
        val_loss = LossMeter()
        val_bce_loss = LossMeter()
        embedding_losses = [LossMeter() for x in range(self.num_embeddings)]
        
        with EvalModel(session.model):
            for input, label, *_ in tqdm(self.val_data, desc="Validating", leave=True):
                label = Variable(util.to_gpu(label))
                output = session.forward(input)
                
                step_loss = session.criterion(output, label).data.cpu()
                
                val_loss.update(step_loss, input.shape[0])
                
                val_bce_loss.update(F.cross_entropy(output[-1][0], label).data.cpu(), input.shape[0])
                
                self.val_accuracy_meter.update(output, label)
                    
                for layer, embedding_loss in zip(output[:-1], embedding_losses):
                    embedding_loss.update(batch_all_triplet_loss(layer[0].view(layer[0].size(0), -1), label, 1).data.cpu())
        
        self.val_losses.append(val_loss.raw_avg.item())
        self.val_bce_losses.append(val_bce_loss.raw_avg.item())
         
        accuracy = self.val_accuracy_meter.accuracy()
        
        self.val_accuracies.append(accuracy)
              
        for meter, loss in zip(embedding_losses, self.val_embedding_losses):
            loss.append(meter.raw_avg)     
        
    def on_epoch_begin(self, session):
        self.train_accuracy_meter.reset()     
        self.train_bce_loss_meter = LossMeter()
        
    def on_epoch_end(self, session, lossMeter): 
        self.train_accuracies.append(self.train_accuracy_meter.accuracy())
        self.train_losses.append(lossMeter.debias.data.cpu().item())
        
        self.train_bce_losses.append(self.train_bce_loss_meter.raw_avg.data.cpu().item())
        
        self.run(session, lossMeter) 
        self.epochs.append(self.num_batches)
        self.num_epochs += 1
        
        print("\nval accuracy: ", round(self.val_accuracies[-1], 4),
              "\ntrain loss: ", round(self.train_losses[-1], 4) , 
              " train BCE : ", round(self.train_bce_losses[-1], 4) ,       
              "\nvalid loss: ", round(self.val_losses[-1], 4), 
              " valid BCE : ", round(self.val_bce_losses[-1], 4))
    
    def on_batch_end(self, session, lossMeter, output, label):
        label = Variable(util.to_gpu(label))
        batch_accuracy = self.train_accuracy_meter.update(output, label)
        self.batch_train_accuracies.append(batch_accuracy)
        self.batch_train_losses.append(lossMeter.loss.data.cpu().item())   
        self.train_bce_loss_meter.update(F.cross_entropy(output[-1][0], label).data.cpu(), label.shape[0])
             
        for layer, embedding_loss in zip(output[:-1], self.batch_train_embedding_losses):
            embedding_loss.append(batch_all_triplet_loss(layer[0].view(layer[0].size(0), -1), label, 1).data.cpu().item())
            
        self.num_batches += 1
            
    def plot(self):
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, ncols=1, figsize=(15, 15))
        
        #ax.plot(np.arange(self.num_batches), self.batch_train_accuracies)
        #legend.append("Train accuracy per batch")
        
        #ax.plot(np.arange(self.num_batches), self.batch_train_losses)
        #legend.append("Train loss per batch")
            
        ax1.plot(self.epochs, self.train_accuracies, '-o', label="Training accuracy per epoch")

        ax1.plot(self.epochs, self.val_accuracies, '-o', label="Validation accuracy per epoch")
        
        ax2.plot(self.epochs, self.train_losses, '-o', label="Training loss per epoch")
        
        ax2.plot(self.epochs, self.val_losses, '-o', label="Validation loss per epoch")
        
        ax3.plot(self.epochs, self.train_bce_losses, '-o', label="Training BCE loss per epoch")
        
        ax3.plot(self.epochs, self.val_bce_losses, '-o', label="Validation BCE loss per epoch")
        
        for embedding in self.batch_train_embedding_losses:
            ax4.plot(np.arange(self.num_batches), embedding, label="Train embedding triplet loss per batch")
        
        for embedding in self.val_embedding_losses:
            ax4.plot(self.epochs, embedding, '-o', label="Validation embedding triplet loss per epoch")
            
        for ax in (ax1, ax2, ax3, ax4):
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))   

In [9]:
class SelectiveSequential(nn.Module):
    def __init__(self, to_select, modules_dict):
        super(SelectiveSequential, self).__init__()
        for key, module in modules_dict.items():
            self.add_module(key, module)
        self._to_select = to_select
    
    def forward(self, x):
        list = []
        for name, module in self._modules.items():
            x = module(x)
            if name in self._to_select:
                list.append((x, name))
        return list

In [None]:
  model = SelectiveSequential(
    ['fc1', 'fc2', 'out'],
    {'conv32a': nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
     'batch32a': nn.BatchNorm2d(32),
     'act32a': nn.ReLU(),
    
     'conv32b': nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
     'batch32b': nn.BatchNorm2d(32),
     'act32b': nn.ReLU(),
     
     'max1': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv64a': nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
     'batch64a': nn.BatchNorm2d(64),
     'act64a': nn.ReLU(),
     
     'conv64b': nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
     'batch64b': nn.BatchNorm2d(64),
     'act64b': nn.ReLU(),
    
     'max2': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv128a': nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
     'batch128a': nn.BatchNorm2d(128),
     'act128a': nn.ReLU(),
     
     'conv128b': nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
     'batch128b': nn.BatchNorm2d(128),
     'act128b': nn.ReLU(),
     
     'max3': nn.MaxPool2d(kernel_size=2, stride=2),
     
     'conv256a': nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
     'batch256a': nn.BatchNorm2d(256),
     'act256a': nn.ReLU(),
     
     'conv256b': nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
     'batch256b': nn.BatchNorm2d(256),
     'act256b': nn.ReLU(),
    
     'flatten': Flatten(),
    
     'fc1': nn.Linear(4 * 4 * 256, 512),
     'actLinear': nn.ReLU(),
     #'drop1': nn.Dropout(.05),
     'fc2': nn.Linear(512, 256),
     'actLin2': nn.ReLU(),
     #'drop1': nn.Dropout(.05),
     'out': nn.Linear(256, 10)})

In [None]:
criterion = TripletRegularizedCrossEntropyLoss(0, .5)

In [None]:
sess = Session(model, criterion, optim.AdamW, 1e-3)

In [None]:
validator = EmbeddingSpaceValidator(valloader, 2, CustomOneHotAccuracy)
lr_scheduler = CosAnneal(len(trainloader) * 15, T_mult=1, lr_min=1e-7)
schedule = TrainingSchedule(trainloader, [lr_scheduler, validator])
sess.train(schedule, 15)

In [None]:
validator.plot()

In [None]:
np.max(validator.val_accuracies), "Best accuracy without reg"

In [22]:
model2 = SelectiveSequential(
    ['act1', 'act2', 'out'],
    {'conv32a': nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
     # 'batch32a': nn.BatchNorm2d(32),
     'act32a': nn.ReLU(),
    
     'conv32b': nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
     # 'batch32b': nn.BatchNorm2d(32),
     'act32b': nn.ReLU(),
     
     'max1': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv64a': nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
     # 'batch64a': nn.BatchNorm2d(64),
     'act64a': nn.ReLU(),
     
     'conv64b': nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
     # 'batch64b': nn.BatchNorm2d(64),
     'act64b': nn.ReLU(),
    
     'max2': nn.MaxPool2d(kernel_size=2, stride=2),
    
     'conv128a': nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
     # 'batch128a': nn.BatchNorm2d(128),
     'act128a': nn.ReLU(),
     
     'conv128b': nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
     # 'batch128b': nn.BatchNorm2d(128),
     'act128b': nn.ReLU(),
     
     'max3': nn.MaxPool2d(kernel_size=2, stride=2),
     
     # 'conv256a': nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
     # 'batch256a': nn.BatchNorm2d(256),
     # 'act256a': nn.ReLU(),
     
     # 'conv256b': nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
     # 'batch256b': nn.BatchNorm2d(256),
     # 'act256b': nn.ReLU(),
    
     'flatten': Flatten(),
    
     'fc1': nn.Linear(4 * 4 * 128, 512),
     'act1': nn.ReLU(),
     #'drop1': nn.Dropout(.05),
     'fc2': nn.Linear(512, 256),
     'act2': nn.ReLU(),
     #'drop1': nn.Dropout(.05),
     'out': nn.Linear(256, 10)})

In [23]:
criterion = TripletRegularizedCrossEntropyLoss(.25, .5)

In [24]:
sess = Session(model2, criterion, optim.AdamW, 1e-3)

In [25]:
# lr_find(sess, trainloader, start_lr=1e-12)

In [26]:
sess.set_lr(1e-3)

In [None]:
validator2 = EmbeddingSpaceValidator(valloader, 1, CustomOneHotAccuracy)
lr_scheduler2 = CosAnneal(len(trainloader) * 63, T_mult=1, lr_min=1e-7)
schedule2 = TrainingSchedule(trainloader, [validator2])
sess.train(schedule2, 63)

HBox(children=(IntProgress(value=0, description='Epochs', max=63, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.12it/s]


val accuracy:  0.2009 
train loss:  2.519  train BCE :  2.2057 
valid loss:  2.3905  valid BCE :  2.0986





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.34it/s]


val accuracy:  0.3178 
train loss:  2.3672  train BCE :  1.9122 
valid loss:  2.2205  valid BCE :  1.7912





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 34.70it/s]



val accuracy:  0.3856 
train loss:  2.1969  train BCE :  1.7018 
valid loss:  2.1263  valid BCE :  1.6359


HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 35.00it/s]


val accuracy:  0.4156 
train loss:  2.0647  train BCE :  1.572 
valid loss:  1.9779  valid BCE :  1.5928





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.83it/s]


val accuracy:  0.4544 
train loss:  1.9275  train BCE :  1.4485 
valid loss:  1.9087  valid BCE :  1.4742





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.92it/s]


val accuracy:  0.4272 
train loss:  1.8158  train BCE :  1.3452 
valid loss:  1.9329  valid BCE :  1.5424





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 37.28it/s]


val accuracy:  0.4431 
train loss:  1.6814  train BCE :  1.2118 
valid loss:  2.0235  valid BCE :  1.5647





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.60it/s]


val accuracy:  0.4906 
train loss:  1.569  train BCE :  1.1199 
valid loss:  1.805  valid BCE :  1.3752





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.67it/s]


val accuracy:  0.4997 
train loss:  1.4371  train BCE :  0.9897 
valid loss:  1.8281  valid BCE :  1.393





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 33.53it/s]


val accuracy:  0.4909 
train loss:  1.2893  train BCE :  0.8339 
valid loss:  1.891  valid BCE :  1.4401





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 33.68it/s]



val accuracy:  0.5266 
train loss:  1.158  train BCE :  0.7326 
valid loss:  1.9081  valid BCE :  1.3949


HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 31.68it/s]


val accuracy:  0.5203 
train loss:  1.0146  train BCE :  0.5643 
valid loss:  1.9708  valid BCE :  1.4614





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 33.89it/s]


val accuracy:  0.5256 
train loss:  0.8726  train BCE :  0.4332 
valid loss:  2.0793  valid BCE :  1.5263





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.24it/s]



val accuracy:  0.5178 
train loss:  0.7673  train BCE :  0.3638 
valid loss:  2.1607  valid BCE :  1.6212


HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 34.84it/s]


val accuracy:  0.5306 
train loss:  0.6816  train BCE :  0.292 
valid loss:  2.2287  valid BCE :  1.6531





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 32.75it/s]


val accuracy:  0.5184 
train loss:  0.5787  train BCE :  0.192 
valid loss:  2.4266  valid BCE :  1.8342





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 35.01it/s]


val accuracy:  0.5128 
train loss:  0.5127  train BCE :  0.1562 
valid loss:  2.5441  valid BCE :  1.9533





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 35.41it/s]


val accuracy:  0.5188 
train loss:  0.4879  train BCE :  0.1585 
valid loss:  2.4315  valid BCE :  1.8336





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

Validating: 100%|██████████| 50/50 [00:01<00:00, 36.04it/s]


val accuracy:  0.5272 
train loss:  0.4526  train BCE :  0.116 
valid loss:  2.7202  valid BCE :  2.0544





HBox(children=(IntProgress(value=0, description='Steps', max=50, style=ProgressStyle(description_width='initia…

In [None]:
validator2.plot()

In [None]:
print(np.max(validator2.val_accuracies), "Best accuracy with reg")
print(np.max(validator.val_accuracies), "Best accuracy without reg")