In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# Set random seeds for reproducibility
device = 'cuda'
torch.manual_seed(0)
np.random.seed(0)

# ANALYSIS

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torch.linalg import svdvals
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
import seaborn as sns

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

NSAMPLES = 2000
num_classes = 4  # Set the number of classes

# Dataset class
class ComplexGaussianToyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create complex multi-cluster Gaussian data with variable number of classes
def create_complex_data(n_samples_per_cluster=100, n_clusters_per_class=3, std=0.5, num_classes=2, seed=None):
    if seed is not None:
        np.random.seed(seed)
    X = []
    y = []
    centers = []
    for c in range(num_classes):
        class_centers = np.random.randn(n_clusters_per_class, 2) * (3.0 + c * 1.5) # Spread out classes more
        centers.append(class_centers)
        for center in class_centers:
            points = np.random.randn(n_samples_per_cluster, 2) * (std + c * 0.1) + center # Vary std per class
            X.append(points)
            y += [c] * n_samples_per_cluster

    X = np.vstack(X)
    y = np.array(y)
    return X, y

def nonlinear_warp(X, freq=2.0, amp=0.5):
    X_new = X.copy()
    X_new[:, 0] += amp * np.sin(freq * X[:, 1])
    X_new[:, 1] += amp * np.cos(freq * X[:, 0] + np.pi / 4) # Different warp
    return X_new

def rotate_data(X, angle_degrees):
    angle_radians = np.radians(angle_degrees)
    rotation_matrix = np.array([[np.cos(angle_radians), -np.sin(angle_radians)],
                                [np.sin(angle_radians), np.cos(angle_radians)]])
    return X @ rotation_matrix.T

# DATA CREATION
num_datasets = 9 # Toggle the number of datasets to use
num_complex_datasets = 15
complex_datasets = []
seeds = np.random.randint(0, 1000, num_complex_datasets)

