In [None]:
import torch
torch.backends.cudnn.benchmark = True

from torch import nn
from torch.nn.functional import softmax,  log_softmax
from torchmetrics import Accuracy

from resnet_cifar import resnet32

import pytorch_lightning as pl

import wandb

import torchvision.transforms as T

import sys
sys.path.append('../')

from datasets.cifar100_datamodule import DataModule

In [None]:
train_transforms = T.Compose([
                    T.RandomCrop(32, padding=4),
                    T.RandomHorizontalFlip(),  # randomly flip image horizontally
                    T.ToTensor(),
                    T.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
                       ])
test_transforms = T.Compose([
                    T.ToTensor(),
                    T.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
                        ])

## Training with the DML strategy

In [None]:
def kl_div(x, y):
    px = softmax(x, dim=1)
    lpx, lpy = log_softmax(x, dim=1), log_softmax(y, dim=1)
    return (px*(lpx-lpy)).mean()

class LitModel(pl.LightningModule):
    
    def __init__(self, ):
        super().__init__()
        
        self.automatic_optimization = False
        
        self.resnet1 = resnet32()
        self.resnet2 = resnet32()
        
        self.celoss = nn.CrossEntropyLoss()
        self.acc = Accuracy(compute_on_step=True)
        
    def configure_optimizers(self):
        opt = torch.optim.SGD([*self.resnet1.parameters(), *self.resnet2.parameters()], lr=.1, momentum=.9, nesterov=True)
        step = torch.optim.lr_scheduler.StepLR(opt, step_size=60, gamma=.1)
        return [opt], [step]
    
    def forward(self, x, optimize_first:bool=True):
        x1 = self.resnet1(x)
        x2 = self.resnet2(x)
        if not optimize_first:
            x1 = x1.detach()
        else:
            x2 = x2.detach()
        return x1, x2
    
    def training_step(self, batch, batch_id):
        x, y = batch

        opt = self.optimizers()
        
        # Optimize student 1
        x1, x2 = self(x)
        loss = self.celoss(x1, y) + kl_div(x2, x1)
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()
        
        self.log(f'train_loss1', loss, prog_bar=True,)
        
        # Optimize student 2
        x1, x2 = self(x, False)
        loss = self.celoss(x2, y) + kl_div(x1, x2)
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()
        
        self.log(f'train_loss2', loss, prog_bar=True,)
        
        if self.trainer.is_last_batch :
            self.lr_schedulers().step()
            
    
    def validation_step(self, batch, *a):
        x, y = batch
        x1, x2 = self(x, True)
        x1, x2 = softmax(x1, dim=1), softmax(x2, dim=1)
        self.log('val_acc1', self.acc(x1, y), prog_bar=True)
        self.log('val_acc2', self.acc(x2, y), prog_bar=True)
            
    def test_step(self, batch, *a):
        x, y = batch
        x1, x2 = self(x, True)
        x1, x2 = softmax(x1, dim=1), softmax(x2, dim=1)
        self.log('test_acc1', self.acc(x1, y), prog_bar=True)
        self.log('test_acc2', self.acc(x2, y), prog_bar=True)

In [None]:
wandb.finish()

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
logger = pl.loggers.wandb.WandbLogger(project='distilled models', entity='blurry-mood')

trainer = pl.Trainer(callbacks=[lr_monitor], logger=logger, 
                     gpus=-1, max_epochs=200, 
                     val_check_interval=1., progress_bar_refresh_rate=0)
dm = DataModule('../datasets/cifar-100-python/', train_transform=train_transforms, test_transform=test_transforms,
                batch_size=64)
litmodel = LitModel()
trainer.fit(litmodel, dm)
trainer.test(litmodel)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mblurry-mood[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name    | Type             | Params
---------------------------------------------
0 | resnet1 | ResNet           | 470 K 
1 | resnet2 | ResNet           | 470 K 
2 | celoss  | CrossEntropyLoss | 0     
3 | acc     | Accuracy         | 0     
---------------------------------------------
940 K     Trainable params
0         Non-trainable params
940 K     Total params
3.760     Total estimated model params size (MB)


In [None]:
torch.save(litmodel.resnet1.state_dict(), '../models/deep mutual learning/dml_s1_resnet32.pth')

In [None]:
torch.save(litmodel.resnet2.state_dict(), '../models/deep mutual learning/dml_s2_resnet32.pth')