In [1]:
import os
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from Models import ResNet20, ResNet50

In [2]:
input_test = torch.randn([5,3,32,32])
resnet20 = ResNet20()
out = resnet20(input_test)
print(out.size())

resnet50 = ResNet50()
out = resnet50(input_test)
print(out.size())

torch.Size([5, 10])
torch.Size([5, 10])


## Define Loss

In [3]:
def studentLoss(teacher_pred, student_pred, targets, T, alpha):
    """
    Loss function for student network: Loss = alpha * (distillation loss with soft-target) + (1 - alpha) * (cross-entropy loss with true label)
    Return: loss
    """
    if alpha > 0:
        loss = F.kl_div(F.softmax(student_pred / T, dim=1), 
                        F.softmax(teacher_pred / T, dim=1), 
                        reduction='batchmean') * (T**2) * alpha + F.cross_entropy(student_pred, targets) * (1 - alpha)
    else:
        loss = F.cross_entropy(student_pred, targets)

    return loss

## Define Training and Testing Function

In [4]:
def distillationTraining(model, train_loader, save_name):
    print("Now T is {}".format(T))
    
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    
    best_val_acc = 0
    current_learning_rate = LEARNING_RATE
    softmax = torch.nn.Softmax(dim=1)
    print("==> Start training!")
    print("="*50)
    
    for i in range(0, EPOCHS):
        if i in DECAY_EPOCHS and i != 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to {}".format(current_learning_rate))
        
        model.train()
        print("Epoch {}".format(i))
        
        total_examples = 0
        correct_examples = 0
        train_loss = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            pred = model(inputs)
            if alpha > 0:
                with torch.no_grad():
                    teacher_pred = teacher_model(inputs)
            else:
                teacher_pred = 0
            train_loss = studentLoss(teacher_pred, pred, targets, T, alpha)
            
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            total_examples += inputs.shape[0]

            out = softmax(pred)
            out = torch.max(out, 1)

            correct_examples += torch.sum(targets==out[1]).cpu().data.numpy().tolist()
            
        avg_loss = train_loss / len(train_loader)
        avg_acc = correct_examples / total_examples
        print("Training loss: {}, training accuracy: {}".format(avg_loss, avg_acc))

        model.eval()

        total_examples = 0
        correct_examples = 0
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                pred = model(inputs)
                val_loss = criterion(pred, targets)
                total_examples += inputs.shape[0]
                out = softmax(pred)
                out = torch.max(out, 1)
                correct_examples += torch.sum(targets==out[1]).cpu().data.numpy().tolist()
                
        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Val loss: {}, val accuracy: {}".format(avg_loss, avg_acc))

        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            print("Saving ...")
            state = {'state_dict': model.state_dict(),
                      'epoch': i,
                      'lr': current_learning_rate}
            
        torch.save(state, os.path.join(CHECKPOINT_PATH, save_name))
        print('')

    print("="*50)
    print("==> Finished Training! The best accuracy is {}".format(best_val_acc))

In [5]:
def test_model(model):
    model.to(device)
    model.eval()

    total_examples = 0
    correct_examples = 0
    softmax = torch.nn.Softmax(dim=1)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            pred = model(inputs)
            total_examples += inputs.shape[0]

            out = softmax(pred)
            out = torch.max(out, 1)

            correct_examples += torch.sum(targets==out[1]).cpu().data.numpy().tolist()

    avg_acc = correct_examples / total_examples
    print("Total examples is {}, correct examples is {}; Test accuracy: {}".format(total_examples, correct_examples, avg_acc))

## Setup Training

In [6]:
DATA_ROOT = "./data"
CIFAR10_shape = (3, 32, 32)
pad_size = 2
BATCH_SIZE = 128


# Preprocessing
transform_train = transforms.Compose([torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                      torchvision.transforms.RandomCrop((32,32), padding=4),
                                      transforms.ToTensor(), 
                                      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_val = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010))])

train_CIFAR10 = torchvision.datasets.CIFAR10(root='./', train=True, download=True, transform = transform_train)

test_CIFAR10 = torchvision.datasets.CIFAR10(root='./', train=False, download=True, transform=transform_test)

num_train = int(1.0 * len(train_CIFAR10) * 95 / 100)
num_val = len(train_CIFAR10) - num_train
train_CIFAR10, val_CIFAR10 = torch.utils.data.random_split(train_CIFAR10, [num_train, num_val])

train_loader = DataLoader(
    train_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    test_CIFAR10, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
LEARNING_RATE = 0.01

WEIGHT_DECAY = 1e-4

MOMENTUM = 0.9

criterion = nn.CrossEntropyLoss()

EPOCHS = 150
DECAY_EPOCHS = [75,110]
DECAY = 0.1

CHECKPOINT_PATH = "./saved_model"

# for student model
T = 1
alpha = 0.1

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The model is deployed to", device)

# Teacher Model
teacher_model = ResNet20()
checkpoint = torch.load("./saved_model/resnet20_final.pth")
teacher_model.load_state_dict(checkpoint['state_dict'])
test_model(teacher_model)

# Student Model
student_model = ResNet50()

teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

The model is deployed to cuda
Total examples is 10000, correct examples is 8967; Test accuracy: 0.8967


In [21]:
distillationTraining(student_model, train_loader, 'reversed_distillation_resnet50.pth')

Now T is 1
==> Start training!
Epoch 0
Training loss: 0.0005555233219638467, training accuracy: 0.37509473684210526
Val loss: 0.07840928435325623, val accuracy: 0.4956
Saving ...

Epoch 1
Training loss: 0.0012481441954150796, training accuracy: 0.5819157894736842
Val loss: 0.0655677393078804, val accuracy: 0.564
Saving ...

Epoch 2
Training loss: -0.0003462386957835406, training accuracy: 0.668821052631579
Val loss: 0.05585915967822075, val accuracy: 0.6504
Saving ...

Epoch 3
Training loss: -0.0004385567153804004, training accuracy: 0.7205263157894737
Val loss: 0.06990563124418259, val accuracy: 0.6536
Saving ...

Epoch 4
Training loss: -0.0005477671511471272, training accuracy: 0.7549473684210526
Val loss: 0.03308708965778351, val accuracy: 0.7312
Saving ...

Epoch 5
Training loss: -2.280350418004673e-05, training accuracy: 0.7761263157894737
Val loss: 0.04936039447784424, val accuracy: 0.7304

Epoch 6
Training loss: -0.0010279422858729959, training accuracy: 0.7948631578947368
Val l

In [24]:
reverse_distilled_model = ResNet50()
checkpoint = torch.load("./saved_model/reversed_distillation_resnet50_final.pth")
reverse_distilled_model.load_state_dict(checkpoint['state_dict'])

test_model(reverse_distilled_model)

Total examples is 10000, correct examples is 9128; Test accuracy: 0.9128
