## Experiment Notebook

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, random_split
import numpy as np
import random
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You are using device: %s" % device)

You are using device: cuda


### Load data

In [3]:
transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [4]:
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Resize((224, 224)),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

In [5]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_dataset_val = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 13040484.29it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [6]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_loader_val = DataLoader(train_dataset_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Load ResNet50 Teacher model

In [7]:
# teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
teacher = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 177MB/s]


In [8]:
# teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.fc = nn.Linear(teacher.fc.in_features, 10)
teacher = teacher.to(device)

In [9]:
for param in teacher.parameters():
    param.requires_grad = True

### Fine Tune Teacher on Cifar10

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.005, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
num_epochs = 10
train_loss = []
train_accuracy = []
test_loss = []
test_accuracy = []

for epoch in range(num_epochs):
    teacher.train()

    # Train
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = teacher(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluate
    # on training data without augment
    teacher.eval()
    current_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
      for data in train_loader_val:
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = teacher(inputs)
        loss = criterion(outputs, labels)
        current_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss.append(current_loss / len(train_loader))
    train_accuracy.append(100 * correct / total)
    print('[%d] train_loss: %.3f, train_accuracy: %.2f %%' %
     (epoch + 1, current_loss / len(train_loader), 100 * correct / total))

    # on testing data
    correct_test = 0
    total_test = 0
    test_current_loss = 0.0
    with torch.no_grad():
      for data in test_loader:
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = teacher(inputs)
        loss = criterion(outputs, labels)
        test_current_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()

    test_loss.append(test_current_loss / len(test_loader))
    test_accuracy.append(100 * correct_test / total_test)
    print('[%d] test_loss: %.3f, test_accuracy: %.2f %%' %
     (epoch + 1, test_current_loss / len(test_loader), 100 * correct_test / total_test))

    scheduler.step()



[1] train_loss: 0.349, train_accuracy: 88.17 %
[1] test_loss: 0.400, test_accuracy: 86.05 %
[2] train_loss: 0.275, train_accuracy: 90.58 %
[2] test_loss: 0.333, test_accuracy: 88.98 %
[3] train_loss: 0.246, train_accuracy: 91.65 %
[3] test_loss: 0.302, test_accuracy: 89.84 %
[4] train_loss: 0.213, train_accuracy: 92.60 %
[4] test_loss: 0.272, test_accuracy: 90.90 %
[5] train_loss: 0.182, train_accuracy: 93.78 %
[5] test_loss: 0.241, test_accuracy: 92.00 %
[6] train_loss: 0.095, train_accuracy: 96.91 %
[6] test_loss: 0.158, test_accuracy: 94.60 %
[7] train_loss: 0.082, train_accuracy: 97.34 %
[7] test_loss: 0.152, test_accuracy: 94.81 %
[8] train_loss: 0.074, train_accuracy: 97.66 %
[8] test_loss: 0.143, test_accuracy: 95.07 %
[9] train_loss: 0.066, train_accuracy: 97.87 %
[9] test_loss: 0.141, test_accuracy: 95.11 %
[10] train_loss: 0.063, train_accuracy: 97.98 %
[10] test_loss: 0.138, test_accuracy: 95.26 %
[11] train_loss: 0.060, train_accuracy: 98.08 %
[11] test_loss: 0.136, test_ac

In [None]:
plt.plot(train_loss, label='Training Loss')
plt.plot(test_loss, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.plot(train_accuracy, label='Training Accuracy')
plt.plot(test_accuracy, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Test Accuracy')
plt.legend()
plt.legend()
plt.grid()
plt.show()

In [None]:
# torch.save(teacher.state_dict(), 'teacher_finetuned_cifar10.pth')
from google.colab import drive
drive.mount('/content/drive')
torch.save(teacher.state_dict(), "/content/drive/MyDrive/Colab Notebooks/resnet18_finetuned_cifar10_v2.pth")