In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
image_sz = 32
batch_sz = 32
channel_sz = 3
patch_sz = 4
hidden_sz = 512
embed_sz = 512

n_epochs = 25
n_heads = 8
n_layers = 4
n_classes = 10

learning_rate = 0.0005

dropout = 0.2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Multi-head attention block

In [4]:
class MHAttention(nn.Module):
    def __init__(self, n_heads, embed_sz, dropout, batch_sz):
        super(MHAttention, self).__init__()

        self.n_heads = n_heads
        self.embed_sz = embed_sz
        self.head_sz = embed_sz // n_heads
        self.batch_sz = batch_sz
        self.dropout = dropout

        self.Q = nn.Linear(self.embed_sz, self.embed_sz)
        self.K = nn.Linear(self.embed_sz, self.embed_sz)
        self.V = nn.Linear(self.embed_sz, self.embed_sz)
        self.output_layer = nn.Linear(self.embed_sz, self.embed_sz)
        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(self, q, k, v, mask=None):
        q0, q1, q2 = q.size()
        k0, k1, k2 = k.size()
        v0, v1, v2 = v.size()

        q = self.Q(q).reshape(q0, q1, self.n_heads, self.head_sz)
        k = self.K(k).reshape(k0, k1, self.n_heads, self.head_sz)
        v = self.V(v).reshape(v0, v1, self.n_heads, self.head_sz)

        if self.batch_sz == 1:
            q = q.transpose(0, 1)
            k = k.transpose(0, 1)
            v = v.transpose(0, 1)

        attention = self.attention(q, k, v, mask)
        return self.output_layer(attention.reshape(-1, v1, self.embed_sz))

    def attention(self, q, k, v, mask=None):
        scores = torch.einsum("bqhe,bkhe->bhqk", [q, k])

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = scores / math.sqrt(self.head_sz)
        scores = F.softmax(scores, dim=-1)
        scores = self.dropout_layer(scores)
        return torch.einsum("bhql,blhd->bqhd", [scores, v])

## Encoder

In [5]:
class ViTEncoder(nn.Module):
    def __init__(self, n_heads, embed_sz, hidden_sz, dropout):
        super(ViTEncoder, self).__init__()

        self.n_heads = n_heads
        self.embed_sz = embed_sz
        self.hidden_sz = hidden_sz
        self.dropout = dropout

        self.norm1 = nn.LayerNorm(self.embed_sz)
        self.norm2 = nn.LayerNorm(self.embed_sz)
        self.attention = MHAttention(n_heads, embed_sz, dropout, batch_sz=0)
        self.mlp = nn.Sequential(
            nn.Linear(self.embed_sz, 4 * self.embed_sz),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(4 * self.embed_sz, self.embed_sz),
            nn.Dropout(self.dropout),
        )

    def forward(self, x):
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

## DeiT

