In [1]:
import sys
import math
sys.path.append("..")

import torch

from algorithms.deepinversion import distill_deep_inversion
from algorithms.vanilla_distillation import distill
# LATER CHANGE TO SINGLE FUNCTION WITH METHODS FROM CONFIG!

from trainer import Trainer
from models import SimpleConvNet, SimpleResNet
from utils import get_cifar100_loader, get_cifar10_loader

device = "cuda"

%load_ext autoreload
%autoreload 2

In [2]:
teacher = SimpleResNet(n_blocks=4, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_reference = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)

student_vanilla = SimpleConvNet(n_layers=2, n_filters=16, input_shape=(3, 32, 32), n_classes=100).to(device)
student_deepinversion = SimpleConvNet(n_layers=2, n_filters=32, input_shape=(3, 32, 32), n_classes=100).to(device)

In [3]:
train_config = {
    'lr': 1e-3,
    'batch_size': 512,
    'max_iters': -1,
    'max_epochs': 8,
    'evaluate_every': 36,
    'log_every': 36,
    "optimizer": "AdamW",
    'verbose': False,
    'dataset_size': 50000
}

In [4]:
trainer = Trainer(train_config=train_config)

In [5]:
teacher = trainer.train(
    model=teacher,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    max_epochs=train_config["max_epochs"] // 4 # just to prevent overfitting
)
trainer.train_config["max_epochs"] = trainer.train_config["max_epochs"] * 4

Training iters:   0%|          | 0/100000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

In [6]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=teacher
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.6091}

In [7]:
student_reference = trainer.train(
    model=student_reference,
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    val_loader=get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"])
)

Training iters:   0%|          | 0/400000 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

In [8]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_reference
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.5391}

In [9]:
student_vanilla = distill(
    teacher=teacher,
    student=student_vanilla,
    optimizer=torch.optim.AdamW(student_vanilla.parameters(), lr=train_config["lr"], betas=(0.8, 0.9)),
    train_loader=get_cifar10_loader(data_path="../../data", train=True, batch_size=train_config["batch_size"]),
    test_loader=None,
    iterations=math.ceil(train_config['dataset_size'] / train_config["batch_size"]) * train_config["max_epochs"],
    test_freq=-1,
    alpha=0.6,
    T=2.5
)

Training - 784/784 [██████████████████████████████] ELP: 00:49, accuracy: 0.5010 - 


In [10]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_vanilla
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.5437}

In [None]:
student_deepinversion = distill_deep_inversion(
    teacher=teacher, 
    student=student_deepinversion,
    distill_config={
        "alpha": 0.6,
        "T": 2.5,
        "lr": 0.001
    },
    total_iterations=train_config["dataset_size"] * train_config["max_epochs"],
    distill_k_times=64,
    batch_size=256,
    deep_inversion_batch_size=2048,
    deep_inversion_iterations=100,
    n_classes=10,
    image_shape=(3, 32, 32),
    r_feature=0.01,
    first_bn_multiplier=10,
    tv_l1=0.0,
    tv_l2=0.0001,
    l2=0.00001,
    main_loss_multiplier=1.0,
    lr=0.03,
    adi_scale=1,
    device="cuda"
)

Pipeline iters:   0%|          | 0/781 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

Getting_images:   0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
trainer.evaluate(
    get_cifar10_loader(data_path="../../data", train=False, batch_size=train_config["batch_size"]),
    model=student_deepinversion
)

eval batch:   0%|          | 0/20 [00:00<?, ?it/s]

{'Accuracy': 0.0665}