# Knowledge Distillation

In [None]:
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [6]:
import torch.nn as nn
import torch, torchdata, torchtext

import random, math, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. ETL: Loading the dataset

## 2. EDA - simple investigation

## 3. Preprocessing 

### Tokenizing

## 4. Preparing the dataloader

## 5. Design the model and losses

### 5.1 Teacher Model & Student Model

#### Initialize Student Model

### 5.2 Loss function

#### Softmax

$$
P_i(\mathbf{z}_i, T) = \frac{\exp(\mathbf{z}_i / T)}{\sum_{q=0}^k \exp(\mathbf{z}_q / T)}
$$


#### KD and CE Loss
$$
    \mathcal{L}_\text{KD} = -\sum^N_{j=0}\sum_{i=0}^k P_i(\bm{z}_i^{(j)}, T) \log (P_i(\bm{v}_i^{(j)}, T))
$$
$$
    \mathcal{L}_\text{CE} = -\sum^N_{j=0}\sum_{i=0}^k \bm{y}_i^{(j)}\log(P_i(\bm{v}_i^{(j)}, 1))
$$
$$
    \mathcal{L} = \lambda \mathcal{L}_\text{KD} + (1-\lambda)\mathcal{L}_\text{CE}
$$

In [7]:
class DistillKL(nn.Module):
    """
    Distilling the Knowledge in a Neural Network
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! 
    """

    def __init__(self):
        super(DistillKL, self).__init__()

    def forward(self, output_student, output_teacher, temperature=1):
        '''
        Note: the output_student and output_teacher are logits 
        '''
        T = temperature #.cuda()
        
        KD_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(output_student/T, dim=-1),
            F.softmax(output_teacher/T, dim=-1)
        ) * T * T
        
        return KD_loss

## 6. Train

In [11]:
import torch.optim as optim

lr = 0.0005

#training hyperparameters
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy
criterion_div = DistillKL()

In [8]:
def train(model, teacher_model, train_loader, optimizer, criterion, criterion_div, clip, train_loader_length):
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss_cls = outputs.loss
        # compute teacher output
        with torch.no_grad():
            output_teacher = teacher_model(**batch)

        # assert size
        assert outputs.logits.size() == output_teacher.logits.size()

        # compute distillation loss and soften probabilities
        criterion_div = DistillKL()
        loss_div = criterion_div(outputs.logits, output_teacher.logits)
        
        loss = args.gamma * loss_div + (1. - args.gamma) * loss_ce

In [None]:
def evaluate():
    pass