In [6]:
class DeiT(nn.Module):
    def __init__(self, image_sz, channel_sz, patch_sz, embed_sz, n_heads, n_layers, n_classes, hidden_sz, teacher_model, dropout):
        super(DeiT, self).__init__()

        self.image_sz = image_sz
        self.channel_sz = channel_sz
        self.patch_sz = patch_sz
        self.embed_sz = embed_sz
        self.hidden_sz = hidden_sz
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.n_classes = n_classes
        self.dropout = dropout
        self.teacher_model = teacher_model

        self.num_patches = (image_sz // patch_sz) ** 2
        self.patch_sz_flat = channel_sz * (patch_sz ** 2)
        self.embedding_layer = nn.Linear(self.patch_sz_flat, self.embed_sz)
        self.class_token = nn.Parameter(torch.randn(1, 1, self.embed_sz))
        self.distillation_token = nn.Parameter(torch.randn(1, 1, self.embed_sz))
        self.positional_encoding = nn.Parameter(torch.randn(1, self.num_patches + 2, self.embed_sz))
        self.dropout_layer = nn.Dropout(self.dropout)

        self.encoders = nn.ModuleList([
            ViTEncoder(self.n_heads, self.embed_sz, self.hidden_sz, self.dropout)
            for _ in range(self.n_layers)
        ])

        self.classifier = nn.Linear(self.embed_sz, self.n_classes)

        for param in self.teacher_model.parameters():
            param.requires_grad = False
        self.teacher_model.eval()

    def forward(self, x, mask=None):
        b, c, h, w = x.size()

        teacher_logits = self.teacher_model(x)

        x = x.view(b, self.num_patches, self.patch_sz_flat)
        x = self.embedding_layer(x)

        b, n, e = x.size()
        class_tk = self.class_token.expand(b, -1, -1)
        distillation_tk = self.distillation_token.expand(b, -1, -1)

        x = torch.cat((class_tk, x, distillation_tk), dim=1)
        x = self.dropout_layer(x + self.positional_encoding)

        for encoder in self.encoders:
            x = encoder(x)

        x, distillation_token = x[:, 0, :], x[:, -1, :]
        x = self.classifier(x)
        return x, teacher_logits

## Hard distillation loss

In [7]:
class HardDistillationLoss(nn.Module):
    def __init__(self):
        super(HardDistillationLoss, self).__init__()

        self.teacher_cel = nn.CrossEntropyLoss()
        self.student_cel = nn.CrossEntropyLoss()

    def forward(self, teacher_y, student_y, y):
        return 0.5 * ( (self.student_cel(student_y, y)) + (self.teacher_cel(teacher_y, y)) )

## Teacher

In [8]:
from torch.autograd import Variable

import torchvision
import torchvision.models as models

import math

class VGG16_classifier(nn.Module):
    def __init__(self,
                 image_sz,
                 n_classes,
                 hidden_sz,
                 dropout
                 ):

        self.image_sz = image_sz
        self.n_classes = n_classes
        self.hidden_sz = hidden_sz
        self.dropout = dropout

        self.vgg16 = models.vgg16(pretrained=True)
        for parameter in self.vgg16.parameters():
            parameter.requires_grad = True
        self.vgg16.classifier = nn.Sequential(
                nn.Linear(25088, self.hidden_sz * 4),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz * 4, self.hidden_sz * 2),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz * 2, self.hidden_sz),
                nn.ReLU(),
                nn.Dropout(self.dropout),
                nn.Linear(self.hidden_sz, self.n_classes)
            )

    def forward(self, x):
        return self.vgg16(x)

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Add your path to the teacher model (download the teacher from github)

In [10]:
teacher = torch.load("/content/drive/MyDrive/vgg16_cifar10.pth")

  teacher = torch.load("/content/drive/MyDrive/vgg16_cifar10.pth")


# define model

In [11]:
model = DeiT(
    image_sz=image_sz,
    channel_sz=channel_sz,
    patch_sz=patch_sz,
    hidden_sz=hidden_sz,
    embed_sz=embed_sz,
    n_heads=n_heads,
    n_layers=n_layers,
    n_classes=n_classes,
    teacher_model=teacher,
    dropout=dropout
).to(device)

## Train/Test

In [12]:
def train(model, train_loader, loss_function, optimizer, device, num_epochs):

    training_history = {
        "accuracy": [],
        "loss": []
    }

    for epoch in range(1, num_epochs + 1):
        model.train()

        epoch_loss = 0
        true_labels = []
        predicted_labels = []

        for batch_index, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            student_predictions, teacher_predictions = model(images)
            loss = loss_function(teacher_predictions, student_predictions, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels.extend(student_predictions.detach().argmax(dim=-1).tolist())
            true_labels.extend(labels.detach().tolist())

            epoch_loss += loss.item()

        total_correct = sum(pred == true for pred, true in zip(predicted_labels, true_labels))
        total_samples = len(predicted_labels)
        accuracy = total_correct * 100 / total_samples

        training_history["loss"].append(epoch_loss)
        training_history["accuracy"].append(accuracy)

        print(f"{'-' * 50}")
        print(f"Epoch {epoch}/{num_epochs}")
        print(f"Train Loss      : {epoch_loss:.6f}")
        print(f"Train Accuracy  : {accuracy:.2f}% ({total_correct}/{total_samples})")
        print(f"{'-' * 50}")

    return training_history

In [13]:
def test(model, test_loader, device):

    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for batch_index, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            predictions, _ = model(images)

            predicted_labels.extend(predictions.argmax(dim=-1).tolist())
            true_labels.extend(labels.tolist())

    total_correct = sum(pred == true for pred, true in zip(predicted_labels, true_labels))
    total_samples = len(predicted_labels)
    accuracy = total_correct * 100 / total_samples

    print(f"{'-' * 50}")
    print(f"Test Accuracy   : {accuracy:.2f}% ({total_correct}/{total_samples})")
    print(f"{'-' * 50}")

    return accuracy

## preprocess cifar10

In [14]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_sz, shuffle=True)

testset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_sz, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/CIFAR10/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 12.9MB/s]


Extracting ../data/CIFAR10/cifar-10-python.tar.gz to ../data/CIFAR10/
Files already downloaded and verified


## exec

In [15]:
criterion = HardDistillationLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

In [None]:
train(model, trainloader, criterion, optimizer, device, n_epochs)

In [None]:
test(model, testloader, device)