In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
import numpy as np
from copy import deepcopy
from typing import Optional, cast
from tqdm import tqdm
import random

# ==========================================
# 1. Model Definition (CustomDino)
# ==========================================
class CustomDino(nn.Module):
    def __init__(self, num_classes: int = 100, backbone: Optional[nn.Module] = None):
        super().__init__()
        if backbone is None:
            backbone = cast(nn.Module, torch.hub.load(
                "facebookresearch/dino:main", "dino_vits16", pretrained=True
            ))
        self.backbone: nn.Module = backbone
        # We need a scalable classifier that can grow
        self.classifier = nn.Linear(384, num_classes)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)        # [batch, 384]
        logits = self.classifier(features) # [batch, num_classes]
        return logits, features

# ==========================================
# 2. iCaRL Logic Class
# ==========================================
class iCaRL:
    def __init__(self, num_classes=100, memory_size=2000, feature_dim=384, device='cuda'):
        self.device = device
        self.num_classes = num_classes
        self.memory_size = memory_size
        self.feature_dim = feature_dim

        # Initialize Model
        self.model = CustomDino(num_classes=num_classes).to(self.device)
        self.old_model = None # Snapshot of model before current task

        # Memory (Exemplars)
        self.exemplar_sets = [] # List of lists (images per class)
        self.exemplar_means = [] # Class prototypes for NME

        # Training Parameters
        self.lr = 0.01
        self.weight_decay = 1e-5
        self.momentum = 0.9
        self.epochs = 20  # Reduced for demo speed (standard is often higher)

    def update_representation(self, train_loader, new_classes):
        """
        Step 1: Train the model using Classification + Distillation Loss
        """
        print(f"--- Updating Representation for classes {new_classes} ---")

        # 1. Combine new data with exemplars
        # (In this simplified script, we assume train_loader already mixes them if available
        # or we just iterate. For strict iCaRL, we augment the batch with exemplars).
        # To keep it simple for Colab, we will rely on the DataLoader having the mix.

        optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        # Scheduler helps convergence
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)

        self.model.train()
        if self.old_model:
            self.old_model.eval()

        for epoch in range(self.epochs):
            total_loss = 0
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}", leave=False):
                images = images.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()

                # Forward Pass
                logits, _ = self.model(images)

                # --- Loss Calculation ---
                # A. Classification Loss (Cross Entropy on all visible classes)
                loss_cls = F.cross_entropy(logits, labels)

                # B. Distillation Loss (on OLD classes only)
                loss_dist = torch.tensor(0.).to(self.device)
                if self.old_model is not None:
                    # Get old logits
                    with torch.no_grad():
                        old_logits, _ = self.old_model(images)

                    # Sigmoid Distillation (Rebuffi et al. 2017)
                    # We compute BCE between the sigmoid outputs of the new model and the old model
                    # solely for the classes the old model knew.
                    known_classes = self.old_model.classifier.out_features
                    # Usually iCaRL assumes specific output nodes. Here we map indices.
                    # We assume indices 0 to (start of new task) are old classes.

                    # Create a mask for old classes (e.g., 0 to 10, then 0 to 20...)
                    # The 'old_logits' typically has size [B, num_classes] same as current if architecture is fixed
                    # Or [B, old_num_classes] if it grew. DINO linear layer is usually fixed size or grows.
                    # Here we assume fixed size 100 for simplicity.

                    # Calculate Distillation:
                    # T=1 is standard for iCaRL's sigmoid distillation
                    #[:, :start_new_task] Are all the old classes the new model should not forget
                    start_new_task = new_classes[0]
                    if start_new_task > 0:
                        dist_target = torch.sigmoid(old_logits[:, :start_new_task])
                        dist_pred = torch.sigmoid(logits[:, :start_new_task])
                        loss_dist = F.binary_cross_entropy(dist_pred, dist_target)

                loss = loss_cls + loss_dist
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            scheduler.step()
            # print(f"Epoch {epoch}: Loss {total_loss:.4f}")

        # Update the frozen old model
        self.old_model = deepcopy(self.model)
        for p in self.old_model.parameters():
            p.requires_grad = False

    def reduce_exemplar_sets(self, m):
        """
        Step 2: Shrink stored exemplars to fit memory budget.
        m = memory_size / num_classes_seen_so_far
        """
        print(f"Reducing exemplars to {m} per class...")
        for y in range(len(self.exemplar_sets)):
            self.exemplar_sets[y] = self.exemplar_sets[y][:m]

    def construct_exemplar_sets(self, images, m, transform):
        """
        Step 3: Select new exemplars using Herding (nearest to mean).
        """
        print(f"Constructing exemplars (Herding)... target {m} per class")
        self.model.eval()

        # Compute mean of the class
        with torch.no_grad():
            # Extract features
            # Note: We need a loader to process 'images' (which is a list/tensor of raw images)
            # For efficiency in this script, we assume 'images' fits in VRAM or we batch it.
            # Simplified:
            img_tensor = torch.stack(images).to(self.device)
            _, features = self.model(img_tensor)
            features = F.normalize(features, p=2, dim=1)
            class_mean = torch.mean(features, dim=0)

            # Herding Selection
            exemplar_set = []
            exemplar_features = []

            # We assume features are [N, D]
            # We iterate m times to pick m samples
            for k in range(m):
                S = torch.sum(torch.stack(exemplar_features), dim=0) if len(exemplar_features) > 0 else torch.zeros(self.feature_dim).to(self.device)

                # Objective: minimize || class_mean - (S + phi(x)) / k   ||
                phi = features # [N, D]
                mu = class_mean # [D]

                # Distance for all candidates
                dists = torch.norm(mu - ((S + phi)/k), dim=1)

                # Pick best that isn't already chosen (simple way: set dist to inf)
                # In strict implementation, we remove the index.
                best_idx = torch.argmin(dists).item()

                exemplar_set.append(images[best_idx])
                exemplar_features.append(features[best_idx])

                # Mask this index so it's not picked again
                features[best_idx] = features[best_idx] + 1000 # Hacky mask

            self.exemplar_sets.append(exemplar_set)

    def classify_nme(self, x):
        """
        Step 4: Classification using Nearest Mean of Exemplars.
        Strict Implementation of Algorithm 1 & Eq. 2
        """
        self.model.eval()
        with torch.no_grad():
            # 1. Get features of the image to classify
            _, query_features = self.model(x.to(self.device))
            # Normalize query features (Section 2.1)
            query_features = F.normalize(query_features, p=2, dim=1)

            # 2. Compute Prototypes (Means of Exemplars)
            means = []
            for y in range(len(self.exemplar_sets)):
                # Get all exemplars for class y
                ex_imgs = torch.stack(self.exemplar_sets[y]).to(self.device)

                # Extract features for exemplars
                _, ex_feats = self.model(ex_imgs)

                # Normalize exemplar features BEFORE averaging (Section 2.1)
                ex_feats = F.normalize(ex_feats, p=2, dim=1)

                # Compute the mean
                class_mean = torch.mean(ex_feats, dim=0)

                # Re-normalize the mean vector itself (Section 2.1: "averages are also re-normalized")
                class_mean = F.normalize(class_mean.unsqueeze(0), p=2, dim=1).squeeze(0)

                means.append(class_mean)

            if len(means) == 0: return torch.zeros(x.size(0))

            means = torch.stack(means) # [Num_Classes_Seen, Feature_Dim]

            # 3. Find Nearest Prototype (Algorithm 1)
            # "y* = argmin || phi(x) - mu_y ||"
            dists = torch.cdist(query_features, means) # [Batch, Num_Classes]
            preds = torch.argmin(dists, dim=1)

        return preds

