In [1]:
import sys
import math
sys.path.append("..")

import torch

# from algorithms.deepinversion import distill_deep_inversion
# from algorithms.vanilla_distillation import distill
from distillation import run_distillation

from trainer import Trainer
from models import SimpleConvNet, SimpleResNet
from utils import get_cifar100_loader, get_cifar10_loader

device = "cuda"

%load_ext autoreload
%autoreload 2

In [2]:
teacher = SimpleResNet(n_blocks=4, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_reference = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)

student_vanilla = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_deepinversion = SimpleConvNet(n_layers=2, n_filters=32, input_shape=(3, 32, 32), n_classes=100).to(device)
student_dfad = SimpleConvNet(n_layers=2, n_filters=32, input_shape=(3, 32, 32), n_classes=100).to(device)
student_spec_features = SimpleConvNet(n_layers=2, n_filters=32, input_shape=(3, 32, 32), n_classes=100).to(device)

In [3]:
train_config = {
    'lr': 1e-3,
    'batch_size': 512,
    'max_iters': -1,
    'max_epochs': 8,
    'evaluate_every': 36,
    'log_every': 36,
    "optimizer": "AdamW",
    'verbose': False,
    'dataset_size': 50000
}

In [4]:
trainer = Trainer(train_config=train_config)

In [None]:
teacher = trainer.train(
    model=teacher,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    max_epochs=train_config["max_epochs"] // 4 # just to prevent overfitting
)
trainer.train_config["max_epochs"] = trainer.train_config["max_epochs"] * 4

Training iters:   0%|          | 0/100000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=teacher
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.612}

In [None]:
student_reference = trainer.train(
    model=student_reference,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"])
)

Training iters:   0%|          | 0/400000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_reference
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.5452}

In [None]:
student_vanilla = run_distillation(
    algorithm_name="vanilla",
    teacher=teacher,
    student=student_vanilla,
    config={
        "optimizer": torch.optim.AdamW(student_vanilla.parameters(), lr=train_config["lr"], betas=(0.8, 0.9)),
        "train_loader": get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
        "iterations": math.ceil(train_config['dataset_size'] / train_config["batch_size"]) * train_config["max_epochs"],
    }
)

[INFO] Running algorithm: vanilla
[INFO] Using arguments for 'vanilla': ['teacher', 'student', 'test_loader', 'optimizer', 'train_loader', 'iterations']


Training - 784/784 [██████████████████████████████] ELP: 00:49, accuracy: 0.5171 - 


In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_vanilla
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.5514}

In [None]:
student_deepinversion = run_distillation(
    algorithm_name="deepinversion",
    teacher=teacher,
    student=student_deepinversion,
    config={
        "total_iterations": train_config["dataset_size"] * train_config["max_epochs"],
        "distill_k_times": 16,
        "batch_size": 128,
        "deep_inversion_batch_size": 1024,
        "n_classes": 10,
    }
)

[INFO] Running algorithm: deepinversion
[INFO] Using arguments for 'deepinversion': ['teacher', 'student', 'total_iterations', 'distill_k_times', 'batch_size', 'deep_inversion_batch_size', 'n_classes']




Pipeline iters:   0%|          | 0/390 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_deepinversion
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.0665}

In [None]:
student_dfad = run_distillation(
    algorithm_name="dfad",
    teacher=teacher,
    student=student_dfad,
    config={
        ...
    }
)

In [None]:
student_spec_features = run_distillation(
    algorithm_name="?????",
    teacher=teacher,
    student=student_spec_features,
    config={
        ...
    }
)