# Continual Learning Analysis - Dataset Agnostic

Compare baseline vs. PSP models on arbitrary datasets and task transforms.

Features:
- Works with any torchvision dataset (MNIST, CIFAR-10, FashionMNIST, etc.)
- Supports multiple task transforms (Permutation, Rotation, Class-Incremental)
- Trains both baseline and context-aware models
- Visualizes accuracy curves and hidden representations via PCA


## Setup: Path Configuration

In [1]:
import os
import sys

repo_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
src_path = os.path.join(repo_root, "src")

if src_path not in sys.path:
    sys.path.insert(0, src_path)

print("Added to sys.path:", src_path)

Added to sys.path: /Users/lorenzleisner/Desktop/CogSci/Master/WI_SE_25/hands-on-neuroai/src


## Import Required Libraries

In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

from hands_on_neuroai.data.datasets import (
    DatasetConfig, get_image_shape, get_num_classes,
    build_task_datasets, build_dataloaders,
    PermutePixels, Rotate, ComposeTaskTransforms
)
from hands_on_neuroai.training.continual_learning import (
    train_model_on_continual_learning_tasks
)
from hands_on_neuroai.training.metrics import collect_hidden_activations
from hands_on_neuroai.models.factory import build_model_for_continual_learning

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cpu


## Configuration: Dataset, Tasks, and Model Parameters

Customize these settings to experiment with different datasets, task transforms, and models.

In [3]:
# ============ Dataset & Data Configuration ============
dataset_name = "mnist"  # Options: "mnist", "cifar10", "cifar100", "fashionmnist"
data_root = "data"

# ============ Task Configuration ============
num_tasks = 5
task_transform_type = "permutation"  # Options: "permutation", "rotation", "class_incremental"

# For permutation and rotation tasks
steps_per_task = 100
batch_size = 128

# ============ Model Configuration ============
hidden_dim = 128
context_type = "binary"  # Options: "none", "binary", "complex", "rotation"

# ============ Training Configuration ============
eval_interval = steps_per_task // 5
learning_rate = 1e-3
base_seed = 0

print(f"Dataset: {dataset_name}")
print(f"Task Transform: {task_transform_type}")
print(f"Num Tasks: {num_tasks}")
print(f"Model Context: {context_type}")
print(f"Device: {device}")

Dataset: mnist
Task Transform: permutation
Num Tasks: 5
Model Context: binary
Device: cpu


## Step 1: Build Data Loaders

Create data loaders with the specified dataset and task transforms.

In [None]:
# Create dataset config
config = DatasetConfig(root=data_root, download=False)

# Create task transforms based on task_transform_type
if task_transform_type == "permutation":
    task_transforms = [
        PermutePixels(seed=base_seed + i) 
        for i in range(num_tasks)
    ]
elif task_transform_type == "rotation":
    # Rotation angles from 0 to 180 degrees spread across tasks
    angles = [int(180 * i / (num_tasks - 1)) for i in range(num_tasks)]
    task_transforms = [Rotate(degrees=angles[i]) for i in range(num_tasks)]
elif task_transform_type == "class_incremental":
    # For class-incremental, we use a special builder
    from hands_on_neuroai.data.datasets import build_class_incremental_tasks
    input_dim, _ = get_image_shape(dataset_name)
    output_dim = get_num_classes(dataset_name)
    classes_per_task = output_dim // num_tasks
    
    train_dsets, test_dsets = build_class_incremental_tasks(
        dataset_name=dataset_name,
        config=config,
        class_splits=[
            list(range(i * classes_per_task, (i + 1) * classes_per_task))
            for i in range(num_tasks)
        ]
    )
    train_loaders = build_dataloaders(train_dsets, config, batch_size=batch_size, shuffle=True)
    test_loaders = build_dataloaders(test_dsets, config, batch_size=batch_size, shuffle=False)
    task_transforms = None  # Not needed for class-incremental
else:
    raise ValueError(f"Unknown task_transform_type: {task_transform_type}")

