### LeNet5 1998 Implementation

#### 1. Calculating the mean and variance of the MNIST data.

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from PIL import Image
from data import splits, df_train, df_test

# Convert 'image' column to NumPy arrays by extracting pixel data
df_train['image'] = df_train['image'].apply(lambda x: np.array(x['bytes']) if isinstance(x, dict) and 'bytes' in x else np.array(x))
df_test['image'] = df_test['image'].apply(lambda x: np.array(x['bytes']) if isinstance(x, dict) and 'bytes' in x else np.array(x))

# Create a custom dataset class to handle this data
class CustomMNISTDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get the image data (assuming it's stored in columns 'image')
        image = self.dataframe.iloc[idx]['image']
        label = self.dataframe.iloc[idx]['label']
        
        # Convert the image (which is assumed to be a 28x28 numpy array) to a PIL image
        image = Image.fromarray(image.astype(np.uint8), mode='L')
        
        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define the transformation (normalize according to the given formula)
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor (scales to [0, 1])
    transforms.Lambda(lambda x: x * 1.275 - 0.1)  # Normalize as per the given formula
])

# Create dataset and dataloaders
train_dataset = CustomMNISTDataset(dataframe=df_train, transform=transform)
test_dataset = CustomMNISTDataset(dataframe=df_test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)


all_pixels = []

# Use the train_loader (not digit_loader) since this is the dataset you trained on
for images, _ in train_loader:
    pixels = images.view(images.size(0), -1)  # Flatten images
    all_pixels.append(pixels)

all_pixels = torch.cat(all_pixels, dim=0)  # (N, 28*28) or (N, 32*32) if resized

# Compute mean and std
mean = all_pixels.mean()
std = all_pixels.std()

print(f"Recalculated Mean (train_loader): {mean.item():.6f}")
print(f"Recalculated Std (train_loader): {std.item():.6f}")


(28, 28)
Recalculated Mean (train_loader): 0.066592
Recalculated Std (train_loader): 0.392837


#### 2. Saving the MNIST data to data folder.

In [None]:
import os
from PIL import Image

# Create the data directory and subdirectories
os.makedirs("./data/train", exist_ok=True)
os.makedirs("./data/test", exist_ok=True)

# Save training images and labels
with open("./data/train_label.txt", "w") as train_label_file:
    for idx, row in df_train.iterrows():
        # Save the image
        image = Image.fromarray(row['image'].astype('uint8'))
        image.save(f"./data/train/{idx}.png")
        
        # Save the label
        train_label_file.write(f"{row['label']}\n")

# Save testing images and labels
with open("./data/test_label.txt", "w") as test_label_file:
    for idx, row in df_test.iterrows():
        # Save the image
        image = Image.fromarray(row['image'].astype('uint8')) 
        image.save(f"./data/test/{idx}.png")
        
        # Save the label
        test_label_file.write(f"{row['label']}\n")

#### 3. Building the LeNet5 1998 model.

In [146]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import random
import numpy as np
from collections import defaultdict

random.seed(42)

class LeNet5_S2Layer(nn.Module):
    def __init__(self, num_channels):
        super(LeNet5_S2Layer, self).__init__()
        self.coefficient = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        pooled = nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        pooled = pooled * self.coefficient.view(1, -1, 1, 1)
        pooled = pooled + self.bias.view(1, -1, 1, 1)
        return torch.sigmoid(pooled)

class ScaledTanh(nn.Module):
    def forward(self, x):
        return 1.7159 * torch.tanh(x * 2 / 3)

class SquashingFunction(nn.Module):
    def __init__(self, A=1.7159, S=2/3):
        super(SquashingFunction, self).__init__()
        self.A = A
        self.S = S

    def forward(self, x):
        x = self.S * x
        x = torch.clamp(x, min=-0.999, max=0.999)  # avoid NaNs from atanh
        return self.A * torch.atanh(x)

class C3PartialConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, connection_table):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.connection_table = connection_table

        # All connections are initialized, then masked
        self.weight = nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels))
        self.mask = torch.zeros_like(self.weight)

        # Build the binary mask
        for out_c, in_list in enumerate(connection_table):
            for in_c in in_list:
                self.mask[out_c, in_c, :, :] = 1.0

        # Initialize only the allowed weights
        self.reset_parameters()

    def reset_parameters(self):
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        nn.init.uniform_(self.weight, -2.4 / fan_in, 2.4 / fan_in)
        self.bias.data.fill_(2.4 / fan_in)

    def forward(self, x):
        # Apply mask before convolution to zero-out unwanted connections
        masked_weight = self.weight * self.mask.to(self.weight.device)
        return F.conv2d(x, masked_weight, self.bias, stride=1)

class DigitDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Loop through all class labels (0-9)
        for label in range(10):
            class_dir = os.path.join(image_dir, str(label))
            for image_name in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image_name)
                self.image_paths.append(image_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(img_path).convert('L')  # Convert to grayscale ('L' mode)

        # Apply transformation (resize, tensor conversion, normalization)
        if self.transform:
            image = self.transform(image)

        return image, label
        
class EuclideanRBFOutput(nn.Module):
    def __init__(self, num_classes=10, input_dim=84):
        super(EuclideanRBFOutput, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, input_dim)) 

    def forward(self, x):
        x = x.unsqueeze(1)
        centers = self.centers.unsqueeze(0)
        distances = torch.sum((x - centers) ** 2, dim=2)
        return -distances

class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5, self).__init__()

        # C1 and S2
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            ScaledTanh(),
            LeNet5_S2Layer(6)
        )

        # C3 and S4
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            ScaledTanh(),
            LeNet5_S2Layer(16)
        )

        # C5
        self.layer3 = C3PartialConv(in_channels=16, out_channels=120, kernel_size=5, connection_table=[random.sample(range(16), 5) for _ in range(120)])

        # F6
        self.fc = nn.Linear(120, 84)
        self.squashing = SquashingFunction(A=1.0, S=1.0)

        # RBF Output layer
        self.rbf_output = EuclideanRBFOutput(num_classes=num_classes, input_dim=84)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.squashing(out)
        out = self.rbf_output(out)
        return out

    def extract_features(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.squashing(x)
        return x

def extract_features_from_digit_dataset(model, digit_loader, device):
    model.eval()
    features_by_class = defaultdict(list)
    
    with torch.no_grad():
        for images, labels in digit_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            feats = model.extract_features(images)
            
            for f, label in zip(feats, labels):
                features_by_class[label.item()].append(f.cpu().numpy())
    
    return features_by_class

# computing RBF parameters by finding the mean of the features for each class
def compute_rbf_centers(features_by_class, num_classes, input_dim):
    centers = np.zeros((num_classes, input_dim), dtype=np.float32)
    
    for cls in range(num_classes):
        class_feats = np.stack(features_by_class[cls])
        centers[cls] = np.mean(class_feats, axis=0)
    
    return torch.tensor(centers, dtype=torch.float32)
    
class SafeTensorResize:
    def __init__(self, size):
        self.size = size

    def __call__(self, img):
        # If it's a PIL Image, convert it to tensor
        if isinstance(img, Image.Image):
            img = transforms.ToTensor()(img)

        # Resize
        return F.interpolate(img.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False).squeeze(0)

# Final unified transform
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (1 channel)
    SafeTensorResize((32, 32)),
    transforms.Normalize(mean=(0.066592,), std=(0.392837,))
])

class CustomLeNetLoss(nn.Module):
    def __init__(self, j=0.1):
        super(CustomLeNetLoss, self).__init__()
        self.j = j  # Scaling factor for incorrect classes

    def forward(self, logits, labels):
        # Apply softmax to logits to get probabilities
        probs = F.softmax(logits, dim=1)
        
        num_classes = logits.size(1)

        # Get the log probabilities (log(p))
        log_probs = F.log_softmax(logits, dim=1)
        
        # Extract the probabilities for the correct class
        correct_class_log_probs = log_probs.gather(1, labels.view(-1, 1))
        
        # Loss for the correct class (negative log probability)
        correct_class_loss = -correct_class_log_probs
        
        # Loss for the incorrect classes (scaled by j)
        incorrect_class_log_probs = log_probs.clone()
        incorrect_class_log_probs.scatter_(1, labels.view(-1, 1), 0)
        
        # Apply the scaling factor j to the incorrect class log probabilities
        incorrect_class_loss = self.j * torch.sum(torch.exp(incorrect_class_log_probs) * log_probs, dim=1, keepdim=True)
        
        # Combine the losses: correct class loss + incorrect class loss
        loss = correct_class_loss + incorrect_class_loss
        
        return loss.mean()

#### 4. Loading in the MNIST data to split into training and testing data sets, and loading in the Digits data to calculate the RBF parameters.

In [140]:
# Training and testing the model on MNIST data:

from mnist import MNIST

