In [None]:
!git clone https://github.com/baiydaavi/vonenet.git
%cd vonenet/
!git checkout avi

Cloning into 'vonenet'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 184 (delta 13), reused 21 (delta 10), pack-reused 160[K
Receiving objects: 100% (184/184), 508.06 MiB | 35.32 MiB/s, done.
Resolving deltas: 100% (87/87), done.
Checking out files: 100% (27/27), done.
/content/vonenet
Already on 'avi'
Your branch is up to date with 'origin/avi'.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import os, glob
from PIL import Image
from google.colab import files

from vonenet import ResNet18

#login to weights and biases to monitor training

In [None]:
%%capture
!pip install wandb --upgrade
import wandb
wandb.login()

#Download Tiny ImageNet data

In [None]:
if not os.path.exists('/content/tiny-imagenet-200/'):
    !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    !unzip -q tiny-imagenet-200.zip

--2021-09-03 13:28:37--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.68.10
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2021-09-03 13:28:54 (13.4 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]



#Define training and testing dataset classes

In [None]:
class TrainTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("tiny-imagenet-200/train/*/*/*.JPEG")
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.id_dict[img_path.split('/')[2]]
        if self.transform:
            image = self.transform(image)
        return image, label

class TestTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("tiny-imagenet-200/val/images/*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open('tiny-imagenet-200/val/val_annotations.txt', 'r')):
            a = line.split('\t')
            img, cls_id = a[0],a[1]
            self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.cls_dic[img_path.split('/')[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
id_dict = {}
for i, line in enumerate(open('tiny-imagenet-200/wnids.txt', 'r')):
  id_dict[line.replace('\n', '')] = i

transform_train = transforms.Compose([
        transforms.RandomAffine(degrees=30, translate=(0.05, 0.05), scale=(1.0, 1.2)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

#Ensemble Model


In [None]:
class EnsembleModel(nn.Module):

    def __init__(self, model_dict, num_classes = 200):
        super(EnsembleModel, self).__init__()

        self.model_dict = model_dict
        self.models = nn.ModuleList()

        for i, model_type in enumerate(self.model_dict):
            model_params = self.model_dict[model_type]['model_params']
            model_path = self.model_dict[model_type]['model_path']
            
            model = VOneNet(
                model_arch='resnet18', 
                noise_mode=model_params['noise_mode'], 
                noise_scale=0.286, 
                poisson_scale=model_params['poisson_scale'], 
                noise_level=0.071, 
                image_size=64, 
                visual_degrees=2, 
                sf_max=model_params['sf_max'], 
                sf_min=model_params['sf_min'],
                simple_channels=model_params['simple_channels'],
                complex_channels=model_params['complex_channels'],
                stride=2, 
                ksize=25, 
                k_exc=23.5
                )
            
            model.load_state_dict(torch.load(model_path)['net'])

            self.models.append(model)

    def forward(self, x):
        
        out = self.models[0](x)
        for i in range(1, len(self.model_dict)):
            out += self.models[i](x)

        return out 

In [None]:
# Defining the variants use to create the Ensemble model
param_dict = {}
param_dict['sf_low'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':2.0, 
                        'sf_min':0, 'simple_channels':256, 'complex_channels':256}
param_dict['sf_mid'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':5.6,
                        'sf_min':2.0, 'simple_channels':256, 'complex_channels':256}
param_dict['sf_high'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':11.3,
                         'sf_min':5.6, 'simple_channels':256, 'complex_channels':256}
param_dict['no_noise'] = {'noise_mode':None, 'poisson_scale':0.0, 'sf_max':11.3, 
                        'sf_min':0.0, 'simple_channels':256, 'complex_channels':256}
param_dict['low_noise'] = {'noise_mode':'neuronal', 'poisson_scale':0.5, 'sf_max':11.3, 
                        'sf_min':0.0, 'simple_channels':256, 'complex_channels':256}
param_dict['normal_noise'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':11.3, 
                        'sf_min':0.0, 'simple_channels':256, 'complex_channels':256}
param_dict['only_simple'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':11.3, 
                        'sf_min':0.0, 'simple_channels':512, 'complex_channels':0}
param_dict['only_complex'] = {'noise_mode':'neuronal', 'poisson_scale':1.0, 'sf_max':11.3,
                        'sf_min':0.0, 'simple_channels':0, 'complex_channels':512}

model_dict = {}
for model_type in param_dict:
    model_dict[model_type] = {'model_params':param_dict[model_type],
                              'model_path':f'/content/drive/MyDrive/github/vonenet/checkpoint/best_models/{model_type}.pth'}

#Train model

In [None]:
class TrainModel:
    def __init__(self, model_dict, batch_size=128, save_dir=None):
        
        self.EnsembleModel = EnsembleModel(model_dict)

        self.DistilledModel = VOneNet(model_arch='resnet18',
                                    noise_mode=None,
                                    noise_scale=0.286, 
                                    poisson_scale = 0.0, 
                                    noise_level=0.071, 
                                    image_size=64, 
                                    visual_degrees=2, 
                                    sf_max=11.3, 
                                    stride=2, 
                                    ksize=25, 
                                    k_exc=23.5)
        
        self.kl_div_loss = nn.KLDivLoss(log_target=True)
        self.criterion = nn.CrossEntropyLoss()
        self.temperature = 5.
        self.soft_targets_weight = 100.
        self.label_loss_weight = 0.5
        self.optimizer = optim.SGD(self.DistilledModel.parameters(), lr=0.1,
                            momentum=0.9, weight_decay=5e-4)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience = 5)
        
        self.trainset = TrainTinyImageNetDataset(id=id_dict, transform=transform_train)
        self.testset = TestTinyImageNetDataset(id=id_dict, transform=transform_test)
        self.trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
        self.testloader = torch.utils.data.DataLoader(self.testset, batch_size=batch_size, shuffle=False)

        self.train_loss_vec = []
        self.train_acc_vec = []
        self.test_loss_vec = []
        self.test_acc_vec = []
        self.best_acc = 0  # best test accuracy

        self.save_dir = save_dir
        if not os.path.isdir(self.save_dir):
                os.mkdir(self.save_dir)

    def train(self, epoch):

        # training
        print('\nEpoch: %d' % epoch)
        self.EnsembleModel.eval()
        self.DistilledModel.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(self.trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            with torch.no_grad():
                ensemble_logits = self.EnsembleModel(inputs)
            self.optimizer.zero_grad()
            outputs = self.DistilledModel(inputs)
            soft_targets = nn.functional.log_softmax(ensemble_logits / self.temperature, dim=-1)
            soft_prob = nn.functional.log_softmax(outputs/self.temperature, dim=-1)
            soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
            label_loss = self.criterion(outputs, targets) #.to(device)
            loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss
            loss.backward()
            self.optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        training_loss = train_loss/(batch_idx+1)
        training_acc = 100.*correct/total

        print(f'Training Loss: {training_loss} | Training Acc: {training_acc} ({correct}/{total})')
        self.train_loss_vec.append(training_loss)
        self.train_acc_vec.append(training_acc)

        # validation
        self.DistilledModel.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                ensemble_logits = self.EnsembleModel(inputs)
                outputs = self.DistilledModel(inputs)
                soft_targets = nn.functional.log_softmax(ensemble_logits / self.temperature, dim=-1)
                soft_prob = nn.functional.log_softmax(outputs/self.temperature, dim=-1)
                soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
                label_loss = self.criterion(outputs, targets) #.to(device)
                loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            
        self.scheduler.step(test_loss)

        validation_loss = test_loss/(batch_idx+1)
        validation_acc = 100.*correct/total

        print(f'Validation Loss: {validation_loss} | Validation Acc: {validation_acc} ({correct}/{total})')

        self.test_loss_vec.append(validation_loss)
        self.test_acc_vec.append(validation_acc)

        # logging to wandb
        wandb.log({"Epoch": epoch,        
           "Train Loss": training_loss,        
           "Train Acc": training_acc,        
           "Valid Loss": validation_loss,        
           "Valid Acc": validation_acc})

        # Save checkpoint.
        if validation_acc > self.best_acc:
            print('Saving..')
            state = {
                'net': self.DistilledModel.state_dict(),
                'acc': validation_acc,
                'epoch': epoch,
            }
            torch.save(state, f'{self.save_dir}/{epoch}.pth')
            self.best_acc = validation_acc

    def check_accuracy(self, pretrained_path = None):    
        if pretrained_path:
            print('loading pretrained model')
            trained_model_weights = torch.load(pretrained_path)
            self.DistilledModel.load_state_dict(trained_model_weights['net'])

        self.DistilledModel = self.DistilledModel.to(device)
        self.DistilledModel.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.DistilledModel(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_acc = 100.*correct/total

        print(f'Test Accuracy: {test_acc} ({correct}/{total})')
    
    def run(self, pretrained = False, pretrained_path = None, start_epoch = 0, num_epochs = 80):

        if pretrained:
            print('loading pretrained model')
            trained_model_weights = torch.load(pretrained_path)
            self.DistilledModel.load_state_dict(trained_model_weights['net'])

        self.DistilledModel = self.DistilledModel.to(device)
        self.EnsembleModel = self.EnsembleModel.to(device)
        
        with wandb.init(name=distillation_no_noise_variant, project='VOneNet_1'):
            for epoch in range(start_epoch, start_epoch+num_epochs):
                self.train(epoch)

        print(self.best_acc)

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training
save_dir = f'/content/drive/MyDrive/github/vonenet/checkpoint/DistillNoNoise/'
Model = TrainModel(model_dict, batch_size = 256, save_dir = save_dir)
Model.run(start_epoch = 1, num_epochs = 81)

# Testing
Model.check_accuracy()

[34m[1mwandb[0m: Currently logged in as: [33mbaidyaavinash[0m (use `wandb login --relogin` to force relogin)



Epoch: 1


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training Loss: 4.664870393245726 | Training Acc: 5.688 (5688/100000)
Validation Loss: 4.318284602104863 | Validation Acc: 9.39 (939/10000)
Saving..

Epoch: 2
Training Loss: 3.899117456982508 | Training Acc: 14.895 (14895/100000)
Validation Loss: 4.1448576027833965 | Validation Acc: 12.97 (1297/10000)
Saving..

Epoch: 3
Training Loss: 3.4901606512191656 | Training Acc: 21.436 (21436/100000)
Validation Loss: 3.4738124503365047 | Validation Acc: 21.27 (2127/10000)
Saving..

Epoch: 4
Training Loss: 3.2668826921516674 | Training Acc: 25.364 (25364/100000)
Validation Loss: 3.4677123389666593 | Validation Acc: 22.94 (2294/10000)
Saving..

Epoch: 5
Training Loss: 3.1041984183099265 | Training Acc: 28.219 (28219/100000)
Validation Loss: 3.188102260420594 | Validation Acc: 27.19 (2719/10000)
Saving..

Epoch: 6
Training Loss: 2.9813883676553323 | Training Acc: 30.672 (30672/100000)
Validation Loss: 3.0959047818485694 | Validation Acc: 29.35 (2935/10000)
Saving..

Epoch: 7
Training Loss: 2.8888375