<a href="https://colab.research.google.com/github/bec2148/mnist-finalproject/blob/main/MNIST_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, Flatten
from tensorflow.keras.layers import BatchNormalization

In [36]:
TRAINING_SIZE = 60000

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[:TRAINING_SIZE,:,:]
y_train = y_train[:TRAINING_SIZE]
print("X_train shape", X_train.shape)
print("y_train shape", y_train.shape)

X_train shape (60000, 28, 28)
y_train shape (60000,)


In [37]:
# We do not flatten each image into a 784-length vector because we want to perform convolutions first

X_train = X_train.reshape(TRAINING_SIZE, 28, 28, 1) #add an additional dimension to represent the single-channel
X_test = X_test.reshape(10000, 28, 28, 1)

X_train = X_train.astype('float32')         # change integers to 32-bit floating point numbers
X_test = X_test.astype('float32')

X_train /= 255                              # normalize each value for each pixel for the entire vector for each input
X_test /= 255

print("Training matrix shape", X_train.shape)
print("Testing matrix shape", X_test.shape)

Training matrix shape (60000, 28, 28, 1)
Testing matrix shape (10000, 28, 28, 1)


In [38]:
# one-hot format classes

nb_classes = 10 # number of unique digits

Y_train = utils.to_categorical(y_train, nb_classes)
Y_test = utils.to_categorical(y_test, nb_classes)

In [39]:
model = Sequential()                                 # Linear stacking of layers

# Convolution Layer 1
model.add(Conv2D(32, (3, 3), input_shape=(28,28,1), name="Conv01")) # 32 different 3x3 kernels -- so 32 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
convLayer01 = Activation('relu')                     # activation
model.add(convLayer01)

# Convolution Layer 2
model.add(Conv2D(32, (3, 3)))                        # 32 different 3x3 kernels -- so 32 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
model.add(Activation('relu'))                        # activation
convLayer02 = MaxPooling2D(pool_size=(2,2))          # Pool the max values over a 2x2 kernel
model.add(convLayer02)

# Convolution Layer 3
model.add(Conv2D(64,(3, 3)))                         # 64 different 3x3 kernels -- so 64 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
convLayer03 = Activation('relu')                     # activation
model.add(convLayer03)

# Convolution Layer 4
model.add(Conv2D(64, (3, 3)))                        # 64 different 3x3 kernels -- so 64 feature maps
model.add(BatchNormalization(axis=-1))               # normalize each feature map before activation
model.add(Activation('relu'))                        # activation
convLayer04 = MaxPooling2D(pool_size=(2,2))          # Pool the max values over a 2x2 kernel
model.add(convLayer04)
model.add(Flatten())                                 # Flatten final 4x4x64 output matrix into a 1024-length vector

# Fully Connected Layer 5
model.add(Dense(512))                                # 512 FCN nodes
model.add(BatchNormalization())                      # normalization
model.add(Activation('relu'))                        # activation

# Fully Connected Layer 6
model.add(Dropout(0.2))                              # 20% dropout of randomly selected nodes
model.add(Dense(10))                                 # final 10 FCN nodes
model.add(Activation('softmax'))                     # softmax activation

In [40]:
## Distillation
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
####import torch.nn.functional as F

NUM_EPOCHS = 1

# Check if GPU is available, and if not, use the CPU
print("torch.cuda.is_available()", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading the CIFAR-10 dataset:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
                                                          download=True,
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])),
                                           batch_size=10,
                                           shuffle=True)

torch.cuda.is_available() False


In [41]:
## Distillation
num_classes = 10
# neural network class to be used as teacher:
class TeacherNN(nn.Module):
    def __init__(self, num_classes = 10):
        super(TeacherNN, self).__init__()
        self.features = nn.Sequential(
          # Convolution Layer 1
          nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 2
          nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # 32 3x3 kernels
          nn.BatchNorm2d(32),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Convolution Layer 3
          nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function

          # Convolution Layer 4
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=0, stride=1),  # 64 3x3 kernels
          nn.BatchNorm2d(64),  # Normalize each feature map
          nn.ReLU(inplace=True),  # Activation function
          nn.MaxPool2d(kernel_size=(2, 2)),  # Pooling with a 2x2 kernel

          # Flatten the output
          nn.Flatten(),
        )

        self.classifier = nn.Sequential(
          # Fully Connected Layer 5
          nn.Linear(in_features=64 * 4 * 4, out_features=512),  # Input size from final convolution output
          nn.BatchNorm1d(512),  # Normalize FCN output
          nn.ReLU(inplace=True),  # Activation function

          # Fully Connected Layer 6
          nn.Dropout(0.2),  # Dropout with 20%
          nn.Linear(in_features=512, out_features=10),  # Final output layer with 10 nodes (one for each class)
          nn.Softmax(dim=1)  # Softmax activation for classification
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class StudentNN(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNN, self).__init__()
        self.features = nn.Sequential(
            # Convolution Layer 1
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            # Convolution Layer 2
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=0, stride=1),  # Fewer filters
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),

            # Flatten the output
            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            # Fully Connected Layer 3
            ## changed in_features=32 * 6 * 6 to in_features=32 * 12 * 12
            nn.Linear(in_features=32 * 12 * 12, out_features=128),  # Smaller FC layer
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),

            # Fully Connected Layer 4
            nn.Dropout(0.2),
            nn.Linear(in_features=128, out_features=num_classes),  # Output layer
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [42]:
## Distillation

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [43]:
## Distillation

torch.manual_seed(42)
nn_teacher = TeacherNN(num_classes=10).to(device)
train(nn_teacher, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
test_accuracy_teacher = test(nn_teacher, test_loader, device)

# Instantiate the Student network:
torch.manual_seed(42)
nn_student = StudentNN(num_classes=10).to(device)

Epoch 1/1, Loss: 1.5086240475972494
Test Accuracy: 98.27%


In [44]:
## Distillation
torch.manual_seed(42)
new_nn_student = StudentNN(num_classes=10).to(device)

In [45]:
## Distillation

# Print the norm of the first layer of the initial student model
print("Norm of 1st layer of nn_student:", torch.norm(nn_student.features[0].weight).item())
# Print the norm of the first layer of the new student model
print("Norm of 1st layer of new_nn_student:", torch.norm(new_nn_student.features[0].weight).item())



Norm of 1st layer of nn_student: 2.3761110305786133
Norm of 1st layer of new_nn_student: 2.3761110305786133


In [46]:
## Distillation

## Print the total number of parameters in each model:

total_params_teacher = "{:,}".format(sum(p.numel() for p in nn_teacher.parameters()))
print(f"TeacherNN parameters: {total_params_teacher}")
total_params_student = "{:,}".format(sum(p.numel() for p in nn_student.parameters()))
print(f"StudentNN parameters: {total_params_student}")

TeacherNN parameters: 596,330
StudentNN parameters: 596,394


In [47]:
## Distillation

## Train and test the lightweight network with cross entropy loss:

train(nn_student, train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, device=device)
test_accuracy_student_ce = test(nn_student, test_loader, device)

Epoch 1/1, Loss: 1.5230646441578866
Test Accuracy: 98.12%


In [48]:
## Distillation

print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy: {test_accuracy_student_ce:.2f}%")

Teacher accuracy: 98.27%
Student accuracy: 98.12%


In [49]:
## Distillation loss is calculated from the logits of the networks. It only returns gradients to the student:

def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_teacher, student=new_nn_student, train_loader=train_loader, epochs=NUM_EPOCHS, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_student_ce_and_kd = test(new_nn_student, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_teacher:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_student_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_student_ce_and_kd:.2f}%")

Epoch 1/1, Loss: 1.1428454108436903
Test Accuracy: 98.24%
Teacher accuracy: 98.27%
Student accuracy without teacher: 98.12%
Student accuracy with CE + KD: 98.24%
