In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Data Preparation
def construct_split_mnist(task_labels, split='train'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = datasets.MNIST(root='./data', train=(split=='train'), download=True, transform=transform)
    task_datasets = []
    for labels in task_labels:
        idxs = [i for i, (_, label) in enumerate(dataset) if label in labels]
        task_dataset = Subset(dataset, idxs)
        task_datasets.append(task_dataset)
    return task_datasets

In [None]:

# Neural Network Definition
class SimpleNN(nn.Module):
    def __init__(self, output_dim):
        super(SimpleNN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.model(x)

In [None]:


# Masked Softmax
def masked_softmax(logits, mask):
    masked_logits = torch.where(mask.bool(), logits, torch.tensor(-1e32, device=logits.device))
    return F.softmax(masked_logits, dim=-1)

# Training and Testing Loops
def train_and_evaluate(model, task_datasets, task_labels, device):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    output_mask = torch.zeros(10, requires_grad=False, device=device)

    for task_idx, (train_dataset, test_dataset) in enumerate(task_datasets):
        labels = torch.tensor(task_labels[task_idx], device=device)
        new_mask = torch.zeros(10, device=device)
        new_mask[labels] = 1.0
        output_mask[:] = new_mask
        
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

        for epoch in range(10):
            total_loss = 0
            correct = 0
            total = 0
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(targets).sum().item()
                total += targets.size(0)
            train_acc = 100. * correct / total
            print(f'Task {task_idx}, Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%')

        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                correct += predicted.eq(targets).sum().item()
                total += targets.size(0)
            test_acc = 100. * correct / total
            print(f'Task {task_idx}, Test Acc: {test_acc:.2f}%')

In [None]:

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN(output_dim=10).to(device)

task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]
training_datasets = [construct_split_mnist([labels], 'train')[0] for labels in task_labels]
testing_datasets = [construct_split_mnist([labels], 'test')[0] for labels in task_labels]
task_datasets = list(zip(training_datasets, testing_datasets))

train_and_evaluate(model, task_datasets, task_labels, device)