In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
image_sz = 32
batch_sz = 32
channel_sz = 3
patch_sz = 4
hidden_sz = 512
embed_sz = 512

n_epochs = 25
n_heads = 8
n_layers = 4
n_classes = 10

learning_rate = 0.0005

dropout = 0.2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Multi-head attention block

In [None]:
class MHAttention(nn.Module):

    def __init__(self, n_heads, embed_sz, dropout, batch_sz):
        super(MHAttention, self).__init__()

        # attention params
        self.n_heads  = n_heads
        self.embed_sz = embed_sz
        self.dropout  = dropout
        self.batch_sz = batch_sz
        self.head_sz  = embed_sz // n_heads

        self.dropout_lr = nn.Dropout(dropout)

        self.Q = nn.Linear(self.embed_sz, self.embed_sz)
        self.K = nn.Linear(self.embed_sz, self.embed_sz)
        self.V = nn.Linear(self.embed_sz, self.embed_sz)

        self.l = nn.Linear(self.embed_sz, self.embed_sz)

    def forward(self, q, k, v, mask = None):

        q0, q1, q2 = q.size()
        k0, k1, k2 = k.size()
        v0, v1, v2 = v.size()

        q = self.Q(q).reshape(q0, q1, self.n_heads, self.head_sz)
        k = self.Q(k).reshape(k0, k1, self.n_heads, self.head_sz)
        v = self.Q(v).reshape(v0, v1, self.n_heads, self.head_sz)

        if self.batch_sz == 1:
            q = q.transpose(0,1)
            k = k.transpose(0,1)
            v = v.transpose(0,1)

        attention = self.attention(q, k, v, mask)

        return self.l( attention.reshape(-1, v1, self.embed_sz) )

    def attention(self, q, k, v, mask = None):


        scores = torch.einsum("bqhe,bkhe->bhqk", [q, k])

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = scores / math.sqrt(self.embed_sz)
        scores = F.softmax(scores, dim = -1)

        return torch.einsum("bhql,blhd->bqhd", [scores, v])

## Encoder

In [None]:
class ViTEncoder(nn.Module):
    def __init__(self, n_heads, embed_sz, dropout, hidden_sz):
        super(ViTEncoder, self).__init__()

        # attention params
        self.n_heads    = n_heads
        self.embed_sz   = embed_sz
        self.dropout    = dropout
        # self.batch_sz = batch_sz
        # self.head_sz  = embed_sz // n_heads
        self.hidden_sz  = hidden_sz

        self.norm1 = nn.LayerNorm(self.embed_sz)
        self.norm2 = nn.LayerNorm(self.embed_sz)

        self.attention = MHAttention(n_heads, embed_sz, dropout, 0)
        self.MLP = nn.Sequential(
            nn.Linear(self.embed_sz, 4*self.embed_sz),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(4*self.embed_sz, self.embed_sz),
            nn.Dropout(self.dropout)
        )

    def forward(self, x):
        x = self.norm1(x)
        x = x + self.attention(x, x, x)
        x = x + self.MLP(self.norm2(x))
        return x


## DeiT

