In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Subset
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
dataset_name = 'CIFAR10'

mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

# Data normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

num_tasks = 5

# Load training and testing datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
num_classes = 10
classes_per_subset = num_classes // num_tasks

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 49.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
def split_dataset_by_classes(dataset, classes):
    """
    Split the dataset to only include samples from the specified classes.
    Args:
        dataset: The CIFAR-10 dataset object.
        classes: A list of class indices to include in the subset (e.g., [0, 1, 2]).
    Returns:
        A Subset containing only the samples from the specified classes.
    """
    targets = torch.tensor(dataset.targets)  # Get all labels as a tensor
    # Use `torch.isin` to check if each target is in the `classes` list
    indices = torch.where(torch.isin(targets, torch.tensor(classes)))[0]
    return Subset(dataset, indices)


# Create subsets for all pairs of casses
classes_list = list(range(0, num_classes))
class_pairs = [tuple(classes_list[i:i+classes_per_subset]) for i in range(0, len(classes_list), classes_per_subset)]  # Class pairs (e.g., (0, 1), (2, 3), ...)
train_subsets = {pair: split_dataset_by_classes(train_dataset, pair) for pair in class_pairs}
test_subsets = {pair: split_dataset_by_classes(test_dataset, pair) for pair in class_pairs}

# Example: Access a specific subset
# subset_0_1 = subsets[(0, 1)]
# print(f"Subset for classes {0, 1} size: {len(subset_0_1)}")

In [None]:
# Modify ResNet18 for CIFAR
class ResNet18ForCIFAR(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18ForCIFAR, self).__init__()
        self.model = resnet18(weights=None)  # Initialize ResNet18
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # Adjust for CIFAR
        self.model.maxpool = nn.Identity()  # Remove max pooling layer
        self.model.fc = nn.Linear(512, num_classes)  # Adjust the final layer for CIFAR

    def forward(self, x):
        return self.model(x)

In [None]:
"""
Training procedure for the resnet. Use indexes to keep track of sample
learning speed.
"""
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    total_num_samples = len(train_loader.dataset)
    is_correct = torch.zeros((1, total_num_samples))
    is_correct_idx = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} [Training]", leave=False)
    for batch_idx, (indexes,(inputs, targets)) in enumerate(progress_bar):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.shape[0]
        correct += predicted.eq(targets).sum().item()

        batch_num_samples = targets.shape[0]
        is_correct[:, indexes] = predicted.eq(targets).cpu().float()

        # Update progress bar
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", accuracy=f"{100. * correct / total:.2f}%")


    accuracy = 100. * correct / total
    return total_loss / len(train_loader), accuracy, is_correct

In [None]:
"""
Test procedure of the resnet.
"""
def test(model, test_loader, criterion, device, epoch):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    # Wrap DataLoader with tqdm
    progress_bar = tqdm(test_loader, desc=f"Epoch {epoch} [Testing]", leave=False)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.shape[0]
            correct += predicted.eq(targets).sum().item()

            # Update progress bar with batch loss and accuracy
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", accuracy=f"{100. * correct / total:.2f}%")

    accuracy = 100. * correct / total
    return total_loss / len(test_loader), accuracy

In [None]:
# number of epochs used in our exps
num_epochs = 10
save_per_epochs = 10

