Knowledge Distillation 
===============================

**Author**: [Clara Martinez](https://github.com/moonblume/LIVIA.git)


Knowledge distillation is a technique that enables knowledge transfer
from large, computationally expensive models to smaller ones without
losing validity. This allows for deployment on less powerful hardware,
making evaluation faster and more efficient.

Librairies
================


In [340]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from scipy.signal import savgol_filter
from sklearn.metrics import mean_absolute_error, mean_squared_error

from typing import List, Union, Tuple, Any
import statistics

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading dataset
================

Fisrt, I focus on the physiological signals of the Biovid dataset. In one sample, we have access to 6 classes associated with 0 to 4 pain levels :  

Time: This could be the timestamp or time index when the signal was recorded.

GSR (Galvanic Skin Response): A measure of the electrical conductance of the skin, which varies with the moisture level of the skin. It's often associated with emotional arousal.

ECG (Electrocardiogram): A recording of the electrical activity of the heart over time. It typically consists of waves representing the depolarization and repolarization of the heart muscle during each heartbeat.

EMG (Electromyography) - Trapezius: Measures the electrical activity produced by skeletal muscles. The trapezius muscle is a large superficial muscle that extends longitudinally from the occipital bone to the lower thoracic vertebrae and laterally to the spine of the scapula.

EMG - Corrugator: Electromyography signal from the corrugator supercilii muscle, which is a small facial muscle involved in frowning and expressing negative emotions.

EMG - Zygomaticus: Electromyography signal from the zygomaticus major muscle, which is involved in smiling and expressing positive emotions.  
    
    
Our objective is to predict the pain level of input signals. One signal corresponds to one csv file.



In [341]:
# Define the directory containing the CSV files
biosignals_path = '/home/ens/AU59350/LIVIA/physio/physio_organised/'

# Initialize an empty list to store data for DataFrame
data = []

# Iterate over each pain level directory
for pain_level in os.listdir(biosignals_path):
    pain_level_dir = os.path.join(biosignals_path, pain_level)
    
    # Check if it's a directory
    if os.path.isdir(pain_level_dir):
        # Iterate over each CSV file in the pain level directory
        for csv_file in os.listdir(pain_level_dir):
            # Check if it's a CSV file
            if csv_file.endswith('.csv'):
                csv_path = os.path.join(pain_level_dir, csv_file)
                # Read the CSV file
                df = pd.read_csv(csv_path, sep='\t')
                # Extract GSR values
                gsr_signal = df['gsr'].values
                # Extract ECG values
                ecg_signal = df['ecg'].values
                # Extract EMG trapezius values
                emg_signal = df['emg_trapezius'].values
                # Extract time values
                time = df['time'].values
                # Append the CSV name, GSR signals, and Pain level to the data list
                data.append({'CSV name': csv_file,'Time': time, 'GSR signals': gsr_signal, 'ECG signals': ecg_signal, 'EMG signals': emg_signal,'Pain level': int(pain_level)})

# Create a DataFrame from the collected data
df = pd.DataFrame(data)

# Display the DataFrame
df.head()


Unnamed: 0,CSV name,Time,GSR signals,ECG signals,EMG signals,Pain level
0,072414_m_23-PA2-034_bio.csv,"[1641, 3594, 5547, 7500, 9453, 11406, 13359, 1...","[6.966839, 6.966161, 6.966, 6.966839, 6.966161...","[-246.3745, -248.5128, -247.0629, -248.6413, -...","[-0.001924584, -0.02534641, -0.3388469, -0.513...",2
1,081609_w_40-PA2-028_bio.csv,"[0, 1953, 3906, 5859, 7813, 9766, 11719, 13672...","[0.872, 0.872, 0.872, 0.872, 0.872, 0.872, 0.8...","[99.9569, 111.0614, 114.0062, 123.7483, 117.03...","[2.315527, -4.576343, -7.510249, -2.14524, 1.3...",2
2,081714_m_36-PA2-065_bio.csv,"[859, 2813, 4766, 6719, 8672, 10625, 12578, 14...","[6.089862, 6.091, 6.091432, 6.092, 6.092432, 6...","[186.0979, 187.6918, 190.7044, 193.3572, 195.2...","[3.878096e-29, -1.448683e-28, 5.406606e-28, -2...",2
3,102514_w_40-PA2-046_bio.csv,"[0, 1953, 3906, 5859, 7813, 9766, 11719, 13672...","[1.462, 1.462, 1.462, 1.462, 1.462, 1.462, 1.4...","[-107.6247, -95.28533, -108.069, -102.2526, -1...","[-1.5371, -0.8260319, 0.09676914, 0.09593942, ...",2
4,120514_w_56-PA2-019_bio.csv,"[234, 2188, 4141, 6094, 8047, 10000, 11953, 13...","[2.226, 2.226, 2.226, 2.226, 2.226, 2.226, 2.2...","[-262.4505, -253.6655, -228.2035, -202.7421, -...","[-0.005932733, 0.0208005, -0.07729835, 0.27021...",2


### Selection of number of pain level included in the classification task

In [342]:
# Filter the DataFrame to keep rows where pain level is not equal to 1, 2, or 3
df = df[~df['Pain level'].isin([1, 2, 3])]

# Replace label values in the DataFrame
df['Pain level'] = df['Pain level'].replace({4: 1})

# Print the filtered DataFrame
print(df)

                         CSV name  \
5220  080314_w_25-PA4-067_bio.csv   
5221  092009_m_54-PA4-042_bio.csv   
5222  071709_w_23-PA4-071_bio.csv   
5223  082809_m_26-PA4-005_bio.csv   
5224  112909_w_20-PA4-080_bio.csv   
...                           ...   
8695  082909_m_47-BL1-085_bio.csv   
8696  081609_w_40-BL1-090_bio.csv   
8697  091809_w_43-BL1-097_bio.csv   
8698  112016_m_25-BL1-091_bio.csv   
8699  083013_w_47-BL1-086_bio.csv   

                                                   Time  \
5220  [234, 2188, 4141, 6094, 8047, 10000, 11953, 13...   
5221  [156, 2109, 4063, 6016, 7969, 9922, 11875, 138...   
5222  [1406, 3359, 5313, 7266, 9219, 11172, 13125, 1...   
5223  [391, 2344, 4297, 6250, 8203, 10156, 12109, 14...   
5224  [1406, 3359, 5313, 7266, 9219, 11172, 13125, 1...   
...                                                 ...   
8695  [781, 2734, 4688, 6641, 8594, 10547, 12500, 14...   
8696  [547, 2500, 4453, 6406, 8359, 10313, 12266, 14...   
8697  [313, 2266, 4219, 

In [343]:
print(df['Pain level'].unique())

[1 0]


In [344]:
# Remove the 'Time' column from the DataFrame
df.drop(columns=['Time'], inplace=True)

# Remove the 'ECG' column from the DataFrame
df.drop(columns=['ECG signals'], inplace=True)

# Remove the 'EMG' column from the DataFrame
df.drop(columns=['EMG signals'], inplace=True)

# Remove the 'CSV name' column from the DataFrame
#df.drop(columns=['CSV name'], inplace=True)

Preprocessing
================

Preprocessing steps for GSR DataFrame include tasks such as handling missing values, smoothing the signal to reduce noise in the GSR signal 9(Savitzky-Golay filtering), removing outliers (z-score), and normalizing the data between a specified range, such as [0, 1] or [-1, 1] helping comparison across different subjects.

In [345]:
# Function to preprocess GSR signals
def preprocess_gsr_signal(gsr_signal):
    # Handle missing values (if any)
    gsr_signal = np.array(gsr_signal)  # Convert to NumPy array
    gsr_signal = gsr_signal[~np.isnan(gsr_signal)]  # Remove NaN values
    
    # Check if the length of the signal is sufficient for smoothing
    if len(gsr_signal) < 5:
        # If the signal is too short, return the original signal
        return gsr_signal
    
    try:
        # Smoothing using Savitzky-Golay filter
        gsr_signal_smooth = savgol_filter(gsr_signal, window_length=5, polyorder=2)
    except ValueError:
        # If an error occurs during smoothing, return the original signal
        return gsr_signal
    
    # Removing outliers based on Z-scores
    z_scores = (gsr_signal_smooth - gsr_signal_smooth.mean()) / gsr_signal_smooth.std()
    gsr_signal_smooth_no_outliers = gsr_signal_smooth[(z_scores < 3)]
    
    # Normalization
    if len(gsr_signal_smooth_no_outliers) > 0:
        gsr_signal_normalized = (gsr_signal_smooth_no_outliers - gsr_signal_smooth_no_outliers.min()) / \
                                 (gsr_signal_smooth_no_outliers.max() - gsr_signal_smooth_no_outliers.min())
    else:
        # If there are no valid values after removing outliers, return the original signal
        return gsr_signal
    
    return gsr_signal_normalized

# Apply preprocessing to each row in the DataFrame
df['GSR signals'] = df['GSR signals'].apply(preprocess_gsr_signal)

# Display the updated DataFrame
df.head()

Unnamed: 0,CSV name,GSR signals,Pain level
5220,080314_w_25-PA4-067_bio.csv,"[0.0, 0.000784955924601847, 0.0012131137016623...",1
5221,092009_m_54-PA4-042_bio.csv,"[0.21654547886192313, 0.21654547886192313, 0.2...",1
5222,071709_w_23-PA4-071_bio.csv,"[1.0, 0.9983642739356123, 0.9979709555643436, ...",1
5223,082809_m_26-PA4-005_bio.csv,"[0.3629402970779174, 0.36178416384520784, 0.36...",1
5224,112909_w_20-PA4-080_bio.csv,"[0.16994520796566095, 0.1741950993228739, 0.17...",1


In [346]:
# Sélectionnez les lignes où la valeur de la colonne "CSV name" commence par "080314_w_25"
#exemple_df = df[df['CSV name'].str.startswith('080314_w_25')]

# Affichez le DataFrame filtré
#print(exemple_df)

We observe that inside the dataframe each patients had multiple CSV file. In order to perform a Leave-One-Out Cross-Validation (LASO), we need to build annotated file per patients.

### Convert Pandas Dataframe into Pytorch tensor

In [347]:
# Check the unique data types in the 'GSR signals' column
#print(df['GSR signals'].apply(type).unique())

#print(df.dtypes)

In [348]:
import torch
from torch.utils.data import Dataset, DataLoader

class GSRDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        
    def __len__(self):
        return self.dataframe.shape[0]
        
    def __getitem__(self, index):
        row = self.dataframe.iloc[index, :]
        inputs = row['GSR signals']  # No need to convert to tensor
        label = torch.tensor(row['Pain level'])
        
        # Convert the NumPy array to a list of tensors
        inputs = [torch.tensor(signal) for signal in inputs]
        
        # Stack the list of tensors along a new dimension to form a single tensor
        inputs = torch.stack(inputs)
        
        return inputs, label

# Create a Dataset and DataLoader
gsrDataset = GSRDataset(df)
dataloader = DataLoader(gsrDataset, batch_size=1, shuffle=False)

# Use the DataLoader in a training loop
for inputs, labels in dataloader:
    print("Inputs:", inputs)
    print("Labels:", labels)


Inputs: tensor([[0.0000e+00, 7.8496e-04, 1.2131e-03,  ..., 9.9344e-01, 9.9344e-01,
         9.9344e-01]], dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.2165, 0.2165, 0.2165,  ..., 0.9872, 0.9934, 1.0000]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[1.0000, 0.9984, 0.9980,  ..., 0.3482, 0.3499, 0.3505]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.3629, 0.3618, 0.3611,  ..., 0.9975, 0.9987, 1.0000]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.1699, 0.1742, 0.1763,  ..., 0.9866, 0.9939, 1.0000]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.9818, 0.9818, 0.9818,  ..., 0.0289, 0.0289, 0.0289]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.9986, 0.9986, 0.9986,  ..., 0.0392, 0.0392, 0.0392]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs: tensor([[0.0112, 0.0164, 0.0185,  ..., 0.9976, 0.9989, 1.0000]],
       dtype=torch.float64)
Labels: tensor([1])
Inputs

In [349]:
import torch

def check_tensor_sizes(batch):
    # Get the size of the first tensor in the batch
    first_size = batch[0].size()
    
    # Iterate through the remaining tensors in the batch
    for tensor in batch[1:]:
        # Compare the size of each tensor with the size of the first tensor
        if tensor.size() != first_size:
            return False  # Return False if sizes are not equal
    
    return True  # Return True if all sizes are equal

# Example usage:
# Assuming 'batch' is a list of PyTorch tensors representing a batch of inputs
# You can call the check_tensor_sizes function and pass the batch as an argument
if check_tensor_sizes(batch):
    print("All tensors in the batch are of equal size.")
else:
    print("Tensors in the batch have different sizes.")


Tensors in the batch have different sizes.


In [350]:
# Split the dataset into training and testing sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = GSRDataset(train_df)
test_dataset = GSRDataset(test_df)

train_dataset = GSRDataset(train_df)
test_dataset = GSRDataset(test_df)

Defining model classes and utility functions
================

In [351]:
# Neural network class to be used as teacher:

class Conv1D_T(nn.Module):
    def __init__(self, num_classes=2):
        super(Conv1D_T, self).__init__()
        # First Convolutional Layer
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, stride=2)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool1d(kernel_size=2)
        
        # Second Convolutional Layer
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool1d(kernel_size=2)
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(22336, 512)  
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = x.view(x.size(0), -1)  # Flatten the tensor to 1D
        x = self.fc1(x)
        
        x = self.fc2(x)
        #x = self.sigmoid(x)
        return x

In [352]:
# Neural network class to be used as student:

class Conv1D_S(nn.Module):
    def __init__(self, num_classes=2):
        super(Conv1D_S, self).__init__()
        # First Convolutional Layer
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, stride=2)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool1d(kernel_size=2)
        
        # Second Convolutional Layer
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool1d(kernel_size=2)
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(22336, 512)  
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = x.view(x.size(0), -1)  # Flatten the tensor to 1D
        x = self.fc1(x)
        
        x = self.fc2(x)
        #x = self.sigmoid(x)
        return x


Training Testing 
----------------

In [353]:
def train(model, train_loader, epochs, learning_rate, device):
    
    # Define learning parameters
    learning_rate = 0.0001
    epochs = 10
    batch_size = 1024
    num_classes = 2

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
   
    print("Batch structure:", next(iter(train_loader)))

    for epoch in tqdm(range(epochs)):
            
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            print("Inputs:", inputs)
            print("Labels:", labels)
            inputs = inputs.to(torch.float32)  # new Convert inputs to torch.float32
            labels = labels.to(device) # new
            optimizer.zero_grad()
            #outputs = model(inputs) enleve
            outputs = nn_t(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
    

            # Calculate training accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")


def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    predictions_list = []
    labels_list = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(torch.float32).to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            predictions_list.extend(predicted.cpu().numpy())
            labels_list.extend(labels.cpu().numpy())

    predictions = torch.tensor(predictions_list)
    labels = torch.tensor(labels_list)

    accuracy = torch.mean((predictions == labels).float()).item()
    mae = mean_absolute_error(labels, predictions)
    rmse = mean_squared_error(labels, predictions, squared=False)

    return accuracy, mae, rmse


def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        outputs = model(X_test_tensor)
        predictions = torch.argmax(outputs, dim=1)
        y_test_tensor_sampled = y_test_tensor[:len(predictions)]
        accuracy = torch.mean((predictions == y_test_tensor_sampled).float()).item()
        print(f"Test Accuracy: {accuracy:.4f}")
        
        # Calculate MAE and RMSE
        mae = mean_absolute_error(y_test_tensor_sampled, predictions)
        rmse = mean_squared_error(y_test_tensor_sampled, predictions, squared=False)
        print(f"MAE: {mae:.4f}")
        print(f"RMSE: {rmse:.4f}")

Cross-entropy runs
==================

For reproducibility, we need to set the torch manual seed. I train
networks using different methods, so to compare them fairly, it makes
sense to initialize the networks with the same weights. I start by
training the teacher network using cross-entropy:

In [354]:
# Define a function to pad tensors to the maximum length within each batch
def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_length = max(len(tensor) for tensor in inputs)
    padded_inputs = []
    for tensor in inputs:
        if len(tensor) < max_length:
            padded_tensor = torch.nn.functional.pad(tensor, (0, max_length - len(tensor)))
            padded_inputs.append(padded_tensor)
        else:
            padded_inputs.append(tensor)
    padded_inputs = torch.stack(padded_inputs)
    return padded_inputs, torch.tensor(labels)

# Create DataLoader for training and testing sets with the custom collate function
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Train the teacher network
torch.manual_seed(42)
nn_t = Conv1D_T(num_classes=2).to(device)
train(nn_t, train_loader, epochs=10, learning_rate=0.0001, device=device)

# Test the teacher network
#test_accuracy_T = test(nn_t, test_loader, device)
test_accuracy, test_mae, test_rmse = test(nn_t, test_loader, device)
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"MAE: {test_mae:.4f}")
print(f"RMSE: {test_rmse:.4f}")
# Instantiate the lightweight network
torch.manual_seed(42)
nn_s = Conv1D_S(num_classes=2).to(device)

Batch structure: (tensor([[2.2761e-03, 2.2761e-03, 2.2761e-03,  ..., 8.2748e-01, 8.2748e-01,
         8.2748e-01],
        [9.9644e-01, 9.9644e-01, 9.9644e-01,  ..., 2.7024e-03, 2.7024e-03,
         2.7024e-03],
        [9.9763e-01, 9.9763e-01, 9.9763e-01,  ..., 4.2769e-03, 4.2769e-03,
         4.2769e-03],
        ...,
        [1.0000e+00, 9.8818e-01, 9.8313e-01,  ..., 8.9487e-04, 8.9487e-04,
         8.9487e-04],
        [1.6378e-01, 1.6378e-01, 1.6378e-01,  ..., 9.9888e-01, 9.9888e-01,
         9.9888e-01],
        [9.9823e-01, 1.0000e+00, 9.9954e-01,  ..., 5.2466e-04, 2.2123e-03,
         5.1772e-03]], dtype=torch.float64), tensor([0, 0, 0,  ..., 0, 1, 0]))


  0%|          | 0/10 [00:09<?, ?it/s]

Inputs: tensor([[1.2021e-02, 6.6119e-03, 3.9775e-03,  ..., 7.4349e-01, 7.4349e-01,
         7.4349e-01],
        [9.4121e-01, 8.2519e-01, 8.0954e-01,  ..., 3.6833e-01, 3.6833e-01,
         3.6833e-01],
        [3.6445e-01, 3.6445e-01, 3.6445e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [9.6724e-01, 9.6724e-01, 9.6724e-01,  ..., 2.2945e-03, 2.2945e-03,
         2.2945e-03],
        [9.9735e-01, 9.9735e-01, 9.9735e-01,  ..., 2.7501e-03, 2.7501e-03,
         2.7501e-03],
        [9.9928e-01, 9.9928e-01, 9.9928e-01,  ..., 5.0574e-04, 5.0574e-04,
         5.0574e-04]], dtype=torch.float64)
Labels: tensor([1, 1, 1,  ..., 0, 0, 0])





RuntimeError: Given groups=1, weight of size [32, 1, 5], expected input[1, 1024, 2816] to have 1 channels, but got 1024 channels instead

batch_size = 1024

# Create DataLoader for training and testing sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Train the teacher network
torch.manual_seed(42)
nn_t = Conv1D_T(num_classes=2).to(device)
train(nn_t, train_loader, epochs=10, learning_rate=0.0001, device=device)

# Test the teacher network
test_accuracy_T = test(nn_t, test_loader, device)

# Instantiate the lightweight network
torch.manual_seed(42)
nn_s = Conv1D_S(num_classes=2).to(device)


I instantiate one more lightweight network model to compare their
performances. Back propagation is sensitive to weight initialization, so
I need to make sure these two networks have the exact same
initialization.


In [None]:
torch.manual_seed(42)
new_Conv1D_S = Conv1D_S(num_classes=2).to(device)

To ensure I have created a copy of the first network, we inspect the
norm of its first layer. If it matches, then the networks are indeed the same.

In [None]:
# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(Conv1D_S.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_Conv1D_S.features[0].weight).item())

Print the total number of parameters in each model:

In [None]:
total_params_T = "{:,}".format(sum(p.numel() for p in Conv1D_T.parameters()))
print(f"DeepNN parameters: {total_params_T}")
total_params_S = "{:,}".format(sum(p.numel() for p in Conv1D_S.parameters()))
print(f"LightNN parameters: {total_params_S}")

Train and test the lightweight network with cross entropy loss:

In [None]:
train(Conv1D_S, train_loader, epochs=10, learning_rate=0.0001, device=device)
test_accuracy_S_ce = test(Conv1D_S, test_loader, device)

As we can see, based on test accuracy, I can now compare the deeper
network that is to be used as a teacher with the lightweight network
that is the supposed student. So far, the student has not intervened
with the teacher, therefore this performance is achieved by the student
itself. The metrics so far can be seen with the following lines:

In [None]:
print(f"Teacher accuracy: {test_accuracy_T:.2f}%")
print(f"Student accuracy: {test_accuracy_S_ce:.2f}%")

In [None]:
def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    # Define learning parameters
    learning_rate = 0.0001

    # Define loss function and optimizer
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=0.9)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=Conv1D_T, student=new_Conv1D_S, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_T:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_S_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")