In [None]:
import sys
import math
sys.path.append("..")

import torch

# from algorithms.deepinversion import distill_deep_inversion
# from algorithms.vanilla_distillation import distill
from distillation import run_distillation

from trainer import Trainer
from models import SimpleConvNet, SimpleResNet
from utils import get_cifar100_loader, get_cifar10_loader

device = "cuda"

%load_ext autoreload
%autoreload 2

In [2]:
teacher = SimpleResNet(n_blocks=4, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_reference = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)

student_vanilla = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_deepinversion = SimpleConvNet(n_layers=2, n_filters=32, input_shape=(3, 32, 32), n_classes=100).to(device)

In [None]:
train_config = {
    'lr': 1e-3,
    'batch_size': 512,
    'max_iters': -1,
    'max_epochs': 8,
    'evaluate_every': 36,
    'log_every': 36,
    "optimizer": "AdamW",
    'verbose': False,
    'dataset_size': 50000
}

In [4]:
trainer = Trainer(train_config=train_config)

In [5]:
teacher = trainer.train(
    model=teacher,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    max_epochs=train_config["max_epochs"] // 4 # just to prevent overfitting
)
trainer.train_config["max_epochs"] = trainer.train_config["max_epochs"] * 4

Training iters:   0%|          | 0/100000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.383e+00, Last validation accuracy: 0.4125, Epoch: 0.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.209e+00, Last validation accuracy: 0.5388, Epoch: 0.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.018e+00, Last validation accuracy: 0.5916, Epoch: 1.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 9.051e-01, Last validation accuracy: 0.5890, Epoch: 1.737
Loss: 8.518e-01, Last validation accuracy: 0.5890, Epoch: 2.000


In [6]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=teacher
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.6102}

In [7]:
student_reference = trainer.train(
    model=student_reference,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"])
)

Training iters:   0%|          | 0/400000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.755e+00, Last validation accuracy: 0.3755, Epoch: 0.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.610e+00, Last validation accuracy: 0.4225, Epoch: 0.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.546e+00, Last validation accuracy: 0.4406, Epoch: 1.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.479e+00, Last validation accuracy: 0.4600, Epoch: 1.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.386e+00, Last validation accuracy: 0.4851, Epoch: 2.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.366e+00, Last validation accuracy: 0.4855, Epoch: 2.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.303e+00, Last validation accuracy: 0.5080, Epoch: 3.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.296e+00, Last validation accuracy: 0.4999, Epoch: 3.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.265e+00, Last validation accuracy: 0.5179, Epoch: 4.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.258e+00, Last validation accuracy: 0.5101, Epoch: 4.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.235e+00, Last validation accuracy: 0.5244, Epoch: 5.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.229e+00, Last validation accuracy: 0.5188, Epoch: 5.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.208e+00, Last validation accuracy: 0.5307, Epoch: 6.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.202e+00, Last validation accuracy: 0.5229, Epoch: 6.737


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.180e+00, Last validation accuracy: 0.5362, Epoch: 7.369


eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

Loss: 1.172e+00, Last validation accuracy: 0.5277, Epoch: 7.737
Loss: 1.163e+00, Last validation accuracy: 0.5277, Epoch: 8.000


In [8]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_reference
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.5391}

In [None]:
student_vanilla = run_distillation(
    algorithm_name="vanilla",
    teacher=teacher,
    student=student_vanilla,
    config={
        "optimizer": torch.optim.AdamW(student_vanilla.parameters(), lr=train_config["lr"], betas=(0.8, 0.9)),
        "train_loader": get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
        "iterations": math.ceil(train_config['dataset_size'] / train_config["batch_size"]) * train_config["max_epochs"],
    }
)

In [1]:
# student_vanilla = distill(
#     teacher=teacher,
#     student=student_vanilla,
#     optimizer=torch.optim.AdamW(student_vanilla.parameters(), lr=train_config["lr"], betas=(0.8, 0.9)),
#     train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
#     test_loader=None,
#     iterations=math.ceil(train_config['dataset_size'] / train_config["batch_size"]) * train_config["max_epochs"],
#     test_freq=-1,
#     alpha=0.6,
#     T=2.5
# )

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_vanilla
)

In [None]:
student_deepinversion = run_distillation(
    algorithm_name="deepinversion",
    teacher=teacher,
    student=student_vanilla,
    config={
        "total_iterations": train_config["dataset_size"] * train_config["max_epochs"],
        "distill_k_times": 16,
        "batch_size": 128,
        "deep_inversion_batch_size": 1024,
        "n_classes": 10,
    }
)

In [2]:
# student_deepinversion = distill_deep_inversion(
#     teacher=teacher, 
#     student=student_deepinversion,
#     distill_config={
#         "alpha": 0.6,
#         "T": 2.5,
#         "lr": 0.001
#     },
#     total_iterations=train_config["dataset_size"] * train_config["max_epochs"],
#     distill_k_times=16,
#     batch_size=128,
#     deep_inversion_batch_size=1024,
#     deep_inversion_iterations=200,
#     n_classes=100,
#     image_shape=(3, 32, 32),
#     r_feature=0.01,
#     first_bn_multiplier=10,
#     tv_l1=0.0,
#     tv_l2=0.0001,
#     l2=0.00001,
#     main_loss_multiplier=1.0,
#     lr=0.03,
#     adi_scale=1,
#     device="cuda"
# )

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_deepinversion
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.0665}