In [None]:
!pip install timm detectors wandb

In [2]:
import datetime as dt

import detectors
import timm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR100, ImageFolder
import torchvision.transforms as transforms

from tqdm import tqdm
import wandb



In [3]:
def save_model(path, num_epochs, model, optimizer, scheduler=None):
    '''Save on GPU'''
    data = {
        'num_epochs': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None
    }
    torch.save(data, path)


def load_model(path, device, model, optimizer=None, scheduler=None):
    '''Load on GPU'''
    data = torch.load(path)
    model.load_state_dict(data['model_state_dict'])
    model.to(device)
    if optimizer is not None:
        optimizer.load_state_dict(data['optimizer_state_dict'])
    if scheduler is not None:
        scheduler.load_state_dict(data['scheduler_state_dict'])
    return data['num_epochs']


@torch.no_grad()
def validation(model, test_loader, device):
    model.eval()

    val_loss, val_acc, test_set_size = 0.0, 0.0, 0
    for batch, labels in tqdm(test_loader):
        batch = batch.to(device)
        labels = labels.to(device)
        preds = model(batch)
        
        val_loss += F.cross_entropy(preds, labels) * len(batch)
        val_acc += (preds.argmax(dim=1) == labels).sum()
        test_set_size += len(batch)

    return {
        'student_test_loss': val_loss / test_set_size, 
        'student_test_acc': val_acc / test_set_size,
        'teacher_test_acc': 0.7926
    }


def distill(teacher, student, train_loader, test_loader, kd_loss, optimizer, scheduler, 
            n_epochs, valid_period, save_period, temp, device, wandb_init_data):
    with wandb.init(**wandb_init_data) as run:    
        print(f'Training started: {dt.datetime.now()}')
        for epoch in range(n_epochs):
            student.train()
            for batch, labels in train_loader:
                batch = batch.to(device)
                
                optimizer.zero_grad()

                with torch.inference_mode():
                    teacher_predictions = teacher(batch)
                student_predictions = student(batch)
                loss = kd_loss(F.log_softmax(student_predictions / temp, dim=1),
                            F.softmax(teacher_predictions / temp, dim=1))
                loss.backward()
                optimizer.step()

            if (epoch + 1) % valid_period == 0:
                print(f'{epoch + 1} training epochs finished\nValidation started: {dt.datetime.now()}')
                with torch.inference_mode():
                    wandb_metrics_value = validation(student, test_loader, device)
                    wandb.log(wandb_metrics_value)
                print(f'\nValidation finished: {dt.datetime.now()}')
            
            if (epoch + 1) % save_period == 0:
                ckpt_filename = f'{epoch + 1}epochs.pt'
                save_model(ckpt_filename, epoch + 1, student, optimizer, scheduler)
                wandb.save(ckpt_filename)
                print(f'\nCheckpoint saved after {epoch + 1} epochs\n')

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
teacher = timm.create_model("resnet18_cifar100", pretrained=True).to(device)
teacher.eval()
student = timm.create_model("resnet18_cifar100", pretrained=False).to(device)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.2010])
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                         std=[0.2675, 0.2565, 0.2761])
])

batch_size = 512
num_workers = 2

path_to_sigle_image_dataset = '/kaggle/input/ameyoko'
single_image_dataset = ImageFolder(path_to_sigle_image_dataset, transform=train_transform)
cifar_train_dataset = CIFAR100('.', train=True, transform=train_transform, download=True)
test_dataset = CIFAR100('.', train=False, transform=test_transforms, download=True)

single_image_loader = DataLoader(single_image_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
cifar_train_loader = DataLoader(cifar_train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
kd_loss = nn.KLDivLoss(reduction="batchmean")
scheduler = None

temp = 8.0

Downloading: "https://huggingface.co/edadaltocg/resnet18_cifar100/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet18_cifar100.pth
100%|██████████| 42.9M/42.9M [00:00<00:00, 76.0MB/s]


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:02<00:00, 78087317.52it/s]


Extracting ./cifar-100-python.tar.gz to .
Files already downloaded and verified


# ameyoko distillation

In [None]:
n_epochs = 300
valid_period = 1
save_period = 20

wandb_init_data = {
    'project': 'one_image_distillation',
    'name': 'distillation on ameyoko',
    'save_code': True,
    'config': {
        'model': 'ResNet18',
        'optimizer': optimizer,
        'scheduler': scheduler,
        'valid_period': valid_period,
        'dataset': 'ameyoko',
        'num_epochs': n_epochs,
        'dataloader_num_workers': num_workers,
    }
}

distill(teacher, student, single_image_loader, test_loader, kd_loss, optimizer, scheduler, 
        n_epochs, valid_period, save_period, temp, device, wandb_init_data)

# cifar-train-split distillation

In [None]:
n_epochs = 300
valid_period = 1
save_period = 20

wandb_init_data = {
    'project': 'one_image_distillation',
    'name': 'distillation on cifar-train-split',
    'save_code': True,
    'config': {
        'model': 'ResNet18',
        'optimizer': optimizer,
        'scheduler': scheduler,
        'valid_period': valid_period,
        'dataset': 'cifar-train-split',
        'num_epochs': n_epochs,
        'dataloader_num_workers': num_workers,
    }
}

distill(teacher, student, cifar_train_loader, test_loader, kd_loss, optimizer, scheduler, 
        n_epochs, valid_period, save_period, temp, device, wandb_init_data)

# resuming run

In [None]:
project_name = 'one_image_distillation'
run_id = 'i0yw0j67'
n_epochs = 1000
valid_period = 1
save_period = 20

# load checkpoint from wandb
last_ckpt = wandb.restore('300epochs.pt', run_path=f"nik-fedorov/{project_name}/{run_id}")
load_model(last_ckpt.name, device, student, optimizer)

# set wandb_init_data for resuming
wandb_init_data = {
    'project': project_name,
    'id': run_id,
    'resume': 'must',
    'save_code': True
}

# resume training
distill(teacher, student, single_image_loader, test_loader, kd_loss, optimizer, scheduler, 
        n_epochs, valid_period, save_period, temp, device, wandb_init_data)