# ICaRL: The strategy of taking buffer for EWC in Continual Learning

In [4]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cuda
Notebook last modified at: 2025-07-27 23:02:52


## 1. Replay Buffer with per-class exemplars capturing the distribution of each class
- The buffer is used to store a fixed number of exemplars for each class.
- The number of exemplars per class is determined by the `exemplars_per_class`
  parameter.
- The buffer is updated with new exemplars when the model encounters new classes.   

In [5]:
class ReplayBuffer:
    def __init__(self, max_per_class=100):
        """
        Initialize the replay buffer with a maximum number of exemplars per class.
        
        Args:
            max_per_class (int): Maximum number of exemplars to store for each class.
        """
        self.max_per_class = max_per_class
        self.buffer = defaultdict(list)
        
    def add_examples(self, x_batch, y_batch):
        """
        Add examples to the replay buffer.
        
        Args:
            x_batch (torch.Tensor): Batch of input data.
            y_batch (torch.Tensor): Corresponding labels for the input data.
        """
        for x, y in zip(x_batch, y_batch):
            cls = int(y.item())
            lst = self.buffer[cls]
            lst.append(x)
            # FIFO replacement if the class buffer exceeds the limit
            if len(lst) > self.max_per_class:
                lst.pop(0)
            self.buffer[cls] = lst
    
    def get_all_data(self):
        """
        Get all data from the replay buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.buffer.item():
            xs.append(torch.stack(examples)) # Collect all examples for the class
            ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
            

## 2. iCaRL buffer with memory budget memory_size

In [6]:
class ICaRLBuffer:
    def __init__(self, memory_size):
        """
        Initialize the iCaRL buffer with a memory budget.
        
        Args:
            memory_size (int): Total memory budget for the buffer.
        """
        self.memory_size = memory_size # maximum buffer size
        self.exemplar_set = defaultdict(list) # class_id -> list of exemplars tensors
        self.seen_classes = set() # keep track of seen classes
    
    def construct_exemplar_set(self, class_id, features, images, m):
        """
        Construct the exemplar set for a given class.
        
        Args:
            class_id (int): The class ID for which to construct the exemplar set.
            features (torch.Tensor): Features of the images.
            images (torch.Tensor): Corresponding images.
            m (int): Number of exemplars to select for this class.
        """
        features = F.normalize(features, dim=1) # normalize features
        class_mean = F.normalize(class_mean.unsqueeze(0), dim=1) # normalize class mean
        
        selected, exemplar_features = [], []
        used_indices = torch.zeros(features.size(0), dtype=torch.bool, device=features.device)
        
        for k in range(m):
            if k == 0:
                current_sum = 0
            else:
                current_sum = torch.stack(exemplar_features).sum(dim=0)
            mu = class_mean.squeeze(0)
            residual = mu * (k + 1) - current_sum
            distances = (features @ residual).squeeze()
            
            # Mask already used indices
            distances[used_indices] = float('-inf')
            idx = torch.argmax(distances).item()
            
            selected.append(images[idx].cpu())
            exemplar_features.append(features[idx].cpu())
            used_indices[idx] = True
            
        self.exemplar_set[class_id] = selected
        self.seen_classes.add(class_id)
    
    def reduce_exemplar_sets(self, m_per_class):
        """
        Reduce the exemplar sets to fit within the memory budget.
        
        Args:
            m_per_class (int): Number of exemplars to keep per class.
        """
        for cls in self.seen_classes:
            if len(self.exemplar_set[cls]) > m_per_class:
                self.exemplar_set[cls] = self.exemplar_set[cls][:m_per_class]
    
    def get_all_data(self):
        """
        Get all data from the iCaRL buffer as a TensorDataset.
        
        Returns:
            TensorDataset: A dataset containing all exemplars in the buffer.
        """
        xs, ys = [], []
        for cls, examples in self.exemplar_set.items():
            if examples:
                xs.append(torch.stack(examples))
                ys.append(torch.full((len(examples), 1), cls, dtype=torch.long))
        if not xs:
            return None, None
        return torch.cat(xs, dim=0), torch.cat(ys, dim=0)
    
    def get_all_data_for_task(self, class_list):
        """
        Get all data for a specific task from the iCaRL buffer.
        
        Args:
            class_list (list): List of class IDs for the task.
        
        Returns:
            TensorDataset: A dataset containing all exemplars for the specified classes.
        """
        xs, ys = [], []
        for cls in class_list:
            exemplars = self.exemplar_sets.get(cls, [])
            if len(exemplars) == 0:
                continue
            
            # Stack the list-of‐tensors → a single tensor of shape [num_exemplars_of_cls, ...]
            xs.append(torch.stack(exemplars))
            # Create a label‐tensor of shape [num_exemplars_of_cls] filled with “cls”
            ys.append(torch.full((len(exemplars),), cls, dtype=torch.long))

        if not xs:
            return None, None

        # Concatenate along the “batch” dimension
        all_x = torch.cat(xs, dim=0)
        all_y = torch.cat(ys, dim=0)
        return all_x, all_y

## 3. Small ResNet-34 model with 2 blocks
- The model is a small ResNet-34 architecture with 2 blocks.


In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return F.relu(out)

class ResNetSmall(nn.Module):
    """
    Small ResNet-34 model with 2 blocks.
    """
    def __init__(self, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = BasicBlock(16, 32, stride=2)
        self.layer2 = BasicBlock(32, 64, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.feature_dim = 64
        self.fc = nn.Linear(self.feature_dim, num_classes)

    def extract_features(self, x):
        """
            Extract features from the input tensor.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
            Returns:
                torch.Tensor: Extracted features of shape (batch_size, feature_dim).
        """
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        """
            Forward pass through the network.
            Params:
                x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
        """
        feats = self.extract_features(x)
        logits = self.fc(feats)
        return feats, logits

    def expand_output(self, new_num_classes):
        """
        Expand the output layer to accommodate new classes.
        Args:
            new_num_classes (int): The new number of classes for the output layer.
        """
        old_fc = self.fc
        new_fc = nn.Linear(self.feature_dim, new_num_classes)
        with torch.no_grad():
            # copy old parameters of FC layer to newly expanded model
            new_fc.weight[:old_fc.out_features] = old_fc.weight
            new_fc.bias[:old_fc.out_features] = old_fc.bias
        self.fc = new_fc.to(old_fc.weight.device)

## 4. ResNet34

In [9]:
class BasicBlock(nn.Module):
    expansion = 1  # For BasicBlock, output channels = out_channels * 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return F.relu(out)

class ResNet34(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = 512
        self.fc = nn.Linear(self.feature_dim, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def extract_features(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)

    def forward(self, x):
        feats = self.extract_features(x)
        logits = self.fc(feats)
        return feats, logits

    def expand_output(self, new_num_classes):
        old_fc = self.fc
        new_fc = nn.Linear(self.feature_dim, new_num_classes)
        with torch.no_grad():
            new_fc.weight[:old_fc.out_features] = old_fc.weight
            new_fc.bias[:old_fc.out_features] = old_fc.bias
        self.fc = new_fc.to(old_fc.weight.device)

## 5. Elastic Weight Consolidation (EWC)

In [10]:
# ==== 3. Elastic Weight Consolidation (EWC) ====  
class EWC:
    def __init__(self, model, dataloader, device, samples=500):
        self.model = model
        self.device = device
        self.params = {n: p.clone().detach() for n, p in model.named_parameters()}
        self.fisher = self._compute_fisher(dataloader, samples)

    def _compute_fisher(self, dataloader, samples):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
        self.model.eval()
        count = 0
        for x, y in dataloader:
            x = x.to(self.device)
            self.model.zero_grad()
            feats, logits = self.model(x)
            # prob = F.softmax(logits, dim=1)

            # log_prob = F.log_softmax(logits, dim=1)[range(len(y)), y].mean()
            # log_prob.backward()

            log_probs = F.log_softmax(logits, dim=1)
            # sum negative log‐likelihood over batch
            loss_batch = -log_probs[range(len(y)), y].sum()
            loss_batch.backward()
            for n, p in self.model.named_parameters():
                fisher[n] += p.grad.data.pow(2)
            count += x.size(0)   # count by number of *samples*
            if count >= samples:
                break
        return {n: f / count for n, f in fisher.items()}

    def penalty(self, model, lambda_ewc):
        loss = 0
        for n, p in model.named_parameters():
            if n not in self.fisher: continue
            f, p0 = self.fisher[n], self.params[n]
            if p.shape == p0.shape:
                loss += (f * (p - p0).pow(2)).sum()
            else:
                # assume this is the expanded fc.weight or fc.bias
                # only penalize the first p0.shape[...] entries
                if 'fc.weight' in n and p.dim()==2:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
                elif 'fc.bias' in n and p.dim()==1:
                    loss += (f * (p[:p0.size(0)] - p0).pow(2)).sum()
        return (lambda_ewc / 2) * loss

## 6. Fine-tuning ResNet model and EWC 

In [11]:
class PerTaskEWC:
    """
    Collects multiple (params, fisher) snapshots—one per past task—
    and, at training‐time, computes the sum of all EWC penalties.
    """

    def __init__(self, model, device, ewc_paths: list):
        """
        Args:
          - model (nn.Module):  The “current” model (whose parameter names
                                must match those stored on disk).
          - device:            CPU / CUDA device.
          - ewc_paths:         List of file‐paths: ['ewc_task_1.pt', 'ewc_task_2.pt', ...].
        """
        self.device = device
        self.model = model

        self.past_task_params = []  # list of dict: each dict maps name→tensor (θ^{*(k)})
        self.past_task_fishers = [] # list of dict: each dict maps name→tensor (F^{(k)})

        # Load all saved EWC files:
        for path in ewc_paths:
            data = torch.load(path, map_location='cpu')
            # data['params'] and data['fisher'] are both dict(name→cpu_tensor)
            # Move them to the correct device now:
            params_k = {name: param.to(self.device) for name, param in data['params'].items()}
            fisher_k = {name: fisher.to(self.device) for name, fisher in data['fisher'].items()}
            self.past_task_params.append(params_k)
            self.past_task_fishers.append(fisher_k)


    def penalty(self, model, lambda_ewc):
        """
        Loops over each past task k, then each parameter name,
        and accumulates F^{(k)}_i * (θ_i - θ^{*(k)}_i)^2.

        Returns:  (λ/2) * [sum over tasks & params of F (θ - θ*)^2]
        """
        total_loss = 0.0

        # Iterate over each past‐task snapshot:
        for params_k, fisher_k in zip(self.past_task_params, self.past_task_fishers):
            for name, param in model.named_parameters():
                # If this parameter existed when snapshot_k was taken:
                if name not in fisher_k:
                    continue

                θ_star = params_k[name]      # θ^{*(k)}
                Fk      = fisher_k[name]     # F^{(k)}

                if param.shape == θ_star.shape:
                    total_loss += (Fk * (param - θ_star).pow(2)).sum()
                else:
                    # If some layers were expanded (e.g. classifier head grew),
                    # only penalize the “old” slice [0:θ_star.shape[...]].
                    if 'fc.weight' in name and param.dim() == 2:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    elif 'fc.bias' in name and param.dim() == 1:
                        total_loss += (Fk * (param[:θ_star.size(0)] - θ_star).pow(2)).sum()
                    # else: if other layers changed shape unexpectely, you may skip them.

        # Multiply by λ/2:
        return (lambda_ewc / 2) * total_loss

# Mahalanobis-distance based Detector

In [12]:
# ==== 4. Mahalanobis-distance based Detector ====  
class MahalanobisDetector:
    def __init__(self):
        self.class_means = {}
        self.precision = None
        self.tau = None

    def fit(self, model, buffer, device, reg_cov=1e-5):
        model.eval()
        X, Y = buffer.get_all_data()
        if X is None:
            return
        X, Y = X.to(device), Y.to(device)
        feats = []
        with torch.no_grad():
            for i in range(0, len(X), 256):
                batch = X[i:i+256].float()
                f, _ = model(batch)
                feats.append(f)
        feats = torch.cat(feats, dim=0)
        for c in torch.unique(Y):
            self.class_means[int(c.item())] = feats[Y==c].mean(dim=0)
        centered = feats - torch.stack([self.class_means[int(y.item())] for y in Y])
        cov = (centered.t() @ centered) / (len(Y)-1)
        cov += reg_cov * torch.eye(cov.size(0)).to(device)
        self.precision = torch.inverse(cov)

    def score(self, model, x, device):
        model.eval()
        with torch.no_grad():
            f, _ = model(x)
        dists = []
        for mu in self.class_means.values():
            diff = f - mu.unsqueeze(0)
            dists.append(torch.sum(diff @ self.precision * diff, dim=1))
        dists = torch.stack(dists, dim=1)
        # print(f"dists={dists}")
        return dists.min(dim=1)[0]

    def detect(self, model, x, device):
        return self.score(model, x, device) > self.tau

    def set_threshold(self, model, buffer, device, false_positive_rate=0.2):
        X_id, _ = buffer.get_all_data()
        if X_id is None:
            return
        d_id = []
        for i in range(0, len(X_id), 256):
            d_id.append(self.score(model, X_id[i:i+256].to(device), device))
        d_id = torch.cat(d_id)
        self.tau = torch.quantile(d_id, 1 - false_positive_rate).item()

# Training with detector gating, replay, EWC and plotting

In [13]:
import os
def train_and_plot(
    train_loaders, test_loaders, ood_loader, device, args
):
    epochs_per_task = args.epochs
    num_tasks = len(train_loaders) 
    total_epochs = epochs_per_task * args.num_tasks 
    history = {t: [np.nan] * total_epochs for t in range(num_tasks)}

    # use pretrained model
    if args.dataset not in ['mnist']:
        model = ResNet34(num_classes=args.num_cls_per_task)
        if args.use_pretrained:
            pretrained = torchvision_resnet34(pretrained=True)
            # Grab their state_dicts
            pre_sd = pretrained.state_dict()
            model_sd  = model.state_dict()
            # Filter out weights that don’t match (e.g. fc.weight, fc.bias)
            filtered_pre_sd = {
                k: v
                for k, v in pre_sd.items()
                if k in model_sd and v.shape == model_sd[k].shape
            }
            # Overwrite matching keys in our model’s dict
            model_sd.update(filtered_pre_sd)
            # Load back into our custom model
            model.load_state_dict(model_sd)
    else:
        model = ResNetSmall(num_classes=args.num_cls_per_task)
    model = model.to(device)

    log_file_path = os.path.join(args.savedir, 'output_log.txt')
    plot_file_path = os.path.join(args.savedir, 'accuracy.png')
    current_classes = 0
    buffer = iCaRLBuffer(memory_size=args.memory_size)  # set e.g., 2000
    detector = MahalanobisDetector()
    ewc = None
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    #–– Prepare a cycling OOD iterator ––
    ood_iter = iter(ood_loader)

    detector.fit(model, buffer, device)
    detector.set_threshold(model, buffer, device, args.fpr_rate)

    for t, base_loader in enumerate(train_loaders):
        # ───────────────────────────────────────────────────────────────────────────
        # (A) Gate / filter new task’s data via Mahalanobis detector (if t > 0)
        if t > 0:
            model.expand_output(current_classes + args.num_cls_per_task)
            current_classes += args.num_cls_per_task
            optimizer = optim.SGD(model.parameters(), lr=args.lr,
                                  momentum=args.momentum, weight_decay=args.weight_decay)
    
            indices = []
            dataset = base_loader.dataset
            for idx in range(len(dataset)):
                x_img, y_img = dataset[idx]
                x_tensor = x_img.unsqueeze(0).to(device)
                if detector.detect(model, x_tensor, device):
                    indices.append(idx)
            with open(log_file_path, "a") as file:
                print(f"New defect found for task {t}: {len(indices)}/{len(dataset)}", file=file)
            inbound = Subset(dataset, indices)
            new_task_dataset = inbound
        else:
            new_task_dataset = base_loader.dataset
    
        # ───────────────────────────────────────────────────────────────────────────
        # (B) Wrap new_task_dataset so that its labels are Tensors, not plain ints.
        new_task_dataset = LabelTransformDataset(new_task_dataset)
    
        #  (C) Build “combined” = {new_task_dataset} ∪ {all exemplars from buffer}
        buf_x, buf_y = buffer.get_all_data()
        if buf_x is not None:
            # Exemplar labels are already torch.LongTensor, since buf_y is a tensor.
            exemplar_dataset = TensorDataset(buf_x, buf_y)
            combined_dataset = ConcatDataset([new_task_dataset, exemplar_dataset])
        else:
            # If buffer is empty, just train on new_task_dataset
            combined_dataset = new_task_dataset
    
        combined_loader = DataLoader(
            combined_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        # -------------------------------------
        # Build a list of previous ewc paths
        # -------------------------------------
        prev_ewc_paths = [f'{args.savedir}/ewc_task_{k}.pt' for k in range(0, t)]
        print(f"len(prev_ewc_paths)={len(prev_ewc_paths)}")
        if len(prev_ewc_paths) > 0:
            per_task_ewc = PerTaskEWC(model, device, prev_ewc_paths)
        else:
            per_task_ewc = None
    
        # ───────────────────────────────────────────────────────────────────────────
        # (D) Now train for all epochs on combined_loader
        for e in range(epochs_per_task):
            global_epoch = t * epochs_per_task + e
            print(f"Global epoch: {global_epoch}")
            model.train()
    
            for idx, (x_batch, y_batch) in enumerate(combined_loader):
                x = x_batch.to(device)
                y = y_batch.to(device)
    
                feats, logits = model(x)
                loss_task = F.cross_entropy(logits, y)
    
                if per_task_ewc is not None:
                    loss_penalty = per_task_ewc.penalty(model, args.lambda_ewc)
                    loss = loss_task + loss_penalty
                else:
                    loss = loss_task

                if t > 0:
                    try:
                        x_ood, _ = next(ood_iter)
                    except StopIteration:
                        ood_iter = iter(ood_loader)
                        x_ood, _ = next(ood_iter)
                    x_ood = x_ood.to(device)
    
                    scores_in  = detector.score(model, x, device)
                    scores_out = detector.score(model, x_ood, device)
                    margin_in  = torch.clamp(scores_in  - detector.tau, min=0.0).mean()
                    margin_out = torch.clamp(detector.tau - scores_out, min=0.0).mean()
                    loss = loss + args.lambda_ood * (margin_in + margin_out)
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            # Refit detector & threshold
            detector.fit(model, buffer, device)
            detector.set_threshold(model, buffer, device, args.fpr_rate)
            # Evaluate all seen tasks
            model.eval()
            with torch.no_grad():
                for tt in range(t+1):
                    correct, total = 0, 0
                    for x_test, y_test in test_loaders[tt]:
                        x_test, y_test = x_test.to(device), y_test.to(device)
                        _, logits = model(x_test)
                        preds = logits.argmax(dim=1)
                        correct += (preds == y_test).sum().item()
                        total += y_test.size(0)
                    history[tt][global_epoch] = correct / total
        # Save model after task t
        torch.save(model.state_dict(), f'{args.savedir}/model_task_{t}.pt')
        print(f"Model saved for task {t} as '{args.savedir}/model_task_{t}.pt'")
       
       # ───────────────────────────────────────────────────────────────────────────
        #  (1) Compute “cls_start” and “cls_end” for task t
        #      Example: if num_cls_per_task = 2:
        #         task 0  → cls_start = 0*2 = 0, cls_end = 2 → classes [0,1]
        #         task 1  → cls_start = 1*2 = 2, cls_end = 4 → classes [2,3]
        #         etc.
        num_per_task = args.num_cls_per_task
        cls_start = t * num_per_task
        cls_end   = cls_start + num_per_task  # (exclusive upper bound)
    
        # ───────────────────────────────────────────────────────────────────────────
        #  (2) Build “features_all” & “labels_all” from the UNFILTERED task‐t data
        #      i.e. from base_loader.dataset, which has *all* images of these classes
        features_all, labels_all = [], []
        model.eval()
        with torch.no_grad():
            # We iterate over the *entire* base_loader.dataset to get its features+labels
            for x_batch, y_batch in DataLoader(base_loader.dataset, batch_size=64, shuffle=False):
                x_batch = x_batch.to(device)
                feats, _ = model(x_batch)        # feats.shape = [batch_size, feat_dim]
                features_all.append(feats.cpu()) # store on CPU so we can index easily
                labels_all.append(y_batch)       # y_batch is on CPU by default
    
        features_all = torch.cat(features_all, dim=0)  # shape: [N_task_samples, feat_dim]
        labels_all   = torch.cat(labels_all, dim=0)    # shape: [N_task_samples]
    
        # ───────────────────────────────────────────────────────────────────────────
        #  (3) Compute how many exemplars PER class we should keep
        #      Because the buffer must hold at most memory_size total exemplars,
        #      and we now have seen (t+1)*num_per_task distinct classes.
        num_seen_classes = (t + 1) * num_per_task
        m_per_class = buffer.memory_size // num_seen_classes
    
        # ───────────────────────────────────────────────────────────────────────────
        #  (4) For each *new* class cls in [cls_start, cls_end), pick m_per_class exemplars
        for cls in range(cls_start, cls_end):
            # Boolean mask: selects features_all[i] whose label == cls
            cls_mask = (labels_all == cls)            # shape: [N_task_samples], True for indices of class `cls`
            idxs = torch.nonzero(cls_mask, as_tuple=True)[0].tolist()
            # Now `idxs` indexes INTO base_loader.dataset. Since base_loader.dataset
            # has ALL samples of class `cls`, idxs can never be empty here (unless you truly had 0 samples)
            if len(idxs) == 0:
                # (This check is purely defensive; normally base_loader.dataset must have at least 1 sample per class.)
                print(f"[Warning] No samples found for class {cls} in base_loader.dataset!")
                continue
    
            # Build a list of raw images (tensors) for these indices
            imgs_list = [base_loader.dataset[i][0] for i in idxs]
            # Stack them to a single tensor of shape [N_cls_samples, C, H, W]
            imgs_cls = torch.stack(imgs_list, dim=0)
    
            # Also extract the corresponding feature vectors from features_all
            feats_cls = features_all[cls_mask]       # shape [N_cls_samples, feat_dim]
    
            # Now finally construct the exemplar set for class `cls` via herding
            buffer.construct_exemplar_set(cls, feats_cls, imgs_cls, m=m_per_class)
        # (4b) After we add new exemplars, we must *shrink* older classes’ exemplars
        buffer.reduce_exemplar_sets(m_per_class)
    
        # Update how many classes we have processed so far
        current_classes = num_seen_classes
    
        # ───────────────────────────────────────────────────────────────────────────
        #  (5) Now that `buffer` is non‐empty, rebuild EWC & refit the Mahalanobis detector
        # all_x, all_y = buffer.get_all_data()
        all_x, all_y = buffer.get_all_data_for_task([i for i in range(t * num_per_task, (t+1) * num_per_task)])  # or raw Task t data
        if all_x is not None:
            # (5a) Recompute EWC Fisher/info on the *entire* exemplar set
            dataset_all = TensorDataset(all_x.float(), all_y)
            ewc = EWC(model, DataLoader(dataset_all, batch_size=32, shuffle=True), device)
    
            # (5b) Refit the Mahalanobis detector on all exemplars
            detector.fit(model, buffer, device)
            detector.set_threshold(model, buffer, device, args.fpr_rate)
    
            # (5c) (Optional) Save EWC state to disk
            ewc_data = {
                'params': {k: v.cpu() for k, v in ewc.params.items()},
                'fisher': {k: v.cpu() for k, v in ewc.fisher.items()}
            }
            torch.save(ewc_data, f'{args.savedir}/ewc_task_{t}.pt')
            print(f"EWC saved for task {t} as '{args.savedir}/ewc_task_{t}.pt'")

    sum_acc = 0
    for i in range(len(train_loaders)):
        sum_acc += history[i][global_epoch] * 100
        with open(log_file_path, "a") as file:
            print(f"Final test accuracy for task {i}: {history[i][global_epoch] * 100}", file=file)
    with open(log_file_path, "a") as file:
        print(f"Average accuracy across {len(train_loaders)} tasks: {sum_acc / len(train_loaders)}", file=file)
    # Plot
    plt.figure(figsize=(8,6))
    labels = [f'Task {t}' for t in range(num_tasks)]
    for t in range(num_tasks):
        arr = np.array(history[t])
        idxs = ~np.isnan(arr)
        plt.plot(np.arange(total_epochs)[idxs], arr[idxs], label=labels[t])
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy')
    plt.legend()
    plt.grid(True)
    plt.ylim(0.2, 1.02)
    plt.savefig(plot_file_path)
    plt.show()

In [14]:
from torch.utils.data import DataLoader, ConcatDataset, Dataset, TensorDataset, Subset
class LabelTransformDataset(Dataset):
    """
    Wrap any dataset whose __getitem__(idx) returns (img_tensor, int_label)
    and turn int_label → LongTensor(label).
    """
    def __init__(self, base_dataset):
        self.dataset = base_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x, torch.tensor(y, dtype=torch.long)
from torchvision.datasets import CIFAR10, STL10, MNIST, ImageFolder
import itertools
from types import SimpleNamespace
import random
import argparse
import os
# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Define your hyperparameter search space
hyperparams = {
    # 'lr': [0.001, 0.003],
    'lambda_ewc': [0.5],
    'lambda_ood': [1.0],
    # 'batch_size': [128],
    'fpr_rate': [0.1],
    # 'max_per_class': [100],
    'dataset': ['cifar10']
}

# Generate all combinations of hyperparameters
keys, values = zip(*hyperparams.items())
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]


# Loop through each hyperparameter combination
for i, exp in enumerate(experiments):
    print(f"\n=== Running experiment {i+1}/{len(experiments)}: {exp} ===")
    
    args = SimpleNamespace(
        use_pretrained=True,
        num_tasks=4,
        num_cls_per_task=2,
        lambda_ewc=exp['lambda_ewc'],
        lambda_ood=exp['lambda_ood'],
        fpr_rate=exp['fpr_rate'],
        max_per_class=50,
        lr=0.003,
        momentum=0.9,
        weight_decay=1e-4,
        epochs=100,
        batch_size=128,
        dataset=exp['dataset'],
        datadir='',
        num_ood_cls=2,
        memory_size=1600
    )
    tasks = [list(range(i,i+args.num_cls_per_task)) for i in range(0,args.num_cls_per_task*args.num_tasks,args.num_cls_per_task)]
    if args.dataset != 'mnist':
        train_transform = transforms.Compose([
            # transforms.Resize(size=(224,224)),
            # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip(),              
            transforms.RandomApply([
                    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
                ], p=0.8),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset_path = os.path.join(os.getcwd(), args.dataset)
    
    # Check if the subdirectory exists, and create it if it doesn't
    if not os.path.exists(dataset_path):
        os.mkdir(dataset_path)
    
    args.savedir = os.path.join(dataset_path, f'{args.num_tasks}task_{args.num_cls_per_task}pertask_ewc{args.lambda_ewc}_ood{args.lambda_ood}_fpr{args.fpr_rate}_lr{args.lr}_{args.max_per_class}percls_{args.epochs}epochs_batch{args.batch_size}')
    if not os.path.exists(args.savedir):
        os.mkdir(args.savedir)
    
    if args.dataset == 'cifar10':
        train_set = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        test_set  = CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    elif args.dataset == 'mnist':
        train_set = MNIST(root='./data', train=True, download=True, transform=transform)
        test_set = MNIST(root='./data', train=False, download=True, transform=transform)
    elif args.dataset == 'leather':
        train_set = ImageFolder('/kaggle/input/leather-defect-classification/Leather Defect Classification/train', transform=train_transform)
        test_set = ImageFolder('/kaggle/input/leather-defect-classification/Leather Defect Classification/valid', transform=test_transform)
    elif args.dataset == 'texture':
        train_set = ImageFolder('/kaggle/input/texture-classification/Texture/train', transform=train_transform)
        test_set = ImageFolder('/kaggle/input/texture-classification/Texture/valid', transform=test_transform)
    elif args.dataset == 'neu':
        train_set = ImageFolder('/kaggle/input/neu-surface-defect-database/NEU-DET/train/images', transform=train_transform)
        test_set = ImageFolder('/kaggle/input/neu-surface-defect-database/NEU-DET/validation/images', transform=test_transform)
    elif args.dataset == 'gc10':
        train_set = ImageFolder('/kaggle/input/gc10-det/GC10-DET/train', transform=train_transform)
        test_set = ImageFolder('/kaggle/input/gc10-det/GC10-DET/valid', transform=test_transform)
    elif args.dataset == 'ncat12':
        train_set = ImageFolder('/kaggle/input/ncat12-det/NCAT12-DET/train', transform=train_transform)
        test_set = ImageFolder('/kaggle/input/ncat12-det/NCAT12-DET/valid', transform=test_transform)
    else:
        train_set = ImageFolder(args.datadir + '/train', transform=transform)
        test_set = ImageFolder(args.datadir + '/valid', transform=transform)
    
    def make_loader(dataset, task, batch_size=64, shuffle=True):
        indices = [i for i, (_, y) in enumerate(dataset) if y in task]
        subset = Subset(dataset, indices)
        return DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)
    
    class_indices = list(range(len(train_set.classes)))
    train_loaders = [make_loader(train_set, task, batch_size=args.batch_size) for task in tasks]
    test_loaders  = [make_loader(test_set, task, batch_size=args.batch_size)  for task in tasks]
    num_ood_cls = args.num_ood_cls if args.num_ood_cls is not None else len(train_set.classes) - args.num_tasks * args.num_cls_per_task
    ood_classes = class_indices[-num_ood_cls:]
    ood_loader = make_loader(train_set, ood_classes, shuffle=False)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"device: {device}")
    # Run training
    train_and_plot(train_loaders, test_loaders, ood_loader, device, args)


=== Running experiment 1/1: {'lambda_ewc': 0.5, 'lambda_ood': 1.0, 'fpr_rate': 0.1, 'dataset': 'cifar10'} ===


100%|██████████| 170M/170M [00:28<00:00, 5.90MB/s] 


device: cuda




Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /home/dikhang_hcmut/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:02<00:00, 42.2MB/s]


NameError: name 'iCaRLBuffer' is not defined