for i in range(num_complex_datasets):
    n_clusters = np.random.randint(2, 5)
    std_dev = np.random.uniform(0.4, 0.8)
    X_base, y = create_complex_data(NSAMPLES // (num_classes * n_clusters), n_clusters, std_dev, num_classes, seed=seeds[i])

    # Introduce more diversity between datasets
    if i % 4 == 0:
        X = X_base
    elif i % 4 == 1:
        X = nonlinear_warp(X_base, freq=1.5 + i * 0.2, amp=0.6 + i * 0.1)
    elif i % 4 == 2:
        X = rotate_data(X_base, angle_degrees=i * 15)
    else:
        X = nonlinear_warp(rotate_data(X_base, angle_degrees=-i * 10), freq=2.0 - i * 0.1, amp=0.5 + i * 0.05)

    # Introduce label shifts for more forgetting
    if i > 0 and i % 3 == 0:
        y = (y + 1) % num_classes # Shift labels

    complex_datasets.append(ComplexGaussianToyDataset(X, y))

# Select the number of datasets to use
selected_datasets = complex_datasets[:num_datasets]
train_loaders = [DataLoader(d, batch_size=64, shuffle=True) for d in selected_datasets]

# Define the prior and posterior distributions
def prior_distribution(model):
    return [param.data.clone() for param in model.parameters()]

def posterior_distribution(model):
    return [param.data.clone() for param in model.parameters()]

def kl_divergence(prior, posterior, sigma_sq=1.0):
    # we don't have access to a "distribution", therefore, we assume both the prior and the posterior have some shared covariance matrix
    kl = 0.0
    for p, q in zip(prior, posterior):
        kl += torch.sum((q - p) ** 2)
    return (0.5 / sigma_sq) * kl

def pac_bayes_bound(prior, posterior, n_samples, empirical_loss, delta=0.05, sigma_sq=1.0):
    kl = kl_divergence(prior, posterior, sigma_sq=sigma_sq)
    bound_term = (kl + np.log(2 * np.sqrt(n_samples) / delta)) / (2 * n_samples)
    return empirical_loss + torch.sqrt(torch.tensor(bound_term, dtype=torch.float32))

class Net(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = nn.Sequential(
          nn.Linear(2, 32),   # Increased hidden units
          nn.ReLU(),
          nn.Linear(32, num_classes)   # Output layer for num_classes
        )


    def forward(self, x):
        return self.net(x)

class EWC:
    def __init__(self, model: nn.Module, dataloader, device='cuda:0'):
        self.model = model.to(device)
        self.device = device
        self.model.eval()
        self.params = {n: p.clone().detach().to(self.device) for n, p in self.model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(dataloader)

    def _compute_fisher(self, dataloader):
        fisher = {n: torch.zeros_like(p, device=self.device) for n, p in self.model.named_parameters() if p.requires_grad}
        criterion = nn.CrossEntropyLoss()

        for inputs, labels in dataloader:
            self.model.zero_grad()
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            outputs = self.model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.data.pow(2)

        for n in fisher:
            fisher[n] /= len(dataloader)

        return fisher

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad:
                _loss = self.fisher[n] * (p - self.params[n]).pow(2)
                loss += _loss.sum()
        return loss


def train(model, loader, optimizer, criterion, epochs, n_samples, ewc = None, use_ewc = True):
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            if use_ewc:
              lam = 40

              ewc_penalty = 0
              for ewc_instance in ewc:  # Iterate over EWC instances
                  ewc_penalty += ewc_instance.penalty(model)
              loss += lam * ewc_penalty
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")


def train_with_pacbayes(model, loader, optimizer, criterion, epochs, n_samples, prior):
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            posterior = posterior_distribution(model)
            epsilon = total_loss / len(loader)
            bound = pac_bayes_bound(prior, posterior, n_samples, epsilon)
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, PAC-Bayes Bound: {bound:.4f}")

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += len(y)
    return correct / total * 100

import copy

torch.manual_seed(1984)
saved_models_list = []
test_loaders_list = train_loaders # Use the created train loaders for testing as well for simplicity

model = Net(num_classes).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
ewc_memory = []
ewc_enabled = True

for i, train_loader in enumerate(train_loaders):
    print(f"Training on Task {i + 1}")
    if i == 0:
        train(model, train_loader, optimizer, criterion, 100, NSAMPLES, use_ewc=False)
        temp_model = copy.deepcopy(model)
        saved_models_list.append(temp_model)
        # prior = prior_distribution(model)
    else:
        train(model, train_loader, optimizer, criterion, 100, NSAMPLES, use_ewc=ewc_enabled, ewc=ewc_memory)
        temp_model = copy.deepcopy(model)
        saved_models_list.append(temp_model)
    # Store EWC data
    ewc_memory.append(EWC(copy.deepcopy(model), train_loader))

    # else:
    #     train_with_pacbayes(model, train_loader, optimizer, criterion, 100, NSAMPLES, prior)
    #     prior = prior_distribution(model) # Update prior after each task
    # temp_model = copy.deepcopy(model)
    # saved_models_list.append(temp_model)

    print(f"Evaluation after training on Task {i + 1}")
    for j, test_loader in enumerate(test_loaders_list):
        acc = evaluate(model, test_loader)
        print(f"  Accuracy on Task {j + 1}: {acc:.2f}%")
    print("-" * 30)

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.linalg import svdvals
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
import seaborn as sns

# ASSUMPTIONS (define these in your notebook before running):
train_loader_task1 = train_loaders[0] if train_loaders else None
train_loader_ewc = train_loader_task1
criterion = nn.CrossEntropyLoss()
device = 'cuda' or 'cpu'

# Move all models to CPU and eval mode for analysis
for m in saved_models_list:
    m.to('cpu').eval()

# 1+2+3: weight norms, effective rank, parameter‐space trajectory
layers = [('fc1', lambda m: m.net[0]), ('fc2', lambda m: m.net[2])]
weights = {name: [] for name, _ in layers}
for model in saved_models_list:
    for name, getm in layers:
        W = getm(model).weight.detach().cpu().numpy()
        weights[name].append(W)

# Compute metrics
norms = {name: [] for name, _ in layers}
eranks = {name: [] for name, _ in layers}
traj_lengths = {name: 0.0 for name, _ in layers}
for name in norms:
    for i in range(len(weights[name]) - 1):
        dW = weights[name][i+1] - weights[name][i]
        norms[name].append(np.linalg.norm(dW))
        traj_lengths[name] += np.linalg.norm(dW)
    for W in weights[name]:
        sv = np.linalg.svd(W, compute_uv=False)
        p = sv / sv.sum()
        eranks[name].append(np.exp(-np.sum(p * np.log(p + 1e-12))))

# 4: EWC Fisher diag on first model
def compute_fisher(model, dataloader, criterion):
    fishers = {n: torch.zeros_like(p) for n,p in model.named_parameters()}
    for imgs, labels in dataloader:
        model.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        for n, p in model.named_parameters():
            fishers[n] += p.grad.data.pow(2)
    for n in fishers:
        fishers[n] /= len(dataloader)
    return fishers

fisher = compute_fisher(saved_models_list[0].to('cpu'), train_loader_task1, criterion) if train_loader_task1 else None
ewc_overlap = []
if fisher is not None:
    for i in range(len(saved_models_list)-1):
        model_prev = saved_models_list[i]
        model_next = saved_models_list[i+1]
        overlap = 0.0
        for n, p in model_prev.named_parameters():
            delta = (model_next.state_dict()[n] - p.data).cpu()
            overlap += (fisher[n] * delta.pow(2)).sum().item()
        ewc_overlap.append(overlap)

# 5: SVCCA and 6: canonical angles
def svcca(X, Y, max_components=20):
    """
    Mean canonical correlation after PCA to L dims,
    where L = min(n_samples, n_features, max_components).
    """
    n_samples, n_features = X.shape
    L = min(n_samples, n_features, max_components)
    # 1) PCA-reduce
    Xr = PCA(n_components=L).fit_transform(X)
    Yr = PCA(n_components=L).fit_transform(Y)
    # 2) CCA on reduced dims
    cca = CCA(n_components=L)
    Xc, Yc = cca.fit_transform(Xr, Yr)
    # 3) average corr per component
    corrs = [np.corrcoef(Xc[:, i], Yc[:, i])[0, 1] for i in range(L)]
    return np.mean(corrs)

# ---- canonical‐angles helper ----
def canonical_angles(X, Y):
    """
    Principal angles between row-spaces of X and Y:
    angles = arccos(singular_values(X^T Y)).
    """
    M = X.T.dot(Y)
    s = np.linalg.svd(M, compute_uv=False)
    # clamp to [-1,1] to avoid numerical errors outside domain
    s = np.clip(s, -1.0, 1.0)
    angles = np.arccos(s)
    return angles


# Activation extraction helper
def extract_acts(model, loader):
    acts = {name: [] for name, _ in layers}
    labels = []
    hooks = []
    for name, getm in layers:
        hooks.append(getm(model).register_forward_hook(
            lambda m, inp, out, n=name: acts[n].append(out.detach().numpy())
        ))
    for imgs, lbls in loader:
        labels.append(lbls.numpy())
        _ = model(imgs)
    for h in hooks: h.remove()
    # concatenate
    for n in acts:
        acts[n] = np.concatenate([a.reshape(a.shape[0], -1) for a in acts[n]], axis=0)
    labels = np.concatenate(labels, axis=0)
    return acts, labels

# Precompute activations for all models & loaders
acts = {}
for mi, model in enumerate(saved_models_list):
    acts[mi] = {}
    for ti, loader in enumerate(test_loaders_list):
        a, lbl = extract_acts(model, loader)
        acts[mi][ti] = (a, lbl)

# 7: manifold‐geometry & 8: cluster separability
def manifold_and_sep(X, labels):
    classes = np.unique(labels)
    centroids = {}
    radii = {}
    dims = {}
    for c in classes:
        Xi = X[labels==c]
        cent = Xi.mean(0)
        centroids[c] = cent
        radii[c] = np.linalg.norm(Xi-cent, axis=1).mean()
        pca = PCA().fit(Xi)
        cum = np.cumsum(pca.explained_variance_ratio_)
        dims[c] = np.searchsorted(cum, 0.9)+1
    # separability
    mu = X.mean(0)
    SW = np.zeros((X.shape[1],X.shape[1]))
    SB = np.zeros_like(SW)
    for c in classes:
        Xi = X[labels==c]
        mu_c = centroids[c]
        SW += (Xi-mu_c).T @ (Xi-mu_c)
        n_c = Xi.shape[0]
        diff = mu_c - mu
        SB += n_c * np.outer(diff, diff)
    sep = np.trace(SB)/np.trace(SW)
    return centroids, radii, dims, sep

manifold_metrics = {}
for mi in acts:                      # model index
    manifold_metrics[mi] = {}
    for ti in acts[mi]:              # task/input index
        manifold_metrics[mi][ti] = {}
        acts_dict, labels = acts[mi][ti]
        # unpack name, getter from layers
        for name, _ in layers:
            X   = acts_dict[name]   # (N, features)
            lbl = labels            # (N,)
            # compute centroids
            manifold_metrics[mi][ti][name] = manifold_and_sep(X, lbl)


# 9: representational drift (CKA) relative to first model on task1 inputs
def linear_CKA(X, Y):
    Xc = X - X.mean(0)
    Yc = Y - Y.mean(0)
    HSIC = np.linalg.norm(Xc.T.dot(Yc), 'fro')**2
    denom = np.linalg.norm(Xc.T.dot(Xc), 'fro') * np.linalg.norm(Yc.T.dot(Yc), 'fro')
    return HSIC / denom

drift = {name: [] for name, _ in layers}
if 0 in acts and 0 in acts[0]:
    for layer_name, _ in layers:
        if layer_name in acts[0][0][0].keys():  # Correct way to check for keys
            base_X = acts[0][0][0][layer_name]  # model 0, loader 0
            for mi in range(1, len(saved_models_list)):
                if mi in acts and 0 in acts[mi]:
                    acts_dict_mi, _ = acts[mi][0]
                    if layer_name in acts_dict_mi:
                        X = acts_dict_mi[layer_name]
                        drift[layer_name].append(linear_CKA(base_X, X))

# 10: inter‐task transfer (accuracy matrix)
acc_matrix = np.zeros((len(saved_models_list), len(test_loaders_list)))
for mi, model in enumerate(saved_models_list):
    for ti, loader in enumerate(test_loaders_list):
        correct = total = 0
        for imgs, lbls in loader:
            out = model(imgs)
            pred = out.argmax(1)
            correct += (pred == lbls).sum().item()
            total += lbls.size(0)
        acc_matrix[mi, ti] = 100 * correct / total

# --- VISUALIZATIONS ---
# Weight norms & effective rank
for name in norms:
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(norms[name], marker='o')
    plt.title(f'{name} update norms')
    plt.xlabel('Update Step')
    plt.ylabel('Norm')
    plt.xticks(range(len(norms[name])))  # Show all task transitions
    plt.subplot(1, 2, 2)
    plt.plot(eranks[name], marker='o')
    plt.title(f'{name} effective rank')
    plt.xlabel('Task Index')
    plt.ylabel('Effective Rank')
    plt.xticks(range(len(eranks[name])))  # Show all tasks
    plt.suptitle(f'Layer {name}')
    plt.tight_layout()
    plt.show()

# EWC overlap
if ewc_overlap:
    plt.figure()
    plt.plot(ewc_overlap, marker='o')
    plt.title('EWC importance overlap per update')
    plt.xlabel('Update Step')
    plt.ylabel('Overlap')
    plt.xticks(range(len(ewc_overlap)))  # Show all task transitions
    plt.show()

# SVCCA & angles
for name, _ in layers:
    sv_vals, ang_vals = [], []
    # baseline activations: model 0 on loader 0
    if 0 in acts and 0 in acts[0]:
        base_acts, _ = acts[0][0]     # acts[model_idx][loader_idx] = (dict, labels)
        if name in base_acts:
            X0 = base_acts[name]         # shape = (n_samples, feature_dim)

            for mi in range(1, len(saved_models_list)):
                if mi in acts and 0 in acts[mi]:
                    comp_acts, _ = acts[mi][0]
                    if name in comp_acts:
                        Xi = comp_acts[name]
                        sv_vals.append(svcca(X0, Xi))
                        ang_vals.append(canonical_angles(X0, Xi).mean())

            # plot side by side
            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
            axs[0].plot(sv_vals, marker='o')
            axs[0].set_title(f'SVCCA similarity — {name}')
            axs[0].set_xlabel('Model Index (vs Model 0)')
            axs[0].set_ylabel('Mean Canonical Corr')
            axs[0].set_xticks(range(len(sv_vals)))  # Show all model transitions

            axs[1].plot(ang_vals, marker='o')
            axs[1].set_title(f'Mean canonical angle — {name}')
            axs[1].set_xlabel('Model Index (vs Model 0)')
            axs[1].set_ylabel('Angle (rad)')
            axs[1].set_xticks(range(len(ang_vals)))  # Show all model transitions

            plt.suptitle(f'Representation similarity for layer "{name}"', y=1.02)
            plt.tight_layout()
            plt.show()


# Manifold & separability for model 0 on task1 inputs
if 0 in manifold_metrics and 0 in manifold_metrics[0] and 'fc1' in manifold_metrics[0][0]:
    cent, rad, dims, sep = manifold_metrics[0][0]['fc1']
    print('Class radii (fc1):', rad)
    print('Manifold dims (fc1):', dims)
    print('Cluster separability (fc1):', sep)
    if 'fc2' in manifold_metrics[0][0]:
        cent_fc2, rad_fc2, dims_fc2, sep_fc2 = manifold_metrics[0][0]['fc2']
        print('Class radii (fc2):', rad_fc2)
        print('Manifold dims (fc2):', dims_fc2)
        print('Cluster separability (fc2):', sep_fc2)

# Representational drift
for name in drift:
    plt.figure()
    plt.plot(drift[name], marker='o', label=name)
    plt.title(f'CKA-based drift from model0 {name} outputs')
    plt.xlabel('Model Index')
    plt.ylabel('CKA Similarity')
    plt.xticks(range(len(drift[name])))  # Show all model transitions
    plt.legend()
    plt.show()

# Inter-task transfer heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(acc_matrix, annot=True, fmt=".1f", cmap="viridis",
            xticklabels=[f'Task {j+1}' for j in range(acc_matrix.shape[1])],
            yticklabels=[f'Model {i+1}' for i in range(acc_matrix.shape[0])])
plt.xlabel('Test Task'); plt.ylabel('Model Trained On')
plt.title('Inter-task Transfer Accuracy (%)')
plt.tight_layout()
plt.show()

In [None]:
# --- PARAMETER SPACE TRAJECTORY VISUALIZATION ---
def plot_weight_trajectory(weight_list, layer_name, ndim=2):
    """
    Project trajectory of weights into 2D or 3D using PCA.
    """
    flattened = [w.flatten() for w in weight_list]
    X = np.stack(flattened)  # Shape: (num_steps, num_weights)
    pca = PCA(n_components=ndim)
    X_proj = pca.fit_transform(X)

    fig = plt.figure(figsize=(6, 6))
    if ndim == 3:
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(X_proj[:, 0], X_proj[:, 1], X_proj[:, 2], marker='o')
        ax.set_xlabel('PC1'); ax.set_ylabel('PC2'); ax.set_zlabel('PC3')
    else:
        ax = fig.add_subplot(111)
        ax.plot(X_proj[:, 0], X_proj[:, 1], marker='o')
        for i, (x, y) in enumerate(zip(X_proj[:, 0], X_proj[:, 1])):
            ax.text(x, y, str(i), fontsize=8)
        ax.set_xlabel('PC1'); ax.set_ylabel('PC2')

    plt.title(f'Parameter Trajectory: {layer_name}')
    plt.tight_layout()
    plt.show()

# Plot trajectories
for name in weights:
    plot_weight_trajectory(weights[name], name, ndim=2)  # Set ndim=3 for 3D if preferred