# For permutation and rotation, build task datasets with transforms
if task_transforms is not None:
    train_dsets, test_dsets = build_task_datasets(
        dataset_name=dataset_name,
        config=config,
        task_transforms=task_transforms
    )
    train_loaders = build_dataloaders(train_dsets, config, batch_size=batch_size, shuffle=True)
    test_loaders = build_dataloaders(test_dsets, config, batch_size=batch_size, shuffle=False)

print(f"Train loaders: {len(train_loaders)}, Test loaders: {len(test_loaders)}")
print(f"Task transforms: {task_transform_type}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:09<00:00, 1043936.12it/s]



Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 217083.22it/s]



Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1139267.68it/s]



Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1316917.51it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Train loaders: 5, Test loaders: 5
Task transforms: permutation





## Step 2: Initialize Baseline and Context-Aware Models

In [9]:
# Get dataset-specific dimensions
x, y = get_image_shape(dataset_name)
input_dim = x * y
output_dim = get_num_classes(dataset_name)

print(f"Input dimension: {input_dim}, Output dimension: {output_dim}")

# Build baseline model (no task awareness)
baseline = build_model_for_continual_learning(
    context_type="none",
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim,
    num_tasks=num_tasks,
    base_seed=base_seed,
    device=device,
)

# Build context-aware model (with specified context type)
context_model = build_model_for_continual_learning(
    context_type=context_type,
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim,
    num_tasks=num_tasks,
    base_seed=base_seed,
    device=device,
)

print(f"Baseline model created")
print(f"Context model ({context_type}) created")

Input dimension: 784, Output dimension: 10
Baseline model created
Context model (binary) created


## Step 3: Train Models on Sequential Tasks

This cell trains both models sequentially across all tasks, recording accuracy on task 0 throughout training.

In [None]:
print("Training baseline model...")
base_steps, base_acc = train_model_on_continual_learning_tasks(
    model=baseline,
    train_loaders=train_loaders,
    test_loaders=test_loaders,
    num_tasks=num_tasks,
    steps_per_task=steps_per_task,
    lr=learning_rate,
    eval_interval=eval_interval,
    device=device,
    verbose=1
)

print("\nTraining context-aware model...")
ctx_steps, ctx_acc = train_model_on_continual_learning_tasks(
    model=context_model,
    train_loaders=train_loaders,
    test_loaders=test_loaders,
    num_tasks=num_tasks,
    steps_per_task=steps_per_task,
    lr=learning_rate,
    eval_interval=eval_interval,
    device=device,
    verbose=1
)

print("\nTraining complete!")

Training baseline model...


Tasks:   0%|          | 0/5 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
               ^^^^^^^^^^Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
^^^^^^^^^^^^^^^^
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    exitcode = _main

Unexpected exception formatting exception. Falling back to standard exception


