# Teacher Student Network Research - learning at layer 6

##### Teacher Student Network Research
Framework adapted from Official Pytorch Knowledge Distillation Tutorial

Author: 
Asad Amiruddin, 
Harrison Maximillian Rush, 
Huy N Ho

### Import library, datasets, loaders

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from time import time
from torchvision import models

import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')
  
# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
 
# Dataloaders

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


### Define train and test function

In [3]:
def train(model, train_loader, epochs, learning_rate, device):
    start = time()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
    end = time()
    runtime = end - start
    print(f"Training Time: {runtime:.3f}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy
 

### Define deeper neural networks to be used as teachers. 
Can have multiple teachers for comparison/experiments


### Load resnet50 model with finetuned weight as another teacher

In [4]:
teacher_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
teacher_resnet50.fc = nn.Linear(teacher_resnet50.fc.in_features, 10)
teacher_resnet50 = teacher_resnet50.to(device) 
# teacher_resnet50.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/teacher_resnet50.pth",map_location=device  )) 
teacher_resnet50.load_state_dict(torch.load("./trained_model/teacher_resnet50.pth",map_location=device)) 
test_accuracy_teacher = test(teacher_resnet50, test_loader, device)



Test Accuracy: 95.68%


### Define student network


In [25]:
# Define the student model
 

### Load resnet18 model with finetuned weight as a student
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet18.fc = nn.Linear(resnet18.fc.in_features, 10)
resnet18 = resnet18.to(device)
# teacher_resnet18.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/teacher_resnet18.pth",map_location=device  ))
resnet18.load_state_dict(torch.load("./trained_model/teacher_resnet18.pth", map_location=device))
# test_accuracy_teacher = test(teacher_resnet50, test_loader, device)



# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Instantiate the model
# studentNN = studentNN().to(device)

test_accuracy_resnet18_b4_learning = test(resnet18, test_loader, device)

### Define knowledge distillation function

In [4]:

def train_knowledge_distillation_inter_layer(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device,intermediate_layer_index=6):
    print('Knowledge distillation - intermediate layers training')
    start = time()
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)
           
            # Forward pass through teacher model
            with torch.no_grad():
                outputs_teacher = teacher(inputs)
                teacher_intermediate_output = outputs_teacher[intermediate_layer_index]  # Get intermediate layer output
            
            # Forward pass through student model
            outputs_student = student(inputs)
            student_intermediate_output = outputs_student[intermediate_layer_index]  # Get intermediate layer output
            
            # Compute distillation loss
            loss = nn.MSELoss()(teacher_intermediate_output, student_intermediate_output)  

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
    
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
    end = time()
    runtime = end - start
    print(f"Training Time: {runtime:.3f}")

In [None]:
 
train_knowledge_distillation_inter_layer(teacher=teacher_resnet50, student=resnet18, 
                                 train_loader=train_loader, epochs=10, learning_rate=0.001, 
                                 T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)

test_accuracy_learning_student = test(resnet18, test_loader, device) 

# Compare the student test accuracy with and without the teacher, after distillation
# Result shouldn't be stellar because teacher's prediction can't beat ground truth here
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"resnet18 prior to transfer: {test_accuracy_resnet18_b4_learning:.2f}%")
print(f"resnet18 accuracy after intermediate layer knowledge distillation: {resnet18:.2f}%")

### Save trained models - only run after training on Colab

In [61]:
  
torch.save(resnet18.state_dict(), "/content/drive/MyDrive/Colab Notebooks/resnet18_after_intermediate_layer_transfer.pth") # student after KD
