In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# Data params
input_dim = 784
output_dim = 10

# Network params
n_hidden_units = 256
activation_fn = nn.ReLU()

# Optimization params
batch_size = 64
epochs_per_task = 10

class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, n_hidden_units)
        self.relu1 = activation_fn
        self.fc2 = nn.Linear(n_hidden_units, n_hidden_units)
        self.relu2 = activation_fn
        self.fc3 = nn.Linear(n_hidden_units, output_dim)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        return self.fc3(x)

def get_split_mnist(task_labels, split='train'):
    if split == 'train':
        mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ]))
    else:
        mnist = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ]))

    task_datasets = []
    for labels in task_labels:
        idx = [i for i in range(len(mnist.targets)) if mnist.targets[i] in labels]
        task_datasets.append(Subset(mnist, idx))
    return task_datasets

task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]
train_datasets = get_split_mnist(task_labels, split='train')
test_datasets = get_split_mnist(task_labels, split='test')

model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training and evaluation function
def train_and_evaluate(train_dataset, test_dataset, model, optimizer, criterion):
    model.train()
    for epoch in range(epochs_per_task):
        for data, target in DataLoader(train_dataset, batch_size=batch_size, shuffle=True):
            data = data.view(-1, 784)  # Flatten the images
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in DataLoader(test_dataset, batch_size=batch_size, shuffle=False):
            data = data.view(-1, 784)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataset)
    accuracy = 100. * correct / len(test_dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

# Running the training and evaluation for each task
for train_dataset, test_dataset in zip(train_datasets, test_datasets):
    train_and_evaluate(train_dataset, test_dataset, model, optimizer, criterion)