libc++abi: terminating with uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating with uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating with uncaught exception of type std::__1::system_error: Broken pipe
Traceback (most recent call last):
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/5n/lggt_xf57s78jjmq538wdr6w0000gn/T/ipykernel_69040/1808906383.py", line 2, in <module>
    base_steps, base_acc = train_model_on_continual_learning_tasks(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/lorenzleisner/Desktop/CogSci/Master/WI_SE_25/hands-on-neuroai/src/hands_on_neuroai/training/continual_learning.py", line 118, in train_model_on_continual_learning_tasks
    imgs, labels = next(data_iter)
                   ^^^^^^^^^^^^^

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
    self._run_once()
  File "/Users/lorenzleisner/opt/anaconda3/envs/anicog/lib/python3.11/asyn

: 

## Step 4: Visualize Task-Specific Accuracy

Compare how both models maintain accuracy on the first task as they learn new tasks.

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(base_steps, base_acc, marker='o', label="Baseline (no context)", linewidth=2)
plt.plot(ctx_steps, ctx_acc, marker='s', label=f"Context-Aware ({context_type})", linewidth=2)
plt.xlabel("Global Step", fontsize=12)
plt.ylabel("Accuracy on Task 0", fontsize=12)
plt.title(f"{dataset_name.upper()} - {task_transform_type} Tasks - Accuracy Evolution", fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Step 5: Collect Hidden Activations

Extract hidden layer representations from both models across all tasks for analysis.

In [None]:
print("Collecting baseline hidden activations...")
A_baseline, t_baseline = collect_hidden_activations(
    model=baseline,
    hidden_module=baseline.fc1,
    loaders=train_loaders,
    num_tasks=num_tasks,
    num_samples_per_task=200,
    device=device,
)

print("Collecting context-aware model hidden activations...")
A_context, t_context = collect_hidden_activations(
    model=context_model,
    hidden_module=context_model.fc1,
    loaders=train_loaders,
    num_tasks=num_tasks,
    num_samples_per_task=200,
    device=device,
)

print(f"Baseline activations shape: {A_baseline.shape}")
print(f"Context model activations shape: {A_context.shape}")

## Step 6: Analyze Baseline Representation Space (PCA)

In [None]:
pca = PCA(n_components=2)
A_baseline_2d = pca.fit_transform(A_baseline)

plt.figure(figsize=(8, 6))
colors = plt.cm.tab10(np.linspace(0, 1, num_tasks))
for t in range(min(num_tasks, 10)):  # Show up to 10 tasks to avoid color conflicts
    mask = (t_baseline == t)
    plt.scatter(A_baseline_2d[mask, 0], A_baseline_2d[mask, 1], 
               s=20, alpha=0.6, label=f"task {t}", color=colors[t])

plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%})", fontsize=11)
plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%})", fontsize=11)
plt.title(f"Baseline Model - Hidden Layer Representations (PCA)", fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Explained variance: PC1={pca.explained_variance_ratio_[0]:.3f}, PC2={pca.explained_variance_ratio_[1]:.3f}")

## Step 7: Analyze Context-Aware Model Representation Space (PCA)

In [None]:
pca2 = PCA(n_components=2)
A_context_2d = pca2.fit_transform(A_context)

plt.figure(figsize=(8, 6))
colors = plt.cm.tab10(np.linspace(0, 1, num_tasks))
for t in range(min(num_tasks, 10)):  # Show up to 10 tasks to avoid color conflicts
    mask = (t_context == t)
    plt.scatter(A_context_2d[mask, 0], A_context_2d[mask, 1], 
               s=20, alpha=0.6, label=f"task {t}", color=colors[t])

plt.xlabel(f"PC1 ({pca2.explained_variance_ratio_[0]:.1%})", fontsize=11)
plt.ylabel(f"PC2 ({pca2.explained_variance_ratio_[1]:.1%})", fontsize=11)
plt.title(f"Context-Aware Model ({context_type}) - Hidden Layer Representations (PCA)", fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Explained variance: PC1={pca2.explained_variance_ratio_[0]:.3f}, PC2={pca2.explained_variance_ratio_[1]:.3f}")

## Summary

This notebook demonstrated:
1. **Dataset Agnostic**: Works with any torchvision dataset (MNIST, CIFAR-10, etc.)
2. **Task Flexibility**: Supports multiple task transform types (permutation, rotation, class-incremental)
3. **Model Comparison**: Compares baseline vs. context-aware models
4. **Analysis Tools**: Uses PCA to visualize how task representations evolve in hidden layers

### Key Observations:
- **Baseline Model**: May show task overlap/interference in latent space
- **Context-Aware Model**: Typically shows better task separation in latent space
- **Accuracy Curves**: Context-aware models often better maintain performance on task 0 as new tasks are learned

### Next Steps:
- Experiment with different datasets (e.g., `cifar10`, `fashionmnist`)
- Try different context types (`binary`, `complex`, `rotation`)
- Vary task transform types to study different continual learning scenarios
- Investigate the relationship between representation separability and task performance