In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda
import torchvision

device = torch.device('cuda:0')

model1 = nn.Sequential(
    torchvision.models.resnet18(pretrained=True),
    nn.Linear(1000, 10),
)
model1.to(device)
model2 = nn.Sequential(
    torchvision.models.resnet50(pretrained=True),
    nn.Linear(1000, 10),
)
model2.to(device)

transforms = torchvision.transforms.Compose([
    #torchvision.transforms.Resize((64, 64)),
    torchvision.transforms.Grayscale(num_output_channels=3),
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

mnist_train = torchvision.datasets.MNIST('mnist_data', train=True, transform=transforms)
mnist_train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=16)

mnist_test = torchvision.datasets.MNIST('mnist_data', train=False, transform=transforms)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=True, num_workers=16)

In [4]:
model2.load_state_dict(torch.load('resnet50.pth'))

<All keys matched successfully>

In [4]:
from tqdm.notebook import tqdm
from utils import eval_accuracy, train_epoch

adam = optim.Adam(model2.parameters())

n_epochs = 10
for epoch in tqdm(range(n_epochs)):
    print(f"train acc: {train_epoch(model2, adam, mnist_train_loader)}")
    print(f"test acc: {eval_accuracy(model2, mnist_test_loader)}")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [5]:
from tqdm.notebook import tqdm
from utils import eval_accuracy, train_epoch, distillate_epoch

adam = optim.Adam(model1.parameters(), lr=1e-4)

n_epochs = 10
for epoch in tqdm(range(n_epochs)):
    print(f"transfer acc: {distillate_epoch(model1, model2, adam, mnist_train_loader)}")
    print(f"test acc: {eval_accuracy(model1, mnist_test_loader)}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

transfer acc: 0.9371833205223083
test acc: 0.960099995136261


  0%|          | 0/938 [00:00<?, ?it/s]

transfer acc: 0.9540166854858398
test acc: 0.96670001745224


  0%|          | 0/938 [00:00<?, ?it/s]

transfer acc: 0.9548666477203369
test acc: 0.9667999744415283


  0%|          | 0/938 [00:00<?, ?it/s]

transfer acc: 0.9564499855041504
test acc: 0.9596999883651733


  0%|          | 0/938 [00:00<?, ?it/s]

KeyboardInterrupt: 