In [1]:
import torch
from torch import nn
from torch.nn.functional import softmax,  log_softmax
from torchmetrics import Accuracy

import timm

import pytorch_lightning as pl

import wandb

import albumentations as T

import sys
sys.path.append('../')

from datasets.cifar100_datamodule import DataModule

In [2]:
def kl_div(x, y):
    px = softmax(x, dim=1)
    lpx, lpy = log_softmax(x, dim=1), log_softmax(y, dim=1)
    return (px*(lpx-lpy)).mean()

class LitModel(pl.LightningModule):
    
    def __init__(self, ):
        super().__init__()
        self.resnet1 = timm.create_model('resnet18', pretrained=False)
        self.resnet2 = timm.create_model('resnet18', pretrained=False)
        
        self.celoss = nn.CrossEntropyLoss()
        self.acc = Accuracy(compute_on_step=True)
        
    def configure_optimizers(self):
        opt1 = torch.optim.SGD(self.resnet1.parameters(), lr=.1, momentum=.9, nesterov=True)
        opt2 = torch.optim.SGD(self.resnet2.parameters(), lr=.1, momentum=.9, nesterov=True)
        step1 = torch.optim.lr_scheduler.StepLR(opt1, step_size=60, gamma=.1)
        step2 = torch.optim.lr_scheduler.StepLR(opt1, step_size=60, gamma=.1)
        return [opt1, opt2], [step1, step2]
    
    def forward(self, x, first:bool):
        x1 = self.resnet1(x)
        x2 = self.resnet2(x)
        if first:
            x1 = x1.detach()
        else:
            x2 = x2.detach()
        return x1, x2
    
    def training_step(self, batch, batch_id, optimizer_idx):
        x, y = batch
        x1, x2 = self(x, optimizer_idx!=0)
        if optimizer_idx==0:
            loss = self.celoss(x1, y) + kl_div(x2, x1)
        else:
            loss = self.celoss(x2, y) + kl_div(x1, x2)
            
        self.log(f'train_loss{optimizer_idx}', loss, prog_bar=True,)

        return loss
    
    def validation_step(self, batch, *a):
        x, y = batch
        x1, x2 = self(x, True)
        x1, x2 = softmax(x1, dim=1), softmax(x2, dim=1)
        self.log('val_acc1', self.acc(x1, y), prog_bar=True)
        self.log('val_acc2', self.acc(x2, y), prog_bar=True)
        
        return None
    
    def testing_step(self, batch, *a):
        x, y = batch
        x1, x2 = self(x, True)
        x1, x2 = softmax(x1, dim=1), softmax(x2, dim=1)
        self.log('test_acc1', self.acc(x1, y), prog_bar=True)
        self.log('test_acc2', self.acc(x2, y), prog_bar=True)

In [3]:
transforms = T.Compose([
                        T.HorizontalFlip(p=.5), 
                        T.RandomCrop(28, 28, always_apply=True), 
                        T.CropAndPad(px=4, pad_mode=2, keep_size=False)
                       ])

In [None]:
wandb.finish()

logger = pl.loggers.wandb.WandbLogger(name='distilled models', entity='blurry-mood')
trainer = pl.Trainer(gpus=-1, max_epochs=200, logger=logger, val_check_interval=.5, progress_bar_refresh_rate=0)
dm = DataModule('../datasets/cifar-100-python/', transforms, batch_size=64)

litmodel = LitModel()
trainer.fit(litmodel, dm)
trainer.test(litmodel)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]






[34m[1mwandb[0m: Currently logged in as: [33mblurry-mood[0m (use `wandb login --relogin` to force relogin)



  | Name    | Type             | Params
---------------------------------------------
0 | resnet1 | ResNet           | 11.7 M
1 | resnet2 | ResNet           | 11.7 M
2 | celoss  | CrossEntropyLoss | 0     
3 | acc     | Accuracy         | 0     
---------------------------------------------
23.4 M    Trainable params
0         Non-trainable params
23.4 M    Total params
93.516    Total estimated model params size (MB)