# Define relevant variables
batch_size = 1
num_classes = 10
learning_rate = 0.001
num_epochs = 30

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load training and testing datasets
train_dataset = MNIST(split="train", transform=transform)
test_dataset = MNIST(split="test", transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("data loaded")

data loaded


In [147]:
model = LeNet5(num_classes).to(device)

image_dir = './digits updated/'

digit_dataset = DigitDataset(image_dir=image_dir, transform=transform)
digit_loader = DataLoader(digit_dataset, batch_size=32, shuffle=False)

features_by_class = extract_features_from_digit_dataset(model, digit_loader, device)

# Compute RBF centers using the features from DIGIT
rbf_centers = compute_rbf_centers(features_by_class, num_classes=10, input_dim=84)

# Set the computed centers into the model
model.rbf_output.centers.data = rbf_centers.to(device)

#### 5. Training the model on MNIST data subset.

In [109]:
# Training the model on MNIST data:

features_by_class = extract_features_from_digit_dataset(model, digit_loader, device)

# Initialize RBF centers from class means
rbf_centers = compute_rbf_centers(features_by_class, num_classes=10, input_dim=84)
model.rbf_output.centers.data = rbf_centers.to(device)

cost = CustomLeNetLoss(j=0.1)

# Setting the optimizer with the model parameters and learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)

for epoch in range(num_epochs):
    model.train()
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = cost(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Tracking accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (i+1) % len(train_loader) == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Accuracy: {accuracy:.2f}%')

Epoch [1/30], Step [157/157], Loss: 0.7502
Epoch [1/30], Training Accuracy: 29.25%
Epoch [2/30], Step [157/157], Loss: 0.6044
Epoch [2/30], Training Accuracy: 75.56%
Epoch [3/30], Step [157/157], Loss: 0.0930
Epoch [3/30], Training Accuracy: 84.66%
Epoch [4/30], Step [157/157], Loss: 0.7556
Epoch [4/30], Training Accuracy: 88.05%
Epoch [5/30], Step [157/157], Loss: 0.0198
Epoch [5/30], Training Accuracy: 90.83%
Epoch [6/30], Step [157/157], Loss: 0.4451
Epoch [6/30], Training Accuracy: 91.40%
Epoch [7/30], Step [157/157], Loss: 0.1116
Epoch [7/30], Training Accuracy: 92.93%
Epoch [8/30], Step [157/157], Loss: 0.0274
Epoch [8/30], Training Accuracy: 93.42%
Epoch [9/30], Step [157/157], Loss: 0.2799
Epoch [9/30], Training Accuracy: 94.18%
Epoch [10/30], Step [157/157], Loss: 0.0498
Epoch [10/30], Training Accuracy: 94.69%
Epoch [11/30], Step [157/157], Loss: 0.0962
Epoch [11/30], Training Accuracy: 94.78%
Epoch [12/30], Step [157/157], Loss: 0.1251
Epoch [12/30], Training Accuracy: 95.03

#### 6. Testing the model on MNIST data subset produces ~97% accuracy.

In [111]:
# Testing the model on MNIST data:
import importlib
import mnist
importlib.reload(mnist)
from mnist import MNIST

model.eval()  # Set the model to evaluation mode

with torch.no_grad():
    correct = 0
    total = 0
    idx = 0
    for images, labels in test_loader:
        idx += 1
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch [{idx}/{num_epochs}], Testing Accuracy: {accuracy:.2f}%')

Epoch [1/30], Testing Accuracy: 96.88%
Epoch [2/30], Testing Accuracy: 96.09%
Epoch [3/30], Testing Accuracy: 97.40%
Epoch [4/30], Testing Accuracy: 97.27%
Epoch [5/30], Testing Accuracy: 97.50%
Epoch [6/30], Testing Accuracy: 96.88%
Epoch [7/30], Testing Accuracy: 96.88%
Epoch [8/30], Testing Accuracy: 97.07%
Epoch [9/30], Testing Accuracy: 96.53%
Epoch [10/30], Testing Accuracy: 96.25%
Epoch [11/30], Testing Accuracy: 96.02%
Epoch [12/30], Testing Accuracy: 95.70%
Epoch [13/30], Testing Accuracy: 95.91%
Epoch [14/30], Testing Accuracy: 95.87%
Epoch [15/30], Testing Accuracy: 95.62%
Epoch [16/30], Testing Accuracy: 95.51%
Epoch [17/30], Testing Accuracy: 95.59%
Epoch [18/30], Testing Accuracy: 95.40%
Epoch [19/30], Testing Accuracy: 95.31%
Epoch [20/30], Testing Accuracy: 95.00%
Epoch [21/30], Testing Accuracy: 95.01%
Epoch [22/30], Testing Accuracy: 95.10%
Epoch [23/30], Testing Accuracy: 95.11%
Epoch [24/30], Testing Accuracy: 95.05%
Epoch [25/30], Testing Accuracy: 95.00%
Epoch [26