In [None]:
"""
This function executes a single experiment (one permutation) num_runs times.
In this case the samples are not sorted based on learning speed.
"""
def run_default_exp(perm, exp_name, num_runs):
  for run in range(num_runs):
    model = ResNet18ForCIFAR(num_classes)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    all_samples_learning_speed = torch.zeros(num_tasks,int(50000/num_tasks))

    print(exp_name, run)
    for task in perm:
      train_subset = train_subsets[class_pairs[task]]
      test_subset = test_subsets[class_pairs[task]]
      train_loader = torch.utils.data.DataLoader(dataset=list(enumerate(train_subset)), batch_size=128, shuffle=True, num_workers=4)
      test_loader = torch.utils.data.DataLoader(dataset=test_subset, batch_size=128, shuffle=False, num_workers=4)

      # For learning speed tracking
      total_num_samples = len(train_loader.dataset)
      M = torch.zeros((num_epochs, total_num_samples))

      print(f"Training task {task}...")
      for epoch in range(num_epochs):
          train_loss, train_acc, is_correct = train(model, train_loader, criterion, optimizer, device, epoch)
          M[epoch] = is_correct
          test_loss, test_acc = test(model, test_loader, criterion, device, epoch)
          scheduler.step()

          print(f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
          print(epoch + 1, save_per_epochs, (epoch + 1) % (save_per_epochs))
          if ((epoch + 1) % (save_per_epochs)) == 0:
            torch.save(model.state_dict(), f'resnet18_{dataset_name.lower()}_task{task}_epoch{epoch}.pth')
            torch.save(optimizer.state_dict(), f'optimizer_{dataset_name.lower()}_task{task}_epoch{epoch}.pth')
            torch.save(scheduler.state_dict(), f'scheduler_{dataset_name.lower()}_task{task}_epoch{epoch}.pth')
            torch.save(M, f'M_{dataset_name.lower()}_{exp_name}_task{task}_epoch{epoch}.pth')
            print(f"Model of task {task} saved at epoch {epoch}.")
      all_samples_learning_speed[task] = M.sum(axis=0)
    torch.save(all_samples_learning_speed,f'ls_{exp_name}_{run}_default.pth')

In [None]:
exp_permutations = [[0,1,2,3,4],[4,3,1,2,0],[1,4,3,0,2],[3,2,0,4,1],[2,0,4,1,3]]
exp_names = ['p0','p1','p2','p3','p4']
num_runs = 5

for p,n in zip(exp_permutations,exp_names):
  run_default_exp(p,n, num_runs)

In [None]:
"""
Compute and print learning speed results. We print the learning speed per task
and run and also the averages per run. The total array finally contians the
averages per experiment.
"""
def run_default_average(perm, exp_names, num_runs):
  total = []
  for idx_1, exp_name in enumerate(exp_names):
    averages = [0] * num_runs
    print(f"experiment {exp_name}")

    for run in range(num_runs):
      print(f"Run: {run}")
      inner_averages = []
      learn_speed = torch.load(f'ls_{exp_name}_{run}_default.pth')
      count = [torch.bincount(learn_speed[i].int()) for i in range(5)]

      # average learning speed:
      for idx_2, t in enumerate(count):
        avg = 0
        for idx_in, el in enumerate(t):
          avg += el * idx_in
        print(f"Learned Task {idx_2} (actual task {perm[idx_1][idx_2]}) avg. learning speed: {avg/t.sum()}")
        inner_averages.append(avg/t.sum())

      # compute speed average for this run
      averages[run] = sum(inner_averages) / len(inner_averages)
    for pr_idx, avg in enumerate(averages):
      print(f"Run {pr_idx} average: {avg}")

    total.append(sum(averages) / len(averages))
  print(total)
      # could either store final res or connect with drive for easier use


In [None]:
exp_permutations = [[0,1,2,3,4],[4,3,1,2,0],[1,4,3,0,2],[3,2,0,4,1],[2,0,4,1,3]]
exp_names = ['p0','p1','p2','p3','p4']

run_default_average(exp_permutations, exp_names, num_runs)

In [None]:
"""
This function executes a single experiment (one permutation) num_runs times.
In this case the samples are sorted based on learning speed. If desc=True, in
descending order, otherwise in ascending order
"""
def run_exp_ord(perm, exp_name, desc, num_runs):
  for run in range(num_runs):
    model = ResNet18ForCIFAR(num_classes)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    all_samples_learning_speed = torch.zeros(num_tasks,int(50000/num_tasks))

    for task in perm:
      print(f"Exp: {desc}, Run: {run}, Task: {task}.")

      M = torch.load(f'M_{dataset_name.lower()}_p0_task{task}_epoch{num_epochs-1}.pth') # load M files containing sample learning speed (last run of 1st exp)
      learning_speed = M.sum(axis=0)
      sorted_learning_speed, sorted_indices = torch.sort(learning_speed, descending=desc)

      train_subset = train_subsets[class_pairs[task]]
      test_subset = test_subsets[class_pairs[task]]
      train_subset_ord = list(enumerate(Subset(train_dataset, [train_subset.indices[i] for i in sorted_indices]))) # sort subset using learning speed
      train_loader = torch.utils.data.DataLoader(dataset=train_subset_ord, batch_size=128, shuffle=False, num_workers=4)
      test_loader = torch.utils.data.DataLoader(dataset=test_subset, batch_size=128, shuffle=False, num_workers=4)

      # For learning speed tracking
      total_num_samples = len(train_loader.dataset)
      M = torch.zeros((num_epochs, total_num_samples))

      print(f"Training task {task}...")
      for epoch in range(num_epochs):
          train_loss, train_acc, is_correct = train(model, train_loader, criterion, optimizer, device, epoch)
          M[epoch] = is_correct
          test_loss, test_acc = test(model, test_loader, criterion, device, epoch)
          scheduler.step()

          suffix1 = "desc" if desc else "asc"

          print(f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
          print(epoch + 1, save_per_epochs, (epoch + 1) % (save_per_epochs))
          if ((epoch + 1) % (save_per_epochs)) == 0:
            torch.save(model.state_dict(), f'resnet18_{dataset_name.lower()}_task{task}_epoch{epoch}_{suffix1}.pth')
            torch.save(optimizer.state_dict(), f'optimizer_{dataset_name.lower()}_task{task}_epoch{epoch}_{suffix1}.pth')
            torch.save(scheduler.state_dict(), f'scheduler_{dataset_name.lower()}_task{task}_epoch{epoch}_{suffix1}.pth')
            torch.save(M, f'M_{dataset_name.lower()}_task{task}_epoch{epoch}_{suffix1}.pth')
            print(f"Model of task {task} saved at epoch {epoch}.")
      all_samples_learning_speed[task] = M.sum(axis=0)
    torch.save(all_samples_learning_speed,f'ls_{exp_name}_{suffix1}_run{run}.pth')

In [None]:
exp_permutations = [[0,1,2,3,4],[4,3,1,2,0],[1,4,3,0,2],[3,2,0,4,1],[2,0,4,1,3]]
exp_names = ['p0','p1','p2','p3','p4']
num_runs = 5

for desc in [True, False]:
  for p,n in zip(exp_permutations,exp_names):
    run_exp_ord(p,n, desc, num_runs)

In [None]:
"""
Compute and print learning speed results. We print the learning speed per task
and run and also the averages per run. The total array finally contians the
averages per experiment.
Desc parameter again indicates which sorting is used.
"""
def run_average(perm, exp_names, desc, num_runs):
  total = []
  for idx_1, exp_name in enumerate(exp_names):
    averages = [0] * num_runs
    print(f"{desc}, experiment {exp_name}")

    for run in range(num_runs):
      print(f"Run: {run}")
      inner_averages = []
      learn_speed = torch.load(f'ls_{exp_name}_{desc}_run{run}.pth')
      count = [torch.bincount(learn_speed[i].int()) for i in range(5)]

      # average learning speed:
      for idx_2, t in enumerate(count):
        avg = 0
        for idx_in, el in enumerate(t):
          avg += el * idx_in
        print(f"Learned Task {idx_2} (actual task {perm[idx_1][idx_2]}) avg. learning speed: {avg/t.sum()}")
        inner_averages.append(avg/t.sum())

      # compute speed average for this run
      averages[run] = sum(inner_averages) / len(inner_averages)
    for pr_idx, avg in enumerate(averages):
      print(f"Run {pr_idx} average: {avg}")

    total.append(sum(averages) / len(averages))
  print(total)
      # could either store final res or connect with drive for easier use


In [None]:
exp_permutations = [[0,1,2,3,4],[4,3,1,2,0],[1,4,3,0,2],[3,2,0,4,1],[2,0,4,1,3]]
exp_names = ['p0','p1','p2','p3','p4']
order = "desc"

run_average(exp_permutations, exp_names, order, num_runs)