# ==========================================
# 3. Data Utilities
# ==========================================
class iCaRLDataset(Dataset):
    """
    Dataset that combines new task data with stored exemplars.
    """
    def __init__(self, new_data, exemplars, transform=None):
        self.new_data = new_data # List of (image, label) tuples
        self.exemplars = exemplars # List of lists of images
        self.transform = transform

        # Flatten exemplars into a list of (img, label)
        self.exemplar_data = []
        for label, img_list in enumerate(exemplars):
            for img in img_list:
                self.exemplar_data.append((img, label))

        self.all_data = self.new_data + self.exemplar_data

    def __getitem__(self, index):
        img, label = self.all_data[index]
        # img is a Tensor here if coming from CIFAR100(ToTensor),
        # but iCaRL usually stores raw images.
        # For simplicity in this script, we assume img is already Tensor from prev loader
        # If transform is needed, apply here.
        return img, label

    def __len__(self):
        return len(self.all_data)

def get_data_for_classes(dataset, classes):
    """
    Extracts all samples belonging to specific classes.
    """
    indices = [i for i, label in enumerate(dataset.targets) if label in classes]
    return Subset(dataset, indices)

def extract_images_from_subset(subset):
    """
    Helper to pull images out of a Subset for exemplar storage.
    """
    images = []
    # This is slow for large sets, efficient implementation would use indices directly
    # But for a tutorial script, iterating is safe.
    for i in range(len(subset)):
        img, _ = subset[i]
        images.append(img)
    return images

