In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from eval_cifar10 import Eval
from models import resnet18, resnext101_64x4d

In [2]:
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

batch_size = 256

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

evaluator = Eval(batch_size=batch_size)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = resnet18(pretrained=True, download=True) # 訓練済み
model.to(device)
model.eval()
evaluator.eval(model, device)
# 普通に訓練した ResNet18: CIFAR-10 精度 92.95 %

0.9295

In [4]:
model = resnet18(pretrained=False) # 生徒モデル
model.to(device)
evaluator.eval(model, device)
# 訓練前の ResNet18: CIFAR-10 精度 10.00 %



0.0995

In [5]:
teacher = resnext101_64x4d(pretrained=True, download=True) # 教師モデル
teacher.to(device)
teacher.eval()
evaluator.eval(teacher, device)
# 普通に訓練した ResNext101: CIFAR-10 精度 94.07 %

0.9407

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.01, nesterov=True)

epoch = 100
temperature = 10
lam = 0.5

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch)

In [7]:
for epoch in range(epoch):
    model.train()
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        with torch.no_grad():
            outputs_teacher = teacher(inputs).detach()

        loss_distill = (
            F.kl_div(
                F.log_softmax(outputs / temperature, dim=1),
                F.softmax(outputs_teacher / temperature, dim=1),
                reduction="batchmean",
            )
            * temperature
            * temperature
        )

        loss = (1 - lam) * loss + lam * loss_distill

        loss.backward()
        optimizer.step()

    lr_scheduler.step()

    acc = evaluator.eval(model, device)

    print(f"Epoch {epoch + 1}, Accuracy: {acc}")

Epoch 1, Accuracy: 0.4753
Epoch 2, Accuracy: 0.5243
Epoch 3, Accuracy: 0.5632
Epoch 4, Accuracy: 0.6425
Epoch 5, Accuracy: 0.5799
Epoch 6, Accuracy: 0.6533
Epoch 7, Accuracy: 0.6964
Epoch 8, Accuracy: 0.6564
Epoch 9, Accuracy: 0.7242
Epoch 10, Accuracy: 0.726
Epoch 11, Accuracy: 0.7303
Epoch 12, Accuracy: 0.7408
Epoch 13, Accuracy: 0.7638
Epoch 14, Accuracy: 0.7352
Epoch 15, Accuracy: 0.7609
Epoch 16, Accuracy: 0.7016
Epoch 17, Accuracy: 0.6797
Epoch 18, Accuracy: 0.7139
Epoch 19, Accuracy: 0.7851
Epoch 20, Accuracy: 0.7824
Epoch 21, Accuracy: 0.7633
Epoch 22, Accuracy: 0.8087
Epoch 23, Accuracy: 0.7532
Epoch 24, Accuracy: 0.7615
Epoch 25, Accuracy: 0.7531
Epoch 26, Accuracy: 0.7296
Epoch 27, Accuracy: 0.7637
Epoch 28, Accuracy: 0.7652
Epoch 29, Accuracy: 0.7106
Epoch 30, Accuracy: 0.7749
Epoch 31, Accuracy: 0.8152
Epoch 32, Accuracy: 0.7893
Epoch 33, Accuracy: 0.7773
Epoch 34, Accuracy: 0.7849
Epoch 35, Accuracy: 0.7568
Epoch 36, Accuracy: 0.7988
Epoch 37, Accuracy: 0.8108
Epoch 38, A

In [8]:
evaluator.eval(model, device)
# ResNext101 から蒸留した ResNet18: CIFAR-10 精度 93.32 %

0.932

In [9]:
torch.save(model.state_dict(), "resnet18_cifar10_distil_from_resnext101.pth")