# ICaRL: The strategy of taking buffer for EWC in Continual Learning

In [4]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cuda
Notebook last modified at: 2025-07-27 23:02:52


## 1. Replay Buffer with per-class exemplars capturing the distribution of each class
- The buffer is used to store a fixed number of exemplars for each class.
- The number of exemplars per class is determined by the `exemplars_per_class`
  parameter.
- The buffer is updated with new exemplars when the model encounters new classes.   

In [5]:
class ReplayBuffer:
    def __init__(self, max_per_class=100):
        """
        Initialize the replay buffer with a maximum number of exemplars per class.
        
        Args:
            max_per_class (int): Maximum number of exemplars to store for each class.
        """
        self.max_per_class = max_per_class
        self.buffer = defaultdict(list)
        
    def add_examples(self, x_batch, y_batch):
        """
        Add examples to the replay buffer.
        
        Args:
            x_batch (torch.Tensor): Batch of input data.
            y_batch (torch.Tensor): Corresponding labels for the input data.
        """
        for x, y in zip(x_batch, y_batch):
            cls = int(y.item())
            lst = self.buffer[cls]
            lst.append(x)
            # FIFO replacement if the class buffer exceeds the limit
            if len(lst) > self.max_per_class:
                lst.pop(0)
            self.buffer[cls] = lst
    
    def get_all_data(self):
        """
        Get all data from the replay buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.buffer.item():
            xs.append(torch.stack(examples)) # Collect all examples for the class
            ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
            

## 2. iCaRL buffer with memory budget memory_size

In [6]:
class ICaRLBuffer:
    def __init__(self, memory_size):
        """
        Initialize the iCaRL buffer with a memory budget.
        
        Args:
            memory_size (int): Total memory budget for the buffer.
        """
        self.memory_size = memory_size # maximum buffer size
        self.exemplar_set = defaultdict(list) # class_id -> list of exemplars tensors
        self.seen_classes = set() # keep track of seen classes
    
    def construct_exemplar_set(self, class_id, features, images, m):
        """
        Construct the exemplar set for a given class.
        
        Args:
            class_id (int): The class ID for which to construct the exemplar set.
            features (torch.Tensor): Features of the images.
            images (torch.Tensor): Corresponding images.
            m (int): Number of exemplars to select for this class.
        """
        features = F.normalize(features, dim=1) # normalize features
        class_mean = F.normalize(class_mean.unsqueeze(0), dim=1) # normalize class mean
        
        selected, exemplar_features = [], []
        used_indices = torch.zeros(features.size(0), dtype=torch.bool, device=features.device)
        
        for k in range(m):
            if k == 0:
                current_sum = 0
            else:
                current_sum = torch.stack(exemplar_features).sum(dim=0)
            mu = class_mean.squeeze(0)
            residual = mu * (k + 1) - current_sum
            distances = (features @ residual).squeeze()
            
            # Mask already used indices
            distances[used_indices] = float('-inf')
            idx = torch.argmax(distances).item()
            
            selected.append(images[idx].cpu())
            exemplar_features.append(features[idx].cpu())
            used_indices[idx] = True
            
        self.exemplar_set[class_id] = selected
        self.seen_classes.add(class_id)
    
    def reduce_exemplar_sets(self, m_per_class):
        """
        Reduce the exemplar sets to fit within the memory budget.
        
        Args:
            m_per_class (int): Number of exemplars to keep per class.
        """
        for cls in self.seen_classes:
            if len(self.exemplar_set[cls]) > m_per_class:
                self.exemplar_set[cls] = self.exemplar_set[cls][:m_per_class]
    
    def get_all_data(self):
        """
        Get all data from the iCaRL buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.exemplar_set.items():
            if examples:
                xs.append(torch.stack(examples))
                ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
    
    def get_all_data_for_task(self, class_list):
        """
        Get all data for a specific task from the iCaRL buffer.
        
        Args:
            class_list (list): List of class IDs for the task.
        
        Returns:
            TensorDataset: A dataset containing all exemplars for the specified classes.
        """
        xs, ys = [], []
        for cls in class_list:
            exemplars = self.exemplar_sets.get(cls, [])
            if len(exemplars) == 0:
                continue
            
            # Stack the list-of‐tensors → a single tensor of shape [num_exemplars_of_cls, ...]
            xs.append(torch.stack(exemplars))
            # Create a label‐tensor of shape [num_exemplars_of_cls] filled with “cls”
            ys.append(torch.full((len(exemplars),), cls, dtype=torch.long))

        if not xs:
            return None, None

        # Concatenate along the “batch” dimension
        all_x = torch.cat(xs, dim=0)
        all_y = torch.cat(ys, dim=0)
        return all_x, all_y

