In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.resnet import ResNet, BasicBlock, ResNet18_Weights

In [2]:
# define a simple mlp classfier for cifar 10
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(3*32*32, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 3*32*32)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
# define a simple cnn classfier for cifar 10
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*8*8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 64*8*8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class ResNet18(ResNet):
    """Attention maps of ResNet-18.
    
    Overloaded ResNet model to return attention maps.
    """
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        g0 = self.layer1(x)
        g1 = self.layer2(g0)
        g2 = self.layer3(g1)
        g3 = self.layer4(g2)

        x = self.avgpool(g3)
        emb = torch.flatten(x, 1)
        x = self.fc(emb)
        
        return emb, x
    
mlp_clf = MLP().to('cuda')
# print the total number of trainable parameters
print('Total number of trainable parameters:', sum(p.numel() for p in mlp_clf.parameters() if p.requires_grad))

cnn_clf = CNN().to('cuda')
# print the total number of trainable parameters
print('Total number of trainable parameters:', sum(p.numel() for p in cnn_clf.parameters() if p.requires_grad))

baseresnet18 = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT).to('cuda')
res_clf = ResNet18(BasicBlock, [2, 2, 2, 2]).to('cuda')
res_clf.load_state_dict(baseresnet18.state_dict(), strict=False)
res_clf.fc = nn.Linear(512, 10).to('cuda')
print('Total number of trainable parameters:', sum(p.numel() for p in res_clf.parameters() if p.requires_grad))


Total number of trainable parameters: 1707274
Total number of trainable parameters: 2122186
Total number of trainable parameters: 11181642


In [3]:
# load the cifar 10 dataset
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

In [58]:
# define the loss function and optimizer
def train(model):
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 14], gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    best_accuracy = 0.0
    best_model_wts = model.state_dict()

    for epoch in range(20):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            optimizer.zero_grad()
            emb, outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
        print('Learning rate:', optimizer.param_groups[0]['lr'])
        scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to('cuda'), labels.to('cuda')
                _, outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_wts = model.state_dict()

    model.load_state_dict(best_model_wts)
    return model

teacher_model = train(res_clf)

[Epoch 1] loss: 0.598
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 72.29%
[Epoch 2] loss: 0.397
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 71.76%
[Epoch 3] loss: 0.372
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 74.47%
[Epoch 4] loss: 0.135
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 79.30%
[Epoch 5] loss: 0.050
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 79.30%
[Epoch 6] loss: 0.020
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 79.12%
[Epoch 7] loss: 0.011
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 78.90%
[Epoch 8] loss: 0.009
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 79.10%
[Epoch 9] loss: 0.009
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 79.33%
[Epoch 10] loss: 0.013
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 78.79%
[E

In [59]:
# define a simple cnn classfier for cifar 10
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*8*8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 64*8*8)
        emb = self.fc1(x)
        x = self.fc2(emb)
        return emb, x

In [60]:
cnn_clf = CNN().to('cuda')
train(cnn_clf)

[Epoch 1] loss: 1.958
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 37.54%
[Epoch 2] loss: 1.626
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 40.46%
[Epoch 3] loss: 1.562
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 43.93%
[Epoch 4] loss: 1.401
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 48.92%
[Epoch 5] loss: 1.361
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 49.34%
[Epoch 6] loss: 1.343
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 50.27%
[Epoch 7] loss: 1.328
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 51.53%
[Epoch 8] loss: 1.317
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 50.98%
[Epoch 9] loss: 1.309
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 51.40%
[Epoch 10] loss: 1.298
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 52.09%
[E

CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)

In [75]:
# train the student model so that it can mimic the teacher model
student_model = None
student_model = CNN().to('cuda')
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.01)
# setup a cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 14], gamma=0.1)

ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

# initialize the student model weights for better convergence


def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

init_weights(student_model)

In [77]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.01)
# setup a cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 14], gamma=0.1)

teacher_model.eval()
for epoch in range(20):
    student_model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        student_emb, student_outputs = student_model(inputs)
        teacher_emb, teacher_outputs = teacher_model(inputs)
        student_outputs = F.softmax(student_outputs, dim=1)
        teacher_outputs = F.softmax(teacher_outputs, dim=1)
        ce_loss_value = ce_loss(student_outputs, teacher_outputs)
        mse_loss_value = mse_loss(student_emb, teacher_emb)
        label_loss = ce_loss(student_outputs, labels)
        loss = ce_loss_value + label_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
    print('Learning rate:', optimizer.param_groups[0]['lr'])
    scheduler.step()

    student_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to('cuda'), labels.to('cuda')
            _, outputs = student_model(images)
            predicted = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')


[Epoch 1] loss: 4.609
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 2] loss: 4.607
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 3] loss: 4.607
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 4] loss: 4.607
Learning rate: 0.01
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 5] loss: 4.605
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 6] loss: 4.605
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 10.00%
[Epoch 7] loss: 4.605
Learning rate: 0.001
Accuracy of the network on the 10000 test images: 10.00%


KeyboardInterrupt: 