# GSS (Gradient-based Subspace Search) Strategy

In [1]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cpu
Notebook last modified at: 2025-07-31 21:09:13


## Replay Buffer for EWC

In [None]:
class ReplayBuffer:
    def __init__(self, max_per_class=100, model=None, loss_fn=None, device="cpu"):
        """
        Initialize the replay buffer with GSS-Greedy selection.
        
        Args:
            max_per_class (int): Maximum number of exemplars per class.
            model: Neural network model for gradient computation (e.g., ResNetSmall).
            loss_fn: Loss function for gradient computation (e.g., CrossEntropyLoss).
            device: Device to perform computations (cuda or cpu).
        """
        self.max_per_class = max_per_class
        self.buffer = defaultdict(list)
        self.model = model
        self.loss_fn = loss_fn or nn.CrossEntropyLoss(reduction="mean")
        self.device = device

    def compute_gradient(self, x, y):
        """
        Compute normalized gradient for a single sample.
        
        Args:
            x (torch.Tensor): Input tensor (single sample).
            y (torch.Tensor): Label tensor (single label).
        
        Returns:
            torch.Tensor: Normalized gradient vector.
        """
        self.model.eval()
        x, y = x.unsqueeze(0).to(self.device), y.unsqueeze(0).to(self.device)
        self.model.zero_grad()
        feats, logits = self.model(x)
        loss = self.loss_fn(logits, y)
        loss.backward()
        
        # Collect gradients
        grad = []
        for param in self.model.parameters():
            if param.grad is not None:
                grad.append(param.grad.flatten())
        grad = torch.cat(grad)
        grad_norm = torch.norm(grad, p=2)
        return grad / (grad_norm + 1e-8)  # Normalize gradient

    def add_examples(self, x_batch, y_batch):
        """
        Add examples to the replay buffer using GSS-Greedy selection.
        
        Args:
            x_batch (torch.Tensor): Batch of input data.
            y_batch (torch.Tensor): Corresponding labels for the input data.
        """
        if self.model is None or self.loss_fn is None:
            raise ValueError("Model and loss function must be set for GSS-Greedy selection.")
        
        for cls in set(y_batch.cpu().numpy()):
            cls = int(cls)
            # Collect samples for this class
            cls_indices = (y_batch == cls).nonzero(as_tuple=True)[0]
            if not cls_indices.numel():
                continue
            
            cls_samples = [(x_batch[i], y_batch[i]) for i in cls_indices]
            current_samples = self.buffer[cls] + cls_samples
            
            if len(current_samples) <= self.max_per_class:
                # If within limit, keep all samples
                self.buffer[cls] = current_samples
            else:
                # Compute gradients for all samples
                gradients = []
                samples = []
                for x, y in current_samples:
                    grad = self.compute_gradient(x, y)
                    gradients.append(grad)
                    samples.append((x, y))
                
                # GSS-Greedy selection
                selected_indices = []
                for _ in range(self.max_per_class):
                    if not selected_indices:
                        idx = np.random.randint(0, len(samples))
                    else:
                        distances = []
                        for i, grad in enumerate(gradients):
                            if i in selected_indices:
                                continue
                            min_dist = min([torch.norm(grad - gradients[j]) for j in selected_indices])
                            distances.append((i, min_dist))
                        idx = max(distances, key=lambda x: x[1])[0]
                    selected_indices.append(idx)
                
                # Update buffer with selected samples
                self.buffer[cls] = [samples[i] for i in selected_indices]

    def get_all_data(self):
        """
        Get all data from the replay buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.buffer.items():
            if examples:
                xs.extend([x for x, _ in examples])
                ys.extend([y for _, y in examples])
        if not xs:
            return None, None
        return torch.stack(xs), torch.stack(ys)

# EWC class

In [None]:
class EWC:
    def __init__(self, model, dataloader, device, samples=500):
        self.model = model
        self.device = device
        self.params = {n: p.clone().detach() for n, p in model.named_parameters()}
        self.fisher = self._compute_fisher(dataloader, samples)

    def _compute_fisher(self, dataloader, samples):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
        self.model.eval()
        count = 0
        for x, y in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            feats, logits = self.model(x)
            # prob = F.softmax(logits, dim=1)

            # log_prob = F.log_softmax(logits, dim=1)[range(len(y)), y].mean()
            # log_prob.backward()

            log_probs = F.log_softmax(logits, dim=1)
            # sum negative log‐likelihood over batch
            loss_batch = -log_probs[range(len(y)), y].sum()
            loss_batch.backward()
            for n, p in self.model.named_parameters():
                fisher[n] += p.grad.data.pow(2)
            count += x.size(0)   # count by number of *samples*
            if count >= samples:
                break
        return {n: f / count for n, f in fisher.items()}

    def penalty(self, model, lambda_ewc):
        loss = 0
        for n, p in model.named_parameters():
            if n not in self.fisher: continue
            f, p0 = self.fisher[n], self.params[n]
            if p.shape == p0.shape:
                loss += (f * (p - p0).pow(2)).sum()
            else:
                # assume this is the expanded fc.weight or fc.bias
                # only penalize the first p0.shape[...] entries
                if 'fc.weight' in n and p.dim()==2:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
                elif 'fc.bias' in n and p.dim()==1:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
        return (lambda_ewc / 2) * loss

In [None]:
class PerTaskEWC:
    """
    Collects multiple (params, fisher) snapshots—one per past task—
    and, at training‐time, computes the sum of all EWC penalties.
    """
    def __init__(self, model, device, ewc_paths: list):
        """
        Args:
          - model (nn.Module):  The “current” model (whose parameter names
                                must match those stored on disk).
          - device:            CPU / CUDA device.
          - ewc_paths:         List of file‐paths: ['ewc_task_1.pt', 'ewc_task_2.pt', ...].
        """
        self.device = device
        self.model = model

        self.past_task_params = []  # list of dict: each dict maps name→tensor (θ^{*(k)})
        self.past_task_fishers = [] # list of dict: each dict maps name→tensor (F^{(k)})

        # Load all saved EWC files:
        for path in ewc_paths:
            data = torch.load(path, map_location='cpu')
            # data['params'] and data['fisher'] are both dict(name→cpu_tensor)
            # Move them to the correct device now:
            params_k = {name: param.to(self.device) for name, param in data['params'].items()}
            fisher_k = {name: fisher.to(self.device) for name, fisher in data['fisher'].items()}
            self.past_task_params.append(params_k)
            self.past_task_fishers.append(fisher_k)


    def penalty(self, model, lambda_ewc):
        """
        Loops over each past task k, then each parameter name,
        and accumulates F^{(k)}_i * (θ_i - θ^{*(k)}_i)^2.

        Returns:  (λ/2) * [sum over tasks & params of F (θ - θ*)^2]
        """
        total_loss = 0.0

        # Iterate over each past‐task snapshot:
        for params_k, fisher_k in zip(self.past_task_params, self.past_task_fishers):
            for name, param in model.named_parameters():
                # If this parameter existed when snapshot_k was taken:
                if name not in fisher_k:
                    continue

                θ_star = params_k[name]      # θ^{*(k)}
                Fk      = fisher_k[name]     # F^{(k)}

                if param.shape == θ_star.shape:
                    total_loss += (Fk * (param - θ_star).pow(2)).sum()
                else:
                    # If some layers were expanded (e.g. classifier head grew),
                    # only penalize the “old” slice [0:θ_star.shape[...]].
                    if 'fc.weight' in name and param.dim() == 2:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    elif 'fc.bias' in name and param.dim() == 1:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    # else: if other layers changed shape unexpectely, you may skip them.

        # Multiply by λ/2:
        return (lambda_ewc / 2) * total_loss