# GSS

In [1]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cpu
Notebook last modified at: 2025-07-28 23:25:44


## ReplayBuffer for EWC

In [None]:
class ReplayBuffer:
    def __init__(self, max_per_class=100, model=None, loss_fn=None, device="cpu"):
        """
        Initialize the replay buffer with GSS-Greedy selection.
        
        Args:
            max_per_class (int): Maximum number of exemplars per class.
            model: Neural network model for gradient computation (e.g., ResNetSmall).
            loss_fn: Loss function for gradient computation (e.g., CrossEntropyLoss).
            device: Device to perform computations (cuda or cpu).
        """
        self.max_per_class = max_per_class
        self.buffer = defaultdict(list)
        self.model = model
        self.loss_fn = loss_fn or nn.CrossEntropyLoss(reduction="mean")
        self.device = device

    def compute_gradient(self, x, y):
        """
        Compute normalized gradient for a single sample.
        
        Args:
            x (torch.Tensor): Input tensor (single sample).
            y (torch.Tensor): Label tensor (single label).
        
        Returns:
            torch.Tensor: Normalized gradient vector.
        """
        self.model.eval()
        x, y = x.unsqueeze(0).to(self.device), y.unsqueeze(0).to(self.device)
        self.model.zero_grad()
        feats, logits = self.model(x)
        loss = self.loss_fn(logits, y)
        loss.backward()
        
        # Collect gradients
        grad = []
        for param in self.model.parameters():
            if param.grad is not None:
                grad.append(param.grad.flatten())
        grad = torch.cat(grad)
        grad_norm = torch.norm(grad, p=2)
        return grad / (grad_norm + 1e-8)  # Normalize gradient

    def add_examples(self, x_batch, y_batch):
        """
        Add examples to the replay buffer using GSS-Greedy selection.
        
        Args:
            x_batch (torch.Tensor): Batch of input data.
            y_batch (torch.Tensor): Corresponding labels for the input data.
        """
        if self.model is None or self.loss_fn is None:
            raise ValueError("Model and loss function must be set for GSS-Greedy selection.")
        
        for cls in set(y_batch.cpu().numpy()):
            cls = int(cls)
            # Collect samples for this class
            cls_indices = (y_batch == cls).nonzero(as_tuple=True)[0]
            if not cls_indices.numel():
                continue
            
            cls_samples = [(x_batch[i], y_batch[i]) for i in cls_indices]
            current_samples = self.buffer[cls] + cls_samples
            
            if len(current_samples) <= self.max_per_class:
                # If within limit, keep all samples
                self.buffer[cls] = current_samples
            else:
                # Compute gradients for all samples
                gradients = []
                samples = []
                for x, y in current_samples:
                    grad = self.compute_gradient(x, y)
                    gradients.append(grad)
                    samples.append((x, y))
                
                # GSS-Greedy selection
                selected_indices = []
                for _ in range(self.max_per_class):
                    if not selected_indices:
                        idx = np.random.randint(0, len(samples))
                    else:
                        distances = []
                        for i, grad in enumerate(gradients):
                            if i in selected_indices:
                                continue
                            min_dist = min([torch.norm(grad - gradients[j]) for j in selected_indices])
                            distances.append((i, min_dist))
                        idx = max(distances, key=lambda x: x[1])[0]
                    selected_indices.append(idx)
                
                # Update buffer with selected samples
                self.buffer[cls] = [samples[i] for i in selected_indices]

    def get_all_data(self):
        """
        Get all data from the replay buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.buffer.items():
            if examples:
                xs.extend([x for x, _ in examples])
                ys.extend([y for _, y in examples])
        if not xs:
            return None, None
        return torch.stack(xs), torch.stack(ys)

In [None]:
# ==== 3. Elastic Weight Consolidation (EWC) ====  
class EWC:
    def __init__(self, model, dataloader, device, samples=500):
        self.model = model
        self.device = device
        self.params = {n: p.clone().detach() for n, p in model.named_parameters()}
        self.fisher = self._compute_fisher(dataloader, samples)

    def _compute_fisher(self, dataloader, samples):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
        self.model.eval()
        count = 0
        for x, y in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            feats, logits = self.model(x)
            # prob = F.softmax(logits, dim=1)

            # log_prob = F.log_softmax(logits, dim=1)[range(len(y)), y].mean()
            # log_prob.backward()

            log_probs = F.log_softmax(logits, dim=1)
            # sum negative log‐likelihood over batch
            loss_batch = -log_probs[range(len(y)), y].sum()
            loss_batch.backward()
            for n, p in self.model.named_parameters():
                fisher[n] += p.grad.data.pow(2)
            count += x.size(0)   # count by number of *samples*
            if count >= samples:
                break
        return {n: f / count for n, f in fisher.items()}

    def penalty(self, model, lambda_ewc):
        loss = 0
        for n, p in model.named_parameters():
            if n not in self.fisher: continue
            f, p0 = self.fisher[n], self.params[n]
            if p.shape == p0.shape:
                loss += (f * (p - p0).pow(2)).sum()
            else:
                # assume this is the expanded fc.weight or fc.bias
                # only penalize the first p0.shape[...] entries
                if 'fc.weight' in n and p.dim()==2:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
                elif 'fc.bias' in n and p.dim()==1:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
        return (lambda_ewc / 2) * loss

In [None]:
class PerTaskEWC:
    """
    Collects multiple (params, fisher) snapshots—one per past task—
    and, at training‐time, computes the sum of all EWC penalties.
    """

    def __init__(self, model, device, ewc_paths: list):
        """
        Args:
          - model (nn.Module):  The “current” model (whose parameter names
                                must match those stored on disk).
          - device:            CPU / CUDA device.
          - ewc_paths:         List of file‐paths: ['ewc_task_1.pt', 'ewc_task_2.pt', ...].
        """
        self.device = device
        self.model = model

        self.past_task_params = []  # list of dict: each dict maps name→tensor (θ^{*(k)})
        self.past_task_fishers = [] # list of dict: each dict maps name→tensor (F^{(k)})

        # Load all saved EWC files:
        for path in ewc_paths:
            data = torch.load(path, map_location='cpu')
            # data['params'] and data['fisher'] are both dict(name→cpu_tensor)
            # Move them to the correct device now:
            params_k = {name: param.to(self.device) for name, param in data['params'].items()}
            fisher_k = {name: fisher.to(self.device) for name, fisher in data['fisher'].items()}
            self.past_task_params.append(params_k)
            self.past_task_fishers.append(fisher_k)


    def penalty(self, model, lambda_ewc):
        """
        Loops over each past task k, then each parameter name,
        and accumulates F^{(k)}_i * (θ_i - θ^{*(k)}_i)^2.

        Returns:  (λ/2) * [sum over tasks & params of F (θ - θ*)^2]
        """
        total_loss = 0.0

        # Iterate over each past‐task snapshot:
        for params_k, fisher_k in zip(self.past_task_params, self.past_task_fishers):
            for name, param in model.named_parameters():
                # If this parameter existed when snapshot_k was taken:
                if name not in fisher_k:
                    continue

                θ_star = params_k[name]      # θ^{*(k)}
                Fk      = fisher_k[name]     # F^{(k)}

                if param.shape == θ_star.shape:
                    total_loss += (Fk * (param - θ_star).pow(2)).sum()
                else:
                    # If some layers were expanded (e.g. classifier head grew),
                    # only penalize the “old” slice [0:θ_star.shape[...]].
                    if 'fc.weight' in name and param.dim() == 2:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    elif 'fc.bias' in name and param.dim() == 1:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    # else: if other layers changed shape unexpectely, you may skip them.

        # Multiply by λ/2:
        return (lambda_ewc / 2) * total_loss

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return F.relu(out)

class ResNetSmall(nn.Module):
    """
    Small ResNet-34 model with 2 blocks.
    """
    def __init__(self, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = BasicBlock(16, 32, stride=2)
        self.layer2 = BasicBlock(32, 64, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.feature_dim = 64
        self.fc = nn.Linear(self.feature_dim, num_classes)

    def extract_features(self, x):
        """
            Extract features from the input tensor.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
            Returns:
                torch.Tensor: Extracted features of shape (batch_size, feature_dim).
        """
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        """
            Forward pass through the network.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
        """
        feats = self.extract_features(x)
        logits = self.fc(feats)
        return feats, logits

    def expand_output(self, new_num_classes):
        """
        Expand the output layer to accommodate new classes.
        Args:
            new_num_classes (int): The new number of classes for the output layer.
        """
        old_fc = self.fc
        new_fc = nn.Linear(self.feature_dim, new_num_classes)
        with torch.no_grad():
            # copy old parameters of FC layer to newly expanded model
            new_fc.weight[:old_fc.out_features] = old_fc.weight
            new_fc.bias[:old_fc.out_features] = old_fc.bias
        self.fc = new_fc.to(old_fc.weight.device)

## Implement GSS greedy

In [None]:
def compute_gradient(model, x, y, loss_fn, device):
    """Compute normalized gradient for a single sample"""
    model.eval()
    x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)
    model.zero_grad()
    features, logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    
    # Collect gradients
    grad = []
    for param in model.parameters():
        if param.grad is not None:
            grad.append(param.grad.flatten())
    grad = torch.cat(grad)
    grad_norm = torch.norm(grad, p=2)
    return grad / (grad_norm + 1e-8)  # Normalize gradient

def gss_greedy_selection(model, data_loader, num_exemplars_per_class, loss_fn, device):
    """GSS-Greedy exemplar selection"""
    model.eval()
    exemplars = defaultdict(list)
    for class_id in range(data_loader.dataset.num_classes):
        # Get class-specific data
        class_samples = [(x, y) for x, y in data_loader.dataset if y == class_id]
        if not class_samples:
            continue
        
        # Compute gradients for all samples
        gradients = []
        samples = []
        for x, y in class_samples:
            grad = compute_gradient(model, x, y, loss_fn, device)
            gradients.append(grad)
            samples.append((x, y))
        
        # Greedy selection
        selected_indices = []
        for _ in range(min(num_exemplars_per_class, len(samples))):
            if not selected_indices:
                # Pick first sample randomly
                idx = np.random.randint(0, len(samples))
            else:
                # Compute max-min distance in gradient space
                distances = []
                for i, grad in enumerate(gradients):
                    if i in selected_indices:
                        continue
                    min_dist = min([torch.norm(grad - gradients[j]) for j in selected_indices])
                    distances.append((i, min_dist))
                idx = max(distances, key=lambda x: x[1])[0]
            selected_indices.append(idx)
        
        # Store selected exemplars
        for idx in selected_indices:
            exemplars[class_id].append(samples[idx])
    
    return exemplars

## Implement train function

In [None]:
def compute_fisher(model, dataloader, device, samples=500):
    """
    Compute Fisher information for the model’s parameters.
    
    Args:
        model: Neural network model.
        dataloader: DataLoader for the task.
        device: Device for computations.
        samples: Number of samples to use for Fisher computation.
    
    Returns:
        dict: Fisher information for each parameter.
    """
    fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters()}
    model.eval()
    count = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        model.zero_grad()
        _, logits = model(x)
        log_probs = F.log_softmax(logits, dim=1)
        loss = -log_probs[range(len(y)), y].sum()
        loss.backward()
        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher[n] += p.grad.data.pow(2)
        count += x.size(0)
        if count >= samples:
            break
    return {n: f / count for n, f in fisher.items()}

def save_ewc_snapshot(model, dataloader, device, path, samples=500):
    """
    Save EWC snapshot (parameters and Fisher information) for a task.
    
    Args:
        model: Neural network model.
        dataloader: DataLoader for the task.
        device: Device for computations.
        path: File path to save the snapshot.
        samples: Number of samples for Fisher computation.
    """
    params = {n: p.clone().detach() for n, p in model.named_parameters()}
    fisher = compute_fisher(model, dataloader, device, samples)
    torch.save({'params': params, 'fisher': fisher}, path)

def train_and_plot(train_loaders, test_loaders, ood_loader, device, args):
    """
    Train the model with GSS-based replay buffer and PerTaskEWC regularization.
    
    Args:
        train_loaders (list): List of DataLoader objects for training tasks.
        test_loaders (list): List of DataLoader objects for testing tasks.
        ood_loader (DataLoader): DataLoader for out-of-distribution data.
        device (torch.device): Device for computations.
        args: Hyperparameter arguments.
    """
    # Initialize model and replay buffer
    model = ResNetSmall(num_classes=args.num_cls_per_task).to(device)
    replay_buffer = ReplayBuffer(max_per_class=args.max_per_class, model=model, device=device)
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
    
    # Store paths for EWC snapshots
    ewc_paths = []
    num_classes_seen = 0
    
    # Training loop over tasks
    for task_id, train_loader in enumerate(train_loaders):
        print(f"Training on task {task_id + 1}/{len(train_loaders)}")
        
        # Expand output layer for new classes
        num_classes_seen += args.num_cls_per_task
        model.expand_output(num_classes_seen)
        model.to(device)
        
        # Initialize PerTaskEWC for this task (with all past snapshots)
        if ewc_paths:
            ewc = PerTaskEWC(model, device, ewc_paths)
        else:
            ewc = None
        
        # Train for multiple epochs
        for epoch in range(args.epochs):
            model.train()
            total_loss = 0
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                
                # Add exemplars from replay buffer
                ex_x, ex_y = replay_buffer.get_all_data()
                if ex_x is not None:
                    x = torch.cat([x, ex_x.to(device)], dim=0)
                    y = torch.cat([y, ex_y.to(device)], dim=0)
                
                optimizer.zero_grad()
                feats, logits = model(x)
                loss = loss_fn(logits, y)
                
                # Apply PerTaskEWC penalty
                if ewc is not None:
                    loss += ewc.penalty(model, args.lambda_ewc)
                
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            print(f"Epoch {epoch + 1}/{args.epochs}, Loss: {total_loss / len(train_loader):.4f}")
        
        # Update replay buffer with GSS-Greedy
        x_batch = torch.cat([x for x, _ in train_loader.dataset], dim=0)
        y_batch = torch.cat([y for _, y in train_loader.dataset], dim=0)
        replay_buffer.add_examples(x_batch, y_batch)
        
        # Save EWC snapshot for this task
        ewc_path = f"{args.savedir}/ewc_task_{task_id + 1}.pt"
        save_ewc_snapshot(model, train_loader, device, ewc_path, samples=500)
        ewc_paths.append(ewc_path)
        
        # Evaluate on all seen tasks
        model.eval()
        accuracies = []
        for t, test_loader in enumerate(test_loaders[:task_id + 1]):
            correct, total = 0, 0
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                with torch.no_grad():
                    _, logits = model(x)
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
            acc = correct / total
            accuracies.append(acc)
            print(f"Task {t + 1} Accuracy: {acc:.4f}")
        avg_acc = sum(accuracies) / len(accuracies) if accuracies else 0
        print(f"Average Accuracy across tasks: {avg_acc:.4f}")
        
        # Save model and buffer
        torch.save(model.state_dict(), f"{args.savedir}/model_task_{task_id + 1}.pth")
        torch.save(replay_buffer.buffer, f"{args.savedir}/buffer_task_{task_id + 1}.pth")
    
    return model, accuracies