# ICaRL: The strategy of taking buffer for EWC in Continual Learning

In [None]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cuda
Notebook last modified at: 2025-07-27 22:38:25


## 1. Replay Buffer with per-class exemplars capturing the distribution of each class
- The buffer is used to store a fixed number of exemplars for each class.
- The number of exemplars per class is determined by the `exemplars_per_class`
  parameter.
- The buffer is updated with new exemplars when the model encounters new classes.   

In [1]:
class ReplayBuffer:
    def __init__(self, max_per_class=100):
        """
        Initialize the replay buffer with a maximum number of exemplars per class.
        
        Args:
            max_per_class (int): Maximum number of exemplars to store for each class.
        """
        self.max_per_class = max_per_class
        self.buffer = defaultdict(list)
        
    def add_examples(self, x_batch, y_batch):
        """
        Add examples to the replay buffer.
        
        Args:
            x_batch (torch.Tensor): Batch of input data.
            y_batch (torch.Tensor): Corresponding labels for the input data.
        """
        for x, y in zip(x_batch, y_batch):
            cls = int(y.item())
            lst = self.buffer[cls]
            lst.append(x)
            # FIFO replacement if the class buffer exceeds the limit
            if len(lst) > self.max_per_class:
                lst.pop(0)
            self.buffer[cls] = lst
    
    def get_all_data(self):
        """
        Get all data from the replay buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.buffer.item():
            xs.append(torch.stack(examples)) # Collect all examples for the class
            ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
            

## 2. iCaRL buffer with memory budget memory_size

In [2]:
class ICaRLBuffer:
    def __init__(self, memory_size):
        """
        Initialize the iCaRL buffer with a memory budget.
        
        Args:
            memory_size (int): Total memory budget for the buffer.
        """
        self.memory_size = memory_size # maximum buffer size
        self.exemplar_set = defaultdict(list) # class_id -> list of exemplars tensors
        self.seen_classes = set() # keep track of seen classes
    
    def construct_exemplar_set(self, class_id, features, images, m):
        """
        Construct the exemplar set for a given class.
        
        Args:
            class_id (int): The class ID for which to construct the exemplar set.
            features (torch.Tensor): Features of the images.
            images (torch.Tensor): Corresponding images.
            m (int): Number of exemplars to select for this class.
        """
        features = F.normalize(features, dim=1) # normalize features
        class_mean = F.normalize(class_mean.unsqueeze(0), dim=1) # normalize class mean
        
        selected, exemplar_features = [], []
        used_indices = torch.zeros(features.size(0), dtype=torch.bool, device=features.device)
        
        for k in range(m):
            if k == 0:
                current_sum = 0
            else:
                current_sum = torch.stack(exemplar_features).sum(dim=0)
            mu = class_mean.squeeze(0)
            residual = mu * (k + 1) - current_sum
            distances = (features @ residual).squeeze()
            
            # Mask already used indices
            distances[used_indices] = float('-inf')
            idx = torch.argmax(distances).item()
            
            selected.append(images[idx].cpu())
            exemplar_features.append(features[idx].cpu())
            used_indices[idx] = True
            
        self.exemplar_set[class_id] = selected
        self.seen_classes.add(class_id)
    
    def reduce_exemplar_sets(self, m_per_class):
        """
        Reduce the exemplar sets to fit within the memory budget.
        
        Args:
            m_per_class (int): Number of exemplars to keep per class.
        """
        for cls in self.seen_classes:
            if len(self.exemplar_set[cls]) > m_per_class:
                self.exemplar_set[cls] = self.exemplar_set[cls][:m_per_class]
    
    def get_all_data(self):
        """
        Get all data from the iCaRL buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.exemplar_set.items():
            if examples:
                xs.append(torch.stack(examples))
                ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
    
    def get_all_data_for_task(self, class_list):
        """
        Get all data for a specific task from the iCaRL buffer.
        
        Args:
            class_list (list): List of class IDs for the task.
        
        Returns:
            TensorDataset: A dataset containing all exemplars for the specified classes.
        """
        xs, ys = [], []
        for cls in class_list:
            exemplars = self.exemplar_sets.get(cls, [])
            if len(exemplars) == 0:
                continue
            
            # Stack the list-of‐tensors → a single tensor of shape [num_exemplars_of_cls, ...]
            xs.append(torch.stack(exemplars))
            # Create a label‐tensor of shape [num_exemplars_of_cls] filled with “cls”
            ys.append(torch.full((len(exemplars),), cls, dtype=torch.long))

        if not xs:
            return None, None

        # Concatenate along the “batch” dimension
        all_x = torch.cat(xs, dim=0)
        all_y = torch.cat(ys, dim=0)
        return all_x, all_y

## 3. Small ResNet-34 model with 2 blocks
- The model is a small ResNet-34 architecture with 2 blocks.