## 3. Small ResNet-34 model with 2 blocks
- The model is a small ResNet-34 architecture with 2 blocks.


In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return F.relu(out)

class ResNetSmall(nn.Module):
    """
    Small ResNet-34 model with 2 blocks.
    """
    def __init__(self, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = BasicBlock(16, 32, stride=2)
        self.layer2 = BasicBlock(32, 64, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.feature_dim = 64
        self.fc = nn.Linear(self.feature_dim, num_classes)

    def extract_features(self, x):
        """
            Extract features from the input tensor.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
            Returns:
                torch.Tensor: Extracted features of shape (batch_size, feature_dim).
        """
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        """
            Forward pass through the network.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
        """
        feats = self.extract_features(x)
        logits = self.fc(feats)
        return feats, logits

    def expand_output(self, new_num_classes):
        """
        Expand the output layer to accommodate new classes.
        Args:
            new_num_classes (int): The new number of classes for the output layer.
        """
        old_fc = self.fc
        new_fc = nn.Linear(self.feature_dim, new_num_classes)
        with torch.no_grad():
            # copy old parameters of FC layer to newly expanded model
            new_fc.weight[:old_fc.out_features] = old_fc.weight
            new_fc.bias[:old_fc.out_features] = old_fc.bias
        self.fc = new_fc.to(old_fc.weight.device)

## 4. ResNet34

In [9]:
class BasicBlock(nn.Module):
    expansion = 1  # For BasicBlock, output channels = out_channels * 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return F.relu(out)

class ResNet34(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = 512
        self.fc = nn.Linear(self.feature_dim, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def extract_features(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        feats = self.extract_features(x)
        logits = self.fc(feats)
        return feats, logits

    def expand_output(self, new_num_classes):
        old_fc = self.fc
        new_fc = nn.Linear(self.feature_dim, new_num_classes)
        with torch.no_grad():
            new_fc.weight[:old_fc.out_features] = old_fc.weight
            new_fc.bias[:old_fc.out_features] = old_fc.bias
        self.fc = new_fc.to(old_fc.weight.device)

## 5. Elastic Weight Consolidation (EWC)

In [10]:
# ==== 3. Elastic Weight Consolidation (EWC) ====  
class EWC:
    def __init__(self, model, dataloader, device, samples=500):
        self.model = model
        self.device = device
        self.params = {n: p.clone().detach() for n, p in model.named_parameters()}
        self.fisher = self._compute_fisher(dataloader, samples)

    def _compute_fisher(self, dataloader, samples):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
        self.model.eval()
        count = 0
        for x, y in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            feats, logits = self.model(x)
            # prob = F.softmax(logits, dim=1)

            # log_prob = F.log_softmax(logits, dim=1)[range(len(y)), y].mean()
            # log_prob.backward()

            log_probs = F.log_softmax(logits, dim=1)
            # sum negative log‐likelihood over batch
            loss_batch = -log_probs[range(len(y)), y].sum()
            loss_batch.backward()
            for n, p in self.model.named_parameters():
                fisher[n] += p.grad.data.pow(2)
            count += x.size(0)   # count by number of *samples*
            if count >= samples:
                break
        return {n: f / count for n, f in fisher.items()}

    def penalty(self, model, lambda_ewc):
        loss = 0
        for n, p in model.named_parameters():
            if n not in self.fisher: continue
            f, p0 = self.fisher[n], self.params[n]
            if p.shape == p0.shape:
                loss += (f * (p - p0).pow(2)).sum()
            else:
                # assume this is the expanded fc.weight or fc.bias
                # only penalize the first p0.shape[...] entries
                if 'fc.weight' in n and p.dim()==2:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
                elif 'fc.bias' in n and p.dim()==1:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
        return (lambda_ewc / 2) * loss