In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import ViTForImageClassification, ViTImageProcessor, MobileNetV2ForImageClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
teacher_id = 'nateraw/vit-base-patch16-224-cifar10'
teacher_model = ViTForImageClassification.from_pretrained(teacher_id)
teacher_model.to(device)
teacher_model.eval()

for param in teacher_model.parameters():
    param.requires_grad = False

processor = ViTImageProcessor.from_pretrained(teacher_id)

In [16]:
student_model = MobileNetV2ForImageClassification.from_pretrained('google/mobilenet_v2_1.0_224')

NUM_CLASSES = 10
student_model.classifier = torch.nn.Linear(student_model.classifier.in_features, NUM_CLASSES)
student_model.config.num_labels = NUM_CLASSES

student_model.to(device)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [5]:
dataset = load_dataset('uoft-cs/cifar10')

README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [10]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

def train_transforms_fn(examples):
    examples['pixel_values'] = [train_transforms(img.convert('RGB')) for img in examples['img']]
    del examples["img"]
    return examples

def val_transforms_fn(examples):
    examples['pixel_values'] = [val_transforms(img.convert('RGB')) for img in examples['img']]
    del examples["img"]
    return examples

dataset['train'].set_transform(train_transforms_fn)
dataset['test'].set_transform(val_transforms_fn)

In [11]:
train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, num_workers=2)

In [12]:
TEMPERATURE = 4.0
ALPHA = 0.5
lr = 0.002
EPOCHS = 10

optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr, weight_decay=0.0001)

def distill_step(pixel_values, labels):
    pixel_values = pixel_values.to(device)
    labels = labels.to(device)

    student_model.train()
    with torch.no_grad():
        teacher_logits = teacher_model(pixel_values).logits

    student_logits = student_model(pixel_values).logits

    loss_ce = F.cross_entropy(student_logits, labels)
    loss_kl = F.kl_div(
        F.log_softmax(student_logits / TEMPERATURE, dim=-1),
        F.softmax(teacher_logits / TEMPERATURE, dim=-1),
        reduction='batchmean'
    ) * (TEMPERATURE ** 2)

    loss = (1 - ALPHA) * loss_ce + ALPHA * loss_kl

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def evaluate(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)

            outputs = model(pixel_values)
            preds = torch.argmax(outputs.logits, dim=-1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return accuracy_score(all_labels, all_preds)

In [18]:
from tqdm import tqdm

for epoch in range(EPOCHS):
    total_loss = 0.0
    steps = len(train_loader)

    for step, batch in tqdm(enumerate(train_loader)):
        pixel_values = batch['pixel_values']
        labels = batch['label']

        loss = distill_step(pixel_values, labels)
        total_loss += loss

        if step % 100 == 0:
            print(f'Epoch {epoch+1} [{step+1}/{steps}], Loss: {loss:.4f}')

    avg_loss = total_loss / steps

    val_acc = evaluate(student_model, val_loader)

    print('---------------------------------------')
    print(f'epoch {epoch+1} done:')
    print(f'avg train loss: {avg_loss:.4f}')
    print(f'val acc: {val_acc:.4f}')
    print('---------------------------------------')

1it [00:00,  1.50it/s]

Epoch 1 [1/1563], Loss: 3.7554


101it [00:48,  2.20it/s]

Epoch 1 [101/1563], Loss: 3.8640


201it [01:33,  2.16it/s]

Epoch 1 [201/1563], Loss: 3.9795


301it [02:19,  2.21it/s]

Epoch 1 [301/1563], Loss: 3.8701


401it [03:04,  2.19it/s]

Epoch 1 [401/1563], Loss: 3.8706


501it [03:50,  2.18it/s]

Epoch 1 [501/1563], Loss: 3.9780


601it [04:35,  2.20it/s]

Epoch 1 [601/1563], Loss: 3.9530


701it [05:20,  2.20it/s]

Epoch 1 [701/1563], Loss: 3.8982


801it [06:06,  2.18it/s]

Epoch 1 [801/1563], Loss: 3.7942


901it [06:52,  2.20it/s]

Epoch 1 [901/1563], Loss: 3.8749


1001it [07:37,  2.21it/s]

Epoch 1 [1001/1563], Loss: 3.9899


1101it [08:22,  2.20it/s]

Epoch 1 [1101/1563], Loss: 3.7644


1201it [09:08,  2.18it/s]

Epoch 1 [1201/1563], Loss: 3.9115


1301it [09:53,  2.20it/s]

Epoch 1 [1301/1563], Loss: 4.0273


1401it [10:39,  2.20it/s]

Epoch 1 [1401/1563], Loss: 3.8793


1501it [11:24,  2.20it/s]

Epoch 1 [1501/1563], Loss: 3.8545


1563it [11:52,  2.19it/s]


---------------------------------------
epoch 1 done:
avg train loss: 3.8894
val acc: 0.0855
---------------------------------------


1it [00:00,  1.73it/s]

Epoch 2 [1/1563], Loss: 3.9410


101it [00:46,  2.20it/s]

Epoch 2 [101/1563], Loss: 3.9551


201it [01:31,  2.20it/s]

Epoch 2 [201/1563], Loss: 3.9900


256it [01:56,  2.19it/s]


KeyboardInterrupt: 