# ==========================================
# 4. Main Experiment Loop
# ==========================================
def main():
    print("Preparing Data...")
    # Transforms
    stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    transform = Compose([
        Resize(256), CenterCrop(224),
        ToTensor(),
        Normalize(*stats),
    ])

    # Load FULL Datasets
    train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    # Initialize iCaRL
    icarl = iCaRL(num_classes=100, memory_size=2000, device='cuda')

    # Define Tasks (e.g., 5 tasks of 20 classes each)
    TASKS = 5
    CLASSES_PER_TASK = 100 // TASKS

    accuracies = []

    for task_id in range(TASKS):
        # 1. Define Classes for this Task
        start_class = task_id * CLASSES_PER_TASK
        end_class = (task_id + 1) * CLASSES_PER_TASK
        new_classes = list(range(start_class, end_class))

        print(f"\n================ TASK {task_id+1}/{TASKS} : Classes {new_classes} ================")

        # 2. Prepare Training Data (New Data + Exemplars)
        # Get subset of ONLY new classes
        task_data_subset = get_data_for_classes(train_ds, new_classes)

        # Create a list of (img, label) for the custom dataset
        # We iterate once to cache them (RAM intensive but simpler code)
        new_data_list = []
        for i in range(len(task_data_subset)):
            img, target = task_data_subset[i]
            new_data_list.append((img, target))

        # Create Hybrid Dataset
        train_dataset = iCaRLDataset(new_data_list, icarl.exemplar_sets)
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

        # 3. Train (Update Representation)
        icarl.update_representation(train_loader, new_classes)

        # 4. Exemplar Management
        # A. Reduce old sets
        m = icarl.memory_size // end_class
        icarl.reduce_exemplar_sets(m)

        # B. Construct new sets
        for c in new_classes:
            # Extract images for specific class c
            # (Re-extract from subset for clean separation)
            class_subset = get_data_for_classes(train_ds, [c])
            images_c = extract_images_from_subset(class_subset)
            icarl.construct_exemplar_sets(images_c, m, transform)

        # 5. Evaluate on ALL classes seen so far
        print("Evaluating...")
        seen_classes = list(range(0, end_class))
        test_subset = get_data_for_classes(test_ds, seen_classes)
        test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

        correct = 0
        total = 0
        for imgs, lbls in tqdm(test_loader):
            imgs = imgs.to(icarl.device)
            lbls = lbls.to(icarl.device)
            preds = icarl.classify_nme(imgs)
            correct += preds.eq(lbls).sum().item()
            total += lbls.size(0)

        acc = 100. * correct / total
        accuracies.append(acc)
        print(f"Task {task_id+1} Accuracy (NME): {acc:.2f}%")

    print("\nFinal Accuracies per Task:", accuracies)

if __name__ == "__main__":
    main()

100%|██████████| 169M/169M [00:03<00:00, 42.9MB/s]


Initializing CustomDino on cuda...
Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth


100%|██████████| 82.7M/82.7M [00:00<00:00, 303MB/s]
Epoch 1/50: 100%|██████████| 782/782 [08:29<00:00,  1.54it/s]


Epoch 1 | Loss: 7.7707 | Softmax Acc: 1.09% | LR: 0.00999


Epoch 2/50:  63%|██████▎   | 489/782 [05:20<03:11,  1.53it/s]


KeyboardInterrupt: 