In [None]:
class DeiT(nn.Module):

    def __init__(self,image_sz,channel_sz,patch_sz,hidden_sz,embed_sz,n_heads,n_layers,n_classes,teacher,dropout):
        super(DeiT, self).__init__()

        self.image_sz        = image_sz
        self.channel_sz      = channel_sz
        self.patch_sz0       = patch_sz
        self.hidden_sz       = hidden_sz
        self.embed_sz        = embed_sz
        self.n_heads         = n_heads
        self.n_layers        = n_layers
        self.n_classes       = n_classes
        self.teacher         = teacher
        self.dropout         = dropout
        self.n_patches       = (image_sz // patch_sz) ** 2
        self.patch_sz1       = channel_sz * (patch_sz ** 2)
        self.dropout_lr      = nn.Dropout(self.dropout)
        self.norm1           = nn.LayerNorm(self.embed_sz)
        self.distillation_tk = nn.Parameter(torch.randn(1,1,self.embed_sz))
        self.class_tk        = nn.Parameter(torch.randn(1,1,self.embed_sz))
        self.embed           = nn.Linear(self.patch_sz1, self.embed_sz)
        self.pos_encoding    = nn.Parameter(torch.randn(1,self.n_patches + 2, self.embed_sz))

        for parameter in self.teacher.parameters():
            parameter.requires_grad = False
        self.teacher.eval()

        self.encoders = nn.ModuleList([])
        for layer in range(self.n_layers):
            self.encoders.append(ViTEncoder(self.n_heads, self.embed_sz, self.dropout, self.hidden_sz))
        self.classifier = nn.Sequential(nn.Linear(self.embed_sz, self.n_classes))

    def forward(self, x, mask = None):
        b,c,h,w = x.size()

        teacher_out = self.teacher(x)
        x = x.reshape(b, int((h / self.patch_sz0) * (w / self.patch_sz0)), c * self.patch_sz0 * self.patch_sz0)
        x = self.embed(x)
        b,n,e = x.size()

        class_tk = self.class_tk.expand(b,1,e)
        distillation_tk = self.class_tk.expand(b,1,e)

        x = torch.cat((class_tk, x), dim=1)
        x = torch.cat((x, distillation_tk), dim=1)
        x = self.dropout_lr(x+self.pos_encoding)

        for enc in self.encoders:
            x = enc(x)

        x, distillation_token = x[:, 0, :], x[:, -1, :]
        x = self.classifier(self.norm1(x))

        return self.classifier(self.norm1(x)), teacher_out

## Hard distillation loss

In [None]:
class HardDistillationLoss(nn.Module):
    def __init__(self):
        super(HardDistillationLoss, self).__init__()

        self.teacher_cel = nn.CrossEntropyLoss()
        self.student_cel = nn.CrossEntropyLoss()

    def forward(self, teacher_y, student_y, y):
        return 0.5 * ( (self.student_cel(student_y, y)) + (self.teacher_cel(teacher_y, y)) )

## Teacher

In [None]:
from torch.autograd import Variable

import torchvision
import torchvision.models as models

import math

class VGG16_classifier(nn.Module):
    def __init__(self,
                 image_sz,
                 n_classes,
                 hidden_sz,
                 dropout
                 ):

        self.image_sz = image_sz
        self.n_classes = n_classes
        self.hidden_sz = hidden_sz
        self.dropout = dropout

        self.vgg16 = models.vgg16(pretrained=True)
        for parameter in self.vgg16.parameters():
            parameter.requires_grad = True
        self.vgg16.classifier = nn.Sequential(
                nn.Linear(25088, self.hidden_sz * 4),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz * 4, self.hidden_sz * 2),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz * 2, self.hidden_sz),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz, self.n_classes)
            )

    def forward(self, x):
        return self.vgg16(x)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
teacher = torch.load("add/your/path/to/vgg16_cifar10.pth")

# define model

In [None]:
model = DeiT(
    image_sz=image_sz,
    channel_sz=channel_sz,
    patch_sz=patch_sz,
    hidden_sz=hidden_sz,
    embed_sz=embed_sz,
    n_heads=n_heads,
    n_layers=n_layers,
    n_classes=n_classes,
    teacher=teacher,
    dropout=dropout
).to(device)

## Train/Test

In [None]:
def train(model, train_loader, loss_function, optimizer, device, num_epochs):

    training_history = {
        "accuracy": [],
        "loss": []
    }

    for epoch in range(1, num_epochs + 1):
        model.train()

        epoch_loss = 0
        true_labels = []
        predicted_labels = []

        for batch_index, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            student_predictions, teacher_predictions = model(images)
            loss = loss_function(teacher_predictions, student_predictions, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels.extend(student_predictions.detach().argmax(dim=-1).tolist())
            true_labels.extend(labels.detach().tolist())

            epoch_loss += loss.item()

        total_correct = sum(pred == true for pred, true in zip(predicted_labels, true_labels))
        total_samples = len(predicted_labels)
        accuracy = total_correct * 100 / total_samples

        training_history["loss"].append(epoch_loss)
        training_history["accuracy"].append(accuracy)

        print(f"{'-' * 50}")
        print(f"Epoch {epoch}/{num_epochs}")
        print(f"Train Loss      : {epoch_loss:.6f}")
        print(f"Train Accuracy  : {accuracy:.2f}% ({total_correct}/{total_samples})")
        print(f"{'-' * 50}")

    return training_history

In [None]:
def test(model, test_loader, device):

    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for batch_index, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            predictions, _ = model(images)

            predicted_labels.extend(predictions.argmax(dim=-1).tolist())
            true_labels.extend(labels.tolist())

    total_correct = sum(pred == true for pred, true in zip(predicted_labels, true_labels))
    total_samples = len(predicted_labels)
    accuracy = total_correct * 100 / total_samples

    print(f"{'-' * 50}")
    print(f"Test Accuracy   : {accuracy:.2f}% ({total_correct}/{total_samples})")
    print(f"{'-' * 50}")

    return accuracy

## preprocess cifar10

In [None]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_sz, shuffle=True)

testset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_sz, shuffle=False)

## exec

In [None]:
criterion = HardDistillationLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

In [None]:
train(model, trainloader, criterion, optimizer, device, n_epochs)

In [None]:
test(model, testloader, device)