In [1]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

from neurodivergence.models import FCNN
from neurodivergence.compare import kl_divergence_between_models, sup_norm_outputs_between_models

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [3]:
def get_mnist_data():
    train_ds = datasets.MNIST(
        root="data", train=True, download=True, transform=transforms.ToTensor()
    )
    test_ds = datasets.MNIST(
        root="data", train=False, download=True, transform=transforms.ToTensor()
    )

    return train_ds, test_ds


In [4]:
# Create a function for training the neural network
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

In [5]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [6]:
train_ds, test_ds = get_mnist_data()

# https://pytorch.org/docs/stable/notes/randomness.html#dataloader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(1)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2, worker_init_fn=seed_worker, generator=g)


test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

In [7]:
# set_seed(42)

In [8]:
# Prepare the data

# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the neural network, criterion and optimizer
model = FCNN(seed=42).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Train the neural network for 2 epochs
num_epochs = 2
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(model, test_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")



Epoch 1/2, Loss: 0.5317, Test Accuracy: 93.61%
Epoch 2/2, Loss: 0.1969, Test Accuracy: 95.33%


In [9]:
# Create the neural network, criterion and optimizer
model2 = FCNN(seed=43).to(device)
# g = torch.Generator()
g.manual_seed(1)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2, worker_init_fn=seed_worker, generator=g)

optimizer = optim.SGD(model2.parameters(), lr=0.01, momentum=0.9)

# Train the neural network for 2 epochs
num_epochs = 2
for epoch in range(num_epochs):
    train_loss = train(model2, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(model2, test_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

Epoch 1/2, Loss: 0.5482, Test Accuracy: 93.15%
Epoch 2/2, Loss: 0.1999, Test Accuracy: 95.23%


In [10]:
error = kl_divergence_between_models(model, model2, test_loader, device)
error
# -2.2126533139044115e-11  # identical NN initializations
# 0.017387873356044293     # different NN initializations


0.0176932356223464

In [16]:
error = kl_divergence_between_models(model, model2, train_loader, device)
error

0.016638632949690025

In [11]:
sup_norm = sup_norm_outputs_between_models(model, model2, test_loader, device)
sup_norm

0.5211323499679565

In [15]:
sup_norm = sup_norm_outputs_between_models(model, model2, train_loader, device)
sup_norm

0.6663831472396851

In [12]:
# Save the model
# model_path = 'mnist_fcnn.pth'
# torch.save(model.state_dict(), model_path)


In [13]:
# Load the model
# model_path = 'mnist_fcnn.pth'
# loaded_model = FCNN()
# loaded_model.load_state_dict(torch.load(model_path))
# loaded_model.to(device)
