In [1]:
import torch
torch.backends.cudnn.benchmark = True

from torch import nn
from torch.nn.functional import softmax,  log_softmax
from torchmetrics import Accuracy

from resnet_cifar import resnet32

import pytorch_lightning as pl

import wandb

import torchvision.transforms as T

import sys
sys.path.append('../')

from datasets.cifar100_datamodule import DataModule

from deepblocks.layer import MultiHeadAttention

## Training with the network-based strategy

In [2]:
class Attention(nn.Module):
    
    def __init__(self, input_dim):
        super().__init__()
        
        self.linear = nn.Linear(input_dim, input_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        xa = self.linear(x)
        b = xa @ x.transpose(-1, -2)
        c = self.softmax(b)
        y = c @ x
        return y

In [3]:
for _ in range(1000):
    att = Attention(10)
    x = torch.rand(100, 51, 10)
    assert att(x).min()>=0

In [4]:
wandb.finish()


def kl_div(x, y):
    return (x*(x/y).log()).mean()

class LitModel(pl.LightningModule):
    
    def __init__(self, ):
        super().__init__()
        
        self.student1 = resnet32()
        self.student2 = resnet32()
        self.student3 = resnet32()
        self.leader = resnet32()
        
        self.mha = Attention(input_dim=100)
        
        self.T = 3.0
        
        self.celoss = nn.CrossEntropyLoss()
        self.acc = Accuracy(compute_on_step=True, top_k=1)
        
    def configure_optimizers(self):
        opt = torch.optim.SGD(self.parameters(), lr=.1, momentum=.9, nesterov=True, weight_decay=5e-4)
        step = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[150, 255], gamma=.1)
        return [opt], [step]
    
    def forward(self, x, optimize_first:bool=True):
        x1 = self.student1(x)
        x2 = self.student2(x)
        x3 = self.student3(x)
        xl = self.leader(x)
        return x1, x2, x3, xl
    
    def training_step(self, batch, batch_id):
        x, y = batch        
        xs = self(x)
        
        # GT loss
        loss = [self.celoss(_x, y) for _x in xs]
        loss = torch.stack(loss, dim=0).sum()
        
        # peers loss
        t1, t2, t3, tl = [softmax(_x/self.T, dim=1) for _x in xs]
        peers = torch.stack((t1, t2, t3), dim=1)
        mha_peers = self.mha(peers)

        loss += self.T * kl_div(mha_peers, peers)
        
        # leader loss
        mean = peers.mean(dim=1)
        loss += self.T * kl_div(mean, tl)
        
        # logging
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', self.acc(tl, y), prog_bar=True)
        
        return loss    
    
    def validation_step(self, batch, batch_id):
        x, y = batch        
        xs = self(x)
        
        # GT loss
        loss = [self.celoss(_x, y) for _x in xs]
        loss = sum(loss)
        
        # peers loss
        t1, t2, t3, tl = [softmax(_x, dim=1) for _x in xs]
#         peers = torch.stack((t1, t2, t3), dim=1)
#         mha_peers = self.mha(peers)
#         loss += self.T * kl_div(mha_peers, peers)
        
        # leader loss
#         mean = peers.mean(dim=1)
#         loss += self.T * kl_div(mean, tl)
        
        # logging
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.acc(tl, y), prog_bar=True)
        
        return loss 
    
    def test_step(self, batch, *a):
        x, y = batch        
        xs = self(x)
        
        # GT loss
        loss = [self.celoss(_x, y) for _x in xs]
        loss = sum(loss)
        
        # peers loss
        t1, t2, t3, tl = [softmax(_x, dim=1) for _x in xs]
#         peers = torch.stack((t1, t2, t3), dim=1)
#         mha_peers = self.mha(peers)
#         loss += self.T * kl_div(mha_peers, peers)
        
        # leader loss
#         mean = peers.mean(dim=1)
#         loss += self.T * kl_div(mean, tl)
        
        # logging
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', self.acc(tl, y), prog_bar=True)
        
        return loss 

In [5]:
train_transforms = T.Compose([
                    T.RandomCrop(32, padding=4),
                    T.RandomHorizontalFlip(),  # randomly flip image horizontally
                    T.ToTensor(),
                    T.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
                       ])
test_transforms = T.Compose([
                    T.ToTensor(),
                    T.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
                        ])

In [None]:
wandb.finish()

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
logger = pl.loggers.wandb.WandbLogger(project='distilled models', entity='blurry-mood')

trainer = pl.Trainer(callbacks=[lr_monitor], logger=logger, 
                     gpus=-1, max_epochs=300, 
                     val_check_interval=1., progress_bar_refresh_rate=0)

dm = DataModule('../datasets/cifar-100-python/', train_transform=train_transforms, test_transform=test_transforms,
                batch_size=128)

litmodel = LitModel()
trainer.fit(litmodel, dm)
trainer.test(litmodel)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
torch.save(litmodel.leader.state_dict(), '../models/online knowledge distillation with diverse peers/leader_resnet32.pth')