In [11]:
import torch
from torch.utils.data import DataLoader
import torchvision

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

# Define MLP with probes
class MLPWithProbes(nn.Module):
    def __init__(self, input_size=784, hidden_sizes=[512, 256, 128, 64, 32], num_classes=10):
        super(MLPWithProbes, self).__init__()
        self.layers = nn.ModuleList()
        self.probes = nn.ModuleList()
        
        # Create main layers
        prev_size = input_size
        for hidden_size in hidden_sizes:
            self.layers.append(nn.Linear(prev_size, hidden_size))
            self.layers.append(nn.ReLU())
            # Add probe for each layer
            self.probes.append(nn.Linear(hidden_size, num_classes))
            prev_size = hidden_size
            
        # Output layer
        self.layers.append(nn.Linear(prev_size, num_classes))

    def forward(self, x):
        probe_outputs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Add probe output after each hidden layer (before output layer)
            if isinstance(layer, nn.Linear) and i < len(self.layers) - 1:
                probe_outputs.append(self.probes[len(probe_outputs)](x))
        
        return x, probe_outputs

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize model and optimizer
model = MLPWithProbes()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Training loop
def train_epoch():
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 784)
        optimizer.zero_grad()
        
        # Forward pass
        main_output, probe_outputs = model(data)
        
        # Calculate losses
        main_loss = criterion(main_output, target)
        probe_losses = [criterion(probe_out, target) for probe_out in probe_outputs]
        
        # Total loss is main loss plus probe losses
        total_loss = main_loss + sum(probe_losses)
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}: Loss = {total_loss.item():.4f}')

# Train for one epoch
train_epoch()

Batch 0: Loss = 14.0095
Batch 100: Loss = 1.1597
Batch 200: Loss = 1.4180
Batch 300: Loss = 1.3825
Batch 400: Loss = 1.5685
Batch 500: Loss = 1.3616
Batch 600: Loss = 1.1849
Batch 700: Loss = 0.8189
Batch 800: Loss = 0.2472
Batch 900: Loss = 1.4644


In [12]:
def evaluate_probes():
    model.eval()
    correct_main = 0
    correct_probes = [0] * len(model.probes)
    total = 0
    
    with torch.no_grad():
        for data, target in train_loader:
            data = data.view(-1, 784)
            main_output, probe_outputs = model(data)
            
            # Main classifier accuracy
            pred = main_output.argmax(dim=1)
            correct_main += pred.eq(target).sum().item()
            
            # Probe accuracies
            for i, probe_output in enumerate(probe_outputs):
                pred = probe_output.argmax(dim=1)
                correct_probes[i] += pred.eq(target).sum().item()
            
            total += target.size(0)
    
    # Print accuracies
    print(f"Main classifier accuracy: {100. * correct_main / total:.2f}%")
    for i, correct in enumerate(correct_probes):
        print(f"Probe {i+1} accuracy: {100. * correct / total:.2f}%")

evaluate_probes()

Main classifier accuracy: 96.56%
Probe 1 accuracy: 90.06%
Probe 2 accuracy: 96.55%
Probe 3 accuracy: 96.57%
Probe 4 accuracy: 96.53%
Probe 5 accuracy: 96.51%
