### LeNet5 1998 Implementation

#### 1. Calculating the mean and variance of the MNIST data.

In [None]:
import numpy as np
import torch
from torchvision import datasets, transforms

# MNIST data loader
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor (scales to [0, 1])
    transforms.Lambda(lambda x: x * 1.275 - 0.1)  # Normalize as per the given formula
])

# Load the MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Check the mean and variance of the transformed data
data_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
all_data = []

for images, _ in data_loader:
    all_data.append(images.view(images.size(0), -1))

all_data = torch.cat(all_data, dim=0)
mean = all_data.mean()
variance = all_data.var()

# Mean and Variance to use for MNIST data:
print(f'Mean: {mean.item()}')
print(f'Variance: {variance.item()}')

#### 2. Building the LeNet5 1998 model.

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import random
from collections import defaultdict

random.seed(42)

class LeNet5_S2Layer(nn.Module):
    def __init__(self, num_channels):
        super(LeNet5_S2Layer, self).__init__()
        self.coefficient = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        pooled = nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        pooled = pooled * self.coefficient.view(1, -1, 1, 1)
        pooled = pooled + self.bias.view(1, -1, 1, 1)
        return torch.sigmoid(pooled)

class ScaledTanh(nn.Module):
    def forward(self, x):
        return 1.7159 * torch.tanh(x * 2 / 3)

class SquashingFunction(nn.Module):
    def __init__(self, A=1.7159, S=2/3):
        super(SquashingFunction, self).__init__()
        self.A = A
        self.S = S

    def forward(self, x):
        x = self.S * x
        x = torch.clamp(x, min=-0.999, max=0.999)  # avoid NaNs from atanh
        return self.A * torch.atanh(x)

class C3PartialConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, connection_table):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.connection_table = connection_table

        # All connections are initialized, then masked
        self.weight = nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels))
        self.mask = torch.zeros_like(self.weight)

        # Build the binary mask
        for out_c, in_list in enumerate(connection_table):
            for in_c in in_list:
                self.mask[out_c, in_c, :, :] = 1.0

        # Initialize only the allowed weights
        self.reset_parameters()

    def reset_parameters(self):
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        nn.init.uniform_(self.weight, -2.4 / fan_in, 2.4 / fan_in)
        self.bias.data.fill_(2.4 / fan_in)

    def forward(self, x):
        # Apply mask before convolution to zero-out unwanted connections
        masked_weight = self.weight * self.mask.to(self.weight.device)
        return F.conv2d(x, masked_weight, self.bias, stride=1)

class EuclideanRBFOutput(nn.Module):
    def __init__(self, num_classes=10, input_dim=84):
        super(EuclideanRBFOutput, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, input_dim))  # trainable prototypes

    def forward(self, x):
        x = x.unsqueeze(1)  # (batch_size, 1, input_dim)
        centers = self.centers.unsqueeze(0)  # (1, num_classes, input_dim)
        distances = torch.sum((x - centers) ** 2, dim=2)  # (batch_size, num_classes)
        return -distances  # negative distances for CrossEntropyLoss

class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5, self).__init__()

        # C1 and S2
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            ScaledTanh(),
            LeNet5_S2Layer(6)
        )

        # C3 and S4
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            ScaledTanh(),
            LeNet5_S2Layer(16)
        )

        # C5
        self.layer3 = C3PartialConv(in_channels=16, out_channels=120, kernel_size=5, connection_table=[random.sample(range(16), 5) for _ in range(120)])

        # F6
        self.fc = nn.Linear(120, 84)
        self.squashing = SquashingFunction(A=1.0, S=1.0)

        # RBF Output layer
        self.rbf_output = EuclideanRBFOutput(num_classes=num_classes, input_dim=84)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.squashing(out)
        out = self.rbf_output(out)
        return out

    def extract_features(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.squashing(x)
        return x

def compute_rbf_centers(model, dataloader, device, num_classes, input_dim):
    model.eval()
    features_by_class = defaultdict(list)

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            feats = model.extract_features(images)

            for f, label in zip(feats, labels):
                features_by_class[label.item()].append(f.cpu().numpy())

    centers = np.zeros((num_classes, input_dim), dtype=np.float32)
    for cls in range(num_classes):
        class_feats = np.stack(features_by_class[cls])
        centers[cls] = np.mean(class_feats, axis=0)

    return torch.tensor(centers, dtype=torch.float32)


#### 3. Loading in the MNIST data and splitting into training and testing data sets.

In [None]:
# Training and testing the model on MNIST data:

# Define relevant variables
batch_size = 64
num_classes = 10
learning_rate = 0.001
num_epochs = 10

# Device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading the dataset and preprocessing
train_dataset = torchvision.datasets.MNIST(
    root = './data',
    train = True,
    transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor(),
            transforms.Normalize(mean = (0.06659211218357086,), std = (0.15432125329971313,))]),
    download = True
)


test_dataset = torchvision.datasets.MNIST(
    root = './data',
    train = False,
    transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor(),
            transforms.Normalize(mean = (0.06659211218357086,), std = (0.15432125329971313,))]),
    download = False
)


train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = False
)


test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = batch_size,
    shuffle = False
)

print("data loaded")

#### 4. Training the model on MNIST data subset.

In [None]:
# Training the model on MNIST data:

model = LeNet5(num_classes).to(device)

# Initialize RBF centers from class means
rbf_centers = compute_rbf_centers(model, train_loader, device, num_classes, input_dim=84)
model.rbf_output.centers.data = rbf_centers.to(device)

cost = nn.CrossEntropyLoss()

# Setting the optimizer with the model parameters and learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        #Forward pass
        outputs = model(images)
        loss = cost(outputs, labels)

        #Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 400 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

#### 5. Testing the model on MNIST data subset produces ~97% accuracy.

In [None]:
# Testing the model on MNIST data:

model.eval()  # Set the model to evaluation mode

with torch.no_grad():
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')