In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm

In [20]:
class SimpleNN(nn.Module):
    def __init__(self, input_size=784, hidden_size=100, num_classes=10):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.view(-1, 784)  #Flatten the input
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# EWC Implementation
class EWC:
    def __init__(self, model, dataloader, device='cpu'):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {n: p.clone().detach() for n, p in self.params.items()}
        self._fisher = self._compute_fisher_information()

    def _compute_fisher_information(self):
        fisher = {n: torch.zeros_like(p) for n, p in self.params.items()}
        self.model.eval()
        for x, y in self.dataloader:
            x, y = x.to(self.device), y.to(self.device)
            self.model.zero_grad()
            output = self.model(x)
            loss = nn.functional.cross_entropy(output, y)
            loss.backward()
            for n, p in self.params.items():
                fisher[n] += p.grad ** 2 / len(self.dataloader)
        fisher = {n: p.clone().detach() for n, p in fisher.items()}
        return fisher

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if n in self._fisher:
                _loss = self._fisher[n] * (p - self._means[n]) ** 2
                loss += _loss.sum()
        return loss

# Training function with EWC
def train(model, dataloader, optimizer, ewc=None, ewc_lambda=0):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = nn.functional.cross_entropy(output, y)
        if ewc:
            loss += ewc_lambda * ewc.penalty(model)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            _, pred = output.max(1)
            correct += pred.eq(y).sum().item()
    return correct / len(dataloader.dataset)

# Dataset and model setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
fashion_mnist = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
mnist_loader = DataLoader(mnist, batch_size=64, shuffle=True)
fashion_loader = DataLoader(fashion_mnist, batch_size=64, shuffle=True)

# Model, optimizer, and EWC setup
model = SimpleNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

def reset_model_weights(model):
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

# Train on Task 1 (MNIST) without EWC
print("Training on MNIST (Task 1)...")
reset_model_weights(model)  # Ensure model starts from scratch
train(model, mnist_loader, optimizer)
mnist_acc_initial = evaluate(model, mnist_loader)
print(f"Initial Accuracy on MNIST after training: {mnist_acc_initial:.2f}")

# Save the EWC data for MNIST for later use
ewc_mnist = EWC(model, mnist_loader, device=device)

# Train on Task 2 (Fashion-MNIST) without EWC
print("\nTraining on Fashion-MNIST (Task 2) without EWC...")
reset_model_weights(model)  #Reset weights
optimizer = optim.SGD(model.parameters(), lr=0.01)
train(model, fashion_loader, optimizer)  
fashion_mnist_acc_no_ewc = evaluate(model, fashion_loader)
mnist_acc_no_ewc = evaluate(model, mnist_loader)
print(f"Accuracy on Fashion-MNIST (Task 2) without EWC: {fashion_mnist_acc_no_ewc:.2f}")
print(f"Retained Accuracy on MNIST without EWC: {mnist_acc_no_ewc:.2f}")

# Train on Task 2 (Fashion-MNIST) with EWC
ewc_lambda = 0.4
print("\nTraining on Fashion-MNIST (Task 2) with EWC...")
reset_model_weights(model)  #Reset weights
optimizer = optim.SGD(model.parameters(), lr=0.01)
train(model, mnist_loader, optimizer)  #First, train on MNIST to set up EWC
train(model, fashion_loader, optimizer, ewc=ewc_mnist, ewc_lambda=ewc_lambda)  #Train with EWC
fashion_mnist_acc_with_ewc = evaluate(model, fashion_loader)
mnist_acc_with_ewc = evaluate(model, mnist_loader)
print(f"Accuracy on Fashion-MNIST with EWC: {fashion_mnist_acc_with_ewc:.2f}")
print(f"Retained Accuracy on MNIST with EWC: {mnist_acc_with_ewc:.2f}")

# Summary of Results
print("\n--- Comparison Summary ---")
print(f"MNIST Accuracy after Fashion-MNIST (without EWC): {mnist_acc_no_ewc:.2f}")
print(f"MNIST Accuracy after Fashion-MNIST (with EWC): {mnist_acc_with_ewc:.2f}")
print(f"Fashion-MNIST Accuracy (without EWC): {fashion_mnist_acc_no_ewc:.2f}")
print(f"Fashion-MNIST Accuracy (with EWC): {fashion_mnist_acc_with_ewc:.2f}")


Training on MNIST (Task 1)...


100%|██████████| 938/938 [00:02<00:00, 389.56it/s]


Initial Accuracy on MNIST after training: 0.88

Training on Fashion-MNIST (Task 2) without EWC...


100%|██████████| 938/938 [00:02<00:00, 405.91it/s]


Accuracy on Fashion-MNIST (Task 2) without EWC: 0.80
Retained Accuracy on MNIST without EWC: 0.09

Training on Fashion-MNIST (Task 2) with EWC...


100%|██████████| 938/938 [00:02<00:00, 313.48it/s]
100%|██████████| 938/938 [00:02<00:00, 333.21it/s]


Accuracy on Fashion-MNIST with EWC: 0.81
Retained Accuracy on MNIST with EWC: 0.52

--- Comparison Summary ---
MNIST Accuracy after Fashion-MNIST (without EWC): 0.09
MNIST Accuracy after Fashion-MNIST (with EWC): 0.52
Fashion-MNIST Accuracy (without EWC): 0.80
Fashion-MNIST Accuracy (with EWC): 0.81
