# Student and teacher

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [2]:
transforms_mnist = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Loading the MNIST dataset:
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transforms_mnist
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transforms_mnist
)

In [3]:
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, padding=1),  # Changed input channels from 3 to 1
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1568, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [8]:
def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [25]:
def preprocess_for_onn(inputs):
    """
    Convert standard MNIST images to the format expected by ONN:
    - Pad from 28x28 to 200x200
    - Add complex dimension (real and imaginary parts)
    """
    # Remove normalization (optional, depends on your needs)
    # inputs = inputs * 0.3081 + 0.1307
    
    # Pad the images from 28x28 to 200x200
    padded = torch.nn.functional.pad(inputs, (86, 86, 86, 86))
    
    # Convert from [batch, 1, 200, 200] to [batch, 200, 200]
    padded = padded.squeeze(1)
    
    # Add complex dimension (real part is the image, imaginary part is zeros)
    # Shape becomes [batch, 200, 200, 2]
    complex_input = torch.stack((padded, torch.zeros_like(padded)), dim=-1)
    
    return complex_input

In [23]:
def train_onn(model, train_loader, epochs, learning_rate, device):
    """
    Train function specifically for the Optical Neural Network (ONN)
    
    Args:
        model: The ONN model
        train_loader: DataLoader for training data
        epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        device: Device to use for training
    """
    # ONN outputs softmax probabilities directly, so MSE is appropriate
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Preprocess inputs for ONN
            onn_inputs = preprocess_for_onn(inputs)
            
            # Convert labels to one-hot encoding for ONN
            labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=10).float()
            
            optimizer.zero_grad()
            
            # Forward pass - ONN outputs softmax probabilities
            outputs = model(onn_inputs)
            
            # Compute loss between softmax outputs and one-hot labels
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
        
    return model

In [24]:
def test_onn(model, test_loader, device):
    """
    Test function specifically for the Optical Neural Network (ONN)
    
    Args:
        model: The ONN model
        test_loader: DataLoader for test data
        device: Device to use for testing
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Preprocess inputs for ONN
            onn_inputs = preprocess_for_onn(inputs)
            
            # Forward pass
            outputs = model(onn_inputs)
            
            # Get predictions
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [26]:
from onn import Net as ONNNet
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
onn = ONNNet().to(device)

Epoch 1/10, Loss: 0.14108995523657214
Epoch 2/10, Loss: 0.03995176364304319
Epoch 3/10, Loss: 0.025827583994208547
Epoch 4/10, Loss: 0.021019774001848294
Epoch 5/10, Loss: 0.015071856757891036
Epoch 6/10, Loss: 0.01298783389683561
Epoch 7/10, Loss: 0.01244963963113257
Epoch 8/10, Loss: 0.010715614126243082
Epoch 9/10, Loss: 0.00935895531034933
Epoch 10/10, Loss: 0.007573059605796194
Test Accuracy: 99.26%


In [29]:
torch.manual_seed(42)
onn= ONNNet().to(device)
train_onn(onn, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test_onn(onn, test_loader, device)

Epoch 1/10, Loss: 0.0889, Accuracy: 68.98%
Epoch 2/10, Loss: 0.0869, Accuracy: 70.92%
Epoch 3/10, Loss: 0.0858, Accuracy: 70.50%
Epoch 4/10, Loss: 0.0851, Accuracy: 70.45%
Epoch 5/10, Loss: 0.0846, Accuracy: 70.54%
Epoch 6/10, Loss: 0.0843, Accuracy: 70.59%
Epoch 7/10, Loss: 0.0841, Accuracy: 70.72%
Epoch 8/10, Loss: 0.0839, Accuracy: 70.76%
Epoch 9/10, Loss: 0.0838, Accuracy: 70.88%
Epoch 10/10, Loss: 0.0837, Accuracy: 70.83%
Test Accuracy: 72.09%


In [30]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 99.26%
Student accuracy: 72.09%


In [31]:
def train_knowledge_distillation_onn(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model (using original inputs)
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Preprocess inputs for ONN
            onn_inputs = preprocess_for_onn(inputs)
            
            # Forward pass with the student model
            student_logits = student(onn_inputs)

            # Soften the teacher outputs
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            
            # KL divergence loss - note that student_logits is already softmaxed in ONN
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - torch.log(student_logits))) / student_logits.size()[0] * (T**2)

            # Calculate the true label loss - convert labels to one-hot format for ONN
            label_one_hot = torch.nn.functional.one_hot(labels, num_classes=10).float()
            label_loss = torch.nn.functional.mse_loss(student_logits, label_one_hot)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [None]:
new_onn = ONNNet().to(device)
train_knowledge_distillation_onn(teacher=nn_deep, student=new_onn, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test_onn(new_onn, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Test Accuracy: 77.18%
Teacher accuracy: 99.26%
Student accuracy without teacher: 72.09%
Student accuracy with CE + KD: 77.18%
