<a href="https://colab.research.google.com/github/bec2148/mnist-finalproject/blob/main/MNIST_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [404]:
from keras.datasets import mnist        # MNIST dataset is included in Keras
from keras.models import Sequential     # Model type to be used

from keras.layers import Dense, Dropout, Activation # Types of layers to be used in our model
from keras import utils                 # Keras utilities

In [405]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, Flatten
from tensorflow.keras.layers import BatchNormalization

In [406]:
TRAINING_SIZE = 60000

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[:TRAINING_SIZE,:,:]
y_train = y_train[:TRAINING_SIZE]
print("X_train shape", X_train.shape)
print("y_train shape", y_train.shape)

X_train shape (60000, 28, 28)
y_train shape (60000,)


In [407]:
# We do not flatten each image into a 784-length vector because we want to perform convolutions first

X_train = X_train.reshape(TRAINING_SIZE, 28, 28, 1) #add an additional dimension to represent the single-channel
X_test = X_test.reshape(10000, 28, 28, 1)

X_train = X_train.astype('float32')         # change integers to 32-bit floating point numbers
X_test = X_test.astype('float32')

X_train /= 255                              # normalize each value for each pixel for the entire vector for each input
X_test /= 255

print("Training matrix shape", X_train.shape)
print("Testing matrix shape", X_test.shape)

Training matrix shape (60000, 28, 28, 1)
Testing matrix shape (10000, 28, 28, 1)


In [408]:
# one-hot format classes

nb_classes = 10 # number of unique digits

Y_train = utils.to_categorical(y_train, nb_classes)
Y_test = utils.to_categorical(y_test, nb_classes)

In [409]:
model = Sequential()                                 # Linear stacking of layers

# Convolution Layer 1
model.add(Conv2D(32, (3, 3), input_shape=(28,28,1), name="Conv01")) # 32 different 3x3 kernels -- so 32 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
convLayer01 = Activation('relu')                     # activation
model.add(convLayer01)

# Convolution Layer 2
model.add(Conv2D(32, (3, 3)))                        # 32 different 3x3 kernels -- so 32 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
model.add(Activation('relu'))                        # activation
convLayer02 = MaxPooling2D(pool_size=(2,2))          # Pool the max values over a 2x2 kernel
model.add(convLayer02)

# Convolution Layer 3
model.add(Conv2D(64,(3, 3)))                         # 64 different 3x3 kernels -- so 64 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
convLayer03 = Activation('relu')                     # activation
model.add(convLayer03)

# Convolution Layer 4
model.add(Conv2D(64, (3, 3)))                        # 64 different 3x3 kernels -- so 64 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
model.add(Activation('relu'))                        # activation
convLayer04 = MaxPooling2D(pool_size=(2,2))          # Pool the max values over a 2x2 kernel
model.add(convLayer04)
model.add(Flatten())                                 # Flatten final 4x4x64 output matrix into a 1024-length vector

# Fully Connected Layer 5
model.add(Dense(512))                                # 512 FCN nodes
model.add(BatchNormalization())                      # normalization
model.add(Activation('relu'))                        # activation

# Fully Connected Layer 6
model.add(Dropout(0.2))                              # 20% dropout of randomly selected nodes
model.add(Dense(10))                                 # final 10 FCN nodes
model.add(Activation('softmax'))                     # softmax activation

In [410]:
## Distillation
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
####import torch.nn.functional as F

NUM_EPOCHS = 1

# Check if GPU is available, and if not, use the CPU
print("torch.cuda.is_available()", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading the CIFAR-10 dataset:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

torch.cuda.is_available() False


In [411]:
## Distillation
num_classes = 10
# neural network class to be used as teacher:
class TeacherNN(nn.Module):
    def __init__(self, num_classes = 10):
        super(TeacherNN, self).__init__()
        self.features = nn.Sequential(
          # Convolution Layer 1
          nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 2
          nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Convolution Layer 3
          nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 4
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Flatten the output
          nn.Flatten(),
        )

        self.classifier = nn.Sequential(
          # Fully Connected Layer 5
          nn.Linear(in_features=64 * 4 * 4, out_features=512),  # Input size from final convolution output
          nn.BatchNorm1d(512),  # Normalize FCN output
          nn.ReLU(inplace=True),  # Activation function

          # Fully Connected Layer 6
          nn.Dropout(0.2),  # Dropout with 20%
          nn.Linear(in_features=512, out_features=10),  # Final output layer with 10 nodes (one for each class)
          nn.Softmax(dim=1)  # Softmax activation for classification
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class StudentNN(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNN, self).__init__()
        self.features = nn.Sequential(
            # Convolution Layer 1
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            # Convolution Layer 2
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),

            # Flatten the output
            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            # Fully Connected Layer 3
            ## changed in_features=32 * 6 * 6 to in_features=32 * 12 * 12
            nn.Linear(in_features=32 * 12 * 12, out_features=128),  # Smaller FC layer
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),

            # Fully Connected Layer 4
            nn.Dropout(0.2),
            nn.Linear(in_features=128, out_features=num_classes),  # Output layer
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [412]:
## Distillation

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [413]:
## Distillation

torch.manual_seed(42)
nn_teacher = TeacherNN(num_classes=10).to(device)
train(nn_teacher, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
test_accuracy_teacher = test(nn_teacher, test_loader, device)

# Instantiate the Student network:
torch.manual_seed(42)
nn_student = StudentNN(num_classes=10).to(device)

Epoch 1/1, Loss: 1.5091872396667798
Test Accuracy: 98.60%


In [414]:
## Distillation
torch.manual_seed(42)
new_nn_student = StudentNN(num_classes=10).to(device)

In [415]:
## Distillation

# Print the norm of the first layer of the initial student model
print("Norm of 1st layer of nn_student:", torch.norm(nn_student.features[0].weight).item())
# Print the norm of the first layer of the new student model
print("Norm of 1st layer of new_nn_student:", torch.norm(new_nn_student.features[0].weight).item())



Norm of 1st layer of nn_student: 2.3761110305786133
Norm of 1st layer of new_nn_student: 2.3761110305786133


In [416]:
## Distillation

## Print the total number of parameters in each model:

total_params_teacher = "{:,}".format(sum(p.numel() for p in nn_teacher.parameters()))
print(f"TeacherNN parameters: {total_params_teacher}")
total_params_student = "{:,}".format(sum(p.numel() for p in nn_student.parameters()))
print(f"StudentNN parameters: {total_params_student}")

TeacherNN parameters: 596,330
StudentNN parameters: 596,394


In [417]:
## Distillation

## Train and test the lightweight network with cross entropy loss:

train(nn_student, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
test_accuracy_student_ce = test(nn_student, test_loader, device)

Epoch 1/1, Loss: 1.5235334737698236
Test Accuracy: 98.14%


In [418]:
## Distillation

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy: {test_accuracy_student_ce:.2f}%")

Teacher accuracy: 98.60%
Student accuracy: 98.14%


In [419]:
## Distillation loss is calculated from the logits of the networks. It only returns gradients to the student:

def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_teacher, student=new_nn_student, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_kd = test(new_nn_student, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")

Epoch 1/1, Loss: 1.142696574985981
Test Accuracy: 97.99%
Teacher accuracy: 98.60%
Student accuracy without teacher: 98.14%
Student accuracy with CE + KD: 97.99%


In [420]:
print(nn_teacher)

TeacherNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Seque

In [421]:
## Distillation

class ModifiedTeacherNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedTeacherNNCosine, self).__init__()
        self.features = nn.Sequential(
            # one input channel for MNIST
            nn.Conv2d(1, 32, kernel_size=3, padding=0),
            nn.BatchNorm2d(32),  # Normalize each feature map
            nn.ReLU(inplace=True),  # Activation function

            # Convolution Layer 2
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
            nn.BatchNorm2d(32),  # Normalize each feature map
            nn.ReLU(inplace=True),  # Activation function
            nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

            # Convolution Layer 3
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
            nn.BatchNorm2d(64),  # Normalize each feature map
            nn.ReLU(inplace=True),  # Activation function

            # Convolution Layer 4
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
            nn.BatchNorm2d(64),  # Normalize each feature map
            nn.ReLU(inplace=True),  # Activation function
            nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

            # Flatten the output
            nn.Flatten()
        )

        self.classifier = nn.Sequential(
          # Fully Connected Layer 5
          nn.Linear(in_features=64 * 4 * 4, out_features=512),  # Input size from final convolution output
          nn.BatchNorm1d(512),  # Normalize FCN output
          nn.ReLU(inplace=True),  # Activation function

          # Fully Connected Layer 6
          nn.Dropout(0.2),  # Dropout with 20%
          nn.Linear(in_features=512, out_features=10),  # Final output layer with 10 nodes (one for each class)
          nn.Softmax(dim=1)  # Softmax activation for classification
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
        return x, flattened_conv_output_after_pooling

# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedStudentNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedStudentNNCosine, self).__init__()
        self.features = nn.Sequential(
            # Convolution Layer 1
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            # Convolution Layer 2
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),

            # Flatten the output
            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            # Fully Connected Layer 3
            ## changed in_features=32 * 6 * 6 to in_features=32 * 12 * 12
            nn.Linear(in_features=32 * 12 * 12, out_features=128),  # Smaller FC layer
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),

            # Fully Connected Layer 4
            nn.Dropout(0.2),
            nn.Linear(in_features=128, out_features=num_classes),  # Output layer
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        return x, flattened_conv_output

# We do not have to train the modified teacher network from scratch. We just load its weights from the trained instance
modified_nn_teacher = ModifiedTeacherNNCosine(num_classes=10).to(device)
print(modified_nn_teacher)

ModifiedTeacherNNCosine(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Flatten(start_dim=1, end_dim=-1)
  )
  (clas

In [422]:
## Distillation
# ... We do not have to train the modified teacher network from scratch. We just load its weights from the trained instance
modified_nn_teacher.load_state_dict(nn_teacher.state_dict())

# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for nn_teacher:", torch.norm(nn_teacher.features[0].weight).item())
print("Norm of 1st layer for modified_nn_teacher:", torch.norm(modified_nn_teacher.features[0].weight).item())

# Initialize a modified student (lightweight) network with the same seed as our other student instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_student = ModifiedStudentNNCosine(num_classes=10).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_student.features[0].weight).item())

print(nn_student)
print(modified_nn_student)

Norm of 1st layer for nn_teacher: 3.520979642868042
Norm of 1st layer for modified_nn_teacher: 3.520979642868042
Norm of 1st layer: 2.3761110305786133
StudentNN(
  (features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=4608, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=128, out_features=10, bias=True)
    (5): Sof

In [423]:
## Distillation

## The model returns a tuple (logits, hidden_representation). Using a sample input tensor we can print their shapes.

# Create a sample input tensor
sample_input = torch.randn(128, 1, 28, 28).to(device) # Batch size: 128, Filters: 1, Image size: 32x32

# Pass the input through the student
logits, hidden_representation = modified_nn_student(sample_input)

# Print the shapes of the tensors
print("Student logits shape:", logits.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Pass the input through the teacher
logits, hidden_representation = modified_nn_teacher(sample_input)

# Print the shapes of the tensors
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 1024x512)

In [None]:
## Distillation

## In Cosine Loss minimization we maximize the cosine similarity of the two representations by returning gradients to the student:

def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model and keep only the hidden representation
            with torch.no_grad():
                _, teacher_hidden_representation = teacher(inputs)

            # Forward pass with the student model
            student_logits, student_hidden_representation = student(inputs)

            # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [None]:
## Distillation

## Here we ignore the hidden representation returned by the model.

def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
## Distillation

# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=modified_nn_teacher, student=modified_nn_student, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_cosine_loss = test_multiple_outputs(modified_nn_student, test_loader, device)

In [None]:
## Distillation

# Pass the sample input only from the convolutional feature extractor
convolutional_fe_output_student = nn_student.features(sample_input)
convolutional_fe_output_teacher = nn_teacher.features(sample_input)

# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)

In [None]:
## Distillation

class ModifiedTeacherNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedTeacherNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class ModifiedStudentNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedStudentNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # Include an extra regressor (in our case linear)
        self.regressor = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

In [None]:
## Distillation

def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Again ignore teacher logits
            with torch.no_grad():
                _, teacher_feature_map = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.

# Initialize a ModifiedStudentNNRegressor
torch.manual_seed(42)
modified_nn_student_reg = ModifiedStudentNNRegressor(num_classes=10).to(device)

# We do not have to train the modified deep network from scratch.  We just load its weights from the trained instance
modified_nn_teacher_reg = ModifiedTeacherNNRegressor(num_classes=10).to(device)
modified_nn_teacher_reg.load_state_dict(nn_teacher.state_dict())

# Train and test once again
train_mse_loss(teacher=modified_nn_teacher_reg, student=modified_nn_student_reg, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_mse_loss = test_multiple_outputs(modified_nn_student_reg, test_loader, device)

In [None]:
## Distillation

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_student_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_student_ce_and_mse_loss:.2f}%")

In [None]:
model.summary()

In [None]:
#### model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [None]:
#### model.fit(X_train, Y_train, batch_size=128, epochs=1, verbose=1, validation_data=(X_test, Y_test))

In [None]:
### score = model.evaluate(X_test, Y_test)
print('Test score:', score[0])
print('Test accuracy:', score[1])