<a href="https://colab.research.google.com/github/bec2148/mnist-finalproject/blob/main/MNIST_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
## Distillation
# Base code from https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import time


NUM_EPOCHS = 1

# Check if GPU is available, and if not, use the CPU
print("torch.cuda.is_available()", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading the MNIST dataset:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

torch.cuda.is_available() False


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
## Distillation
num_classes = 10
# neural network class to be used as teacher:
class TeacherNN(nn.Module):
    def __init__(self, num_classes = 10):
        super(TeacherNN, self).__init__()
        self.features = nn.Sequential(
          # Convolution Layer 1
          nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 2
          nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Convolution Layer 3
          nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 4
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Flatten the output
          nn.Flatten(),
        )

        self.classifier = nn.Sequential(
          # Fully Connected Layer 5
          nn.Linear(in_features=64 * 4 * 4, out_features=512),  # Input size from final convolution output
          nn.BatchNorm1d(512),  # Normalize FCN output
          nn.ReLU(inplace=True),  # Activation function

          # Fully Connected Layer 6
          nn.Dropout(0.2),  # Dropout with 20%
          nn.Linear(in_features=512, out_features=10),  # Final output layer with 10 nodes (one for each class)
          nn.Softmax(dim=1)  # Softmax activation for classification
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class StudentNN(nn.Module):
    def __init__(self, num_classes=10, conv2_feature_count = 4, linear_feature_count = 4):
        super(StudentNN, self).__init__()
        self.features = nn.Sequential(
            # Convolution Layer 1
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),

            # Convolution Layer 2
            nn.Conv2d(in_channels=8, out_channels=conv2_feature_count, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(conv2_feature_count),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),

            # Flatten the output
            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            # Fully Connected Layer 3
            ## changed in_features=32 * 6 * 6 to in_features=16 * 12 * 12
            nn.Linear(in_features=conv2_feature_count * 12 * 12, out_features=linear_feature_count),  # changed 128 --> linear_feature_count
            nn.BatchNorm1d(linear_feature_count),
            nn.ReLU(inplace=True),

            # Fully Connected Layer 4
            nn.Dropout(0.2),
            nn.Linear(in_features=linear_feature_count, out_features=num_classes),  # Output layer
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [3]:
## Distillation

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [4]:
## Distillation

def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [5]:
## Distillation

torch.manual_seed(42)
nn_teacher = TeacherNN(num_classes=10).to(device)
train(nn_teacher, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
test_accuracy_teacher = test(nn_teacher, test_loader, device)

# Instantiate the Student network:
torch.manual_seed(42)
nn_student = StudentNN(num_classes=10, conv2_feature_count=4, linear_feature_count=4).to(device)

Epoch 1/1, Loss: 1.5091872396667798
Test Accuracy: 98.60%


In [6]:
## Distillation
torch.manual_seed(42)
new_nn_student = StudentNN(num_classes=10).to(device)

In [7]:
## Distillation

## Print the total number of parameters in each model:

total_params_teacher = "{:,}".format(sum(p.numel() for p in nn_teacher.parameters()))
print(f"TeacherNN parameters: {total_params_teacher}")
total_params_student = "{:,}".format(sum(p.numel() for p in nn_student.parameters()))
print(f"StudentNN parameters: {total_params_student}")

TeacherNN parameters: 596,330
StudentNN parameters: 2,762


In [8]:
## Distillation

## Train and test the lightweight network with cross entropy loss:

start_student_train = time.perf_counter()
###train(nn_student, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
end_student_train = time.perf_counter()
print(f"Student training time: {end_student_train - start_student_train:0.4f} seconds")

test_accuracy_student_ce = test(nn_student, test_loader, device)

Epoch 1/1, Loss: 1.9836894825696945
Student training time: 27.3455 seconds
Test Accuracy: 85.33%


In [9]:
## Distillation

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy: {test_accuracy_student_ce:.2f}%")

Teacher accuracy: 98.60%
Student accuracy: 85.33%


In [10]:
## Distillation loss is calculated from the logits of the networks. It only returns gradients to the student:

# Apply ``train_knowledge_distillation`` with a temperature of 2, to smooth the probability curve and
# (hopefully) provide richer information about the teacher's behavior regarding the other 9/10 classes not selected each time
start_student_distill = time.perf_counter()
###train_knowledge_distillation(teacher=nn_teacher, student=new_nn_student, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
end_student_distill = time.perf_counter()
print(f"Student training time: {end_student_distill - start_student_distill:0.4f} seconds")

test_accuracy_student_ce_and_kd = test(new_nn_student, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with knowledge distillation: {test_accuracy_student_ce_and_kd:.2f}%")

Epoch 1/1, Loss: 1.5168607250849406
Student training time: 41.3128 seconds
Test Accuracy: 87.48%
Teacher accuracy: 98.60%
Student accuracy without teacher: 85.33%
Student accuracy with knowledge distillation: 87.48%


In [11]:
for i in range(2):
  trained_nn_student = StudentNN(num_classes=10, conv2_feature_count=4+i, linear_feature_count=4+i).to(device)
  start_student_distill = time.perf_counter()
  train_knowledge_distillation(teacher=nn_teacher, student=trained_nn_student, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
  end_student_distill = time.perf_counter()
  total_params_student = "{:,}".format(sum(p.numel() for p in trained_nn_student.parameters()))
  print(f"TeacherNN parameters: {total_params_teacher}")
  print(f"StudentNN parameters: {total_params_student}")
  print(f"Student training time: {end_student_distill - start_student_distill:0.4f} seconds")
  test_accuracy_student_kd = test(trained_nn_student, test_loader, device)
  print(f"Student accuracy with knowledge distillation: {test_accuracy_student_kd:.2f}%\n")


Epoch 1/1, Loss: 1.5045831643342973
TeacherNN parameters: 596,330
StudentNN parameters: 2,762
Test Accuracy: 85.82%
Student accuracy with knowledge distillation: 85.82%

Epoch 1/1, Loss: 1.4466477553844452
TeacherNN parameters: 596,330
StudentNN parameters: 4,146
Test Accuracy: 92.82%
Student accuracy with knowledge distillation: 92.82%

