In [1]:
!pip install pytorch-lightning
!pip install torch_pruning



In [2]:
# Downalod ResNet18 weights
# !wget -nc http://cipizio.it/storage/Nas-ResNet/resnet18_net_e_199.pth
# !mkdir checkpoint
# !mv resnet18_net_e_199.pth ./checkpoint

# Download trained student with gates and stricsigmoid activation on 30 epochs with ADAM
!#wget -nc http://cipizio.it/storage/Nas-ResNet/student_e50_ADAM_stricsig_lre-2_lp_e-4_dw_8.pth
#!mkdir checkpoint
#!mv student_e50_ADAM_stricsig_lre-2_lp_e-4_dw_8.pth ./checkpoint

In [3]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import src.utils as utils
from src.KD import GatedKD
from src.models import ResNet18
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


import torch_pruning as tp

import os
import argparse

In [4]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
args = {'resume':False, 'lr':0.001, 'lambda_gating':0.1}
seed_everything(42, workers=True)


Global seed set to 42


42

## Dataset CIFAR10

In [5]:
# Data
torch.manual_seed(43)
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])




trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

val_size = 5000
train_size = len(trainset) - val_size

trainset, validationset = torch.utils.data.random_split(trainset, [train_size, val_size])

train_dataloaders = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_dataloaders = torch.utils.data.DataLoader(
    validationset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
test_dataloaders = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2, pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


## Model

In [6]:
args['lr'] = 1e-3
args['lambda_penalty'] = 1e-5
args['distill_weight'] = 0.9

teacher_model = ResNet18(pretrained=True, remove_avg_pool_layer=True, lr=args['lr'], optim='sgd', scheduler_t_max=200)
student_model = ResNet18(lr=args['lr'], optim='adam', scheduler_t_max=200, use_gates_with_penalty=args['lambda_penalty'])

a = student_model.apply_gates()

## Train Teacher

In [7]:
teacher_trainer = Trainer(deterministic=True,  callbacks=[EarlyStopping(monitor="val_loss", mode="min")],  min_steps=100, max_steps=200)
teacher_trainer.fit(teacher_model, train_dataloaders, val_dataloaders)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ResNet           | 11.2 M
1 | criterion     | CrossEntropyLoss | 0     
2 | val_accuracy  | Accuracy         | 0     
3 | test_accuracy | Accuracy         | 0     
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 42
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

torch.Size([128, 10]) torch.Size([128])


  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])
torch.Size([128, 10]) torch.Size([128])


## Training Student KD

In [None]:
distiller = GatedKD(teacher_model, train_dataloaders, test_dataloaders, loss_fn=nn.KLDivLoss(), 
distill_weight=args['distill_weight'], temp=20)  

In [None]:
kd_trainer = Trainer(deterministic=True)
kd_trainer.fit(distiller, train_dataloaders, val_dataloaders, callbacks=[EarlyStopping(monitor="val_kd_loss", mode="min")], min_steps=30, max_steps=100)

In [None]:
kd_trainer.test(dataloaders=test_dataloaders)

In [None]:
use_mean = False
#NAS.gating_threshold = 1e-2
student_model.NAS.verbose=True
student_model.estimate_required_channels(use_mean=use_mean)   
student_model.NAS.verbose=False
student_model.optimize(use_mean=use_mean, amount=0.4)
print("Improvment model size:", distiller.calc_improvement())

In [None]:
print(student_model)

In [None]:
%%time
kd_trainer.test(dataloaders=test_dataloaders)