## ENVIRONMENT SETUP (Kaggle)

In [None]:
# ============================================================
# Clears workspace, clones LabelBench, installs dependencies
# ============================================================

!rm -rf /kaggle/working/*
!git clone https://github.com/EfficientTraining/LabelBench.git
!pip install -r /kaggle/working/LabelBench/requirements.txt
%cd /kaggle/working/LabelBench



In [2]:
ls

[0m[01;34mconfigs[0m/        [01;34mLabelBench[0m/  mp_eval_launcher.py  README.md
[01;34mdocs[0m/           LICENSE      mp_launcher.py       requirements.txt
example_run.sh  main.py      point_evaluation.py  [01;34mresults[0m/


## DATASET SETUP: CIFAR-10 STREAM FOR OPEN-WORLD LEARNING

In [3]:
import torch
import torch.nn.functional as F
from torchvision import datasets as tv_datasets, transforms
from torch.utils.data import Dataset
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset
import numpy as np

NUM_TASKS = 20   # Total number of incremental tasks


# ------------------------------------------------------------
# Dataset wrapper to expose only a subset of CIFAR-10 indices
# ------------------------------------------------------------
class CIFARStream(Dataset):
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds      # Full CIFAR-10 dataset
        self.indices = indices      # Indices assigned to this task

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Fetch the actual sample using stored indices
        x, y = self.base_ds[self.indices[idx]]
        return x, y


# ------------------------------------------------------------
# One-hot encoding helper (required by LabelBench)
# ------------------------------------------------------------
def one_hot(y, n=10):
    return F.one_hot(torch.tensor(y), num_classes=n).float()


# ------------------------------------------------------------
# Base dataset registry (intentionally disabled)
# ------------------------------------------------------------
@register_dataset("splitcifar10", LabelType.MULTI_CLASS)
def get_splitcifar10(_):
    # Prevent accidental usage of the base dataset
    raise RuntimeError("Use splitcifar10_<id>")


# ------------------------------------------------------------
# Build task stream ONCE (randomized split of CIFAR-10)
# ------------------------------------------------------------
base_train_global = tv_datasets.CIFAR10(root="./data", train=True, download=True)
all_idx = np.arange(len(base_train_global))
np.random.shuffle(all_idx)
stream_splits = np.array_split(all_idx, NUM_TASKS)


# ------------------------------------------------------------
# Register each task dynamically
# ------------------------------------------------------------
for split_id in range(NUM_TASKS):

    @register_dataset(f"splitcifar10_{split_id}", LabelType.MULTI_CLASS)
    def _make_split(data_dir, split_id=split_id):

        tf = transforms.Compose([transforms.ToTensor()])

        base_train = tv_datasets.CIFAR10(root=data_dir, train=True, download=True)
        base_test  = tv_datasets.CIFAR10(root=data_dir, train=False, download=True)

        # ----------------------------------------------------
        # TASK 0: Base session (only classes 0 and 1)
        # Fully supervised
        # ----------------------------------------------------
        if split_id == 0:
            indices = [i for i,(x,y) in enumerate(base_train) if y in [0,1]]
        else:
            # ------------------------------------------------
            # TASKS 1â€“19: Unlabeled streaming data
            # ------------------------------------------------
            indices = stream_splits[split_id]

        train_ds = CIFARStream(base_train, indices)

        # Wrap dataset with transforms + one-hot labels
        train_ds = TransformDataset(
            train_ds,
            transform=tf,
            target_transform=lambda y: one_hot(y,10)
        )

        test_ds = TransformDataset(
            base_test,
            transform=tf,
            target_transform=lambda y: one_hot(y,10)
        )

        return train_ds, test_ds, test_ds, None, None, None, 10, [str(i) for i in range(10)]

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 170M/170M [00:01<00:00, 101MB/s]  


## Model Training with Novelty Detection and Clustering

In [4]:
# ============================================================
# MODEL DEFINITION (ResNet18 + Embedding Head)
# ============================================================

import torch.nn as nn
from collections import defaultdict, Counter
from torch.utils.data import DataLoader
import hdbscan
from torchvision.models import resnet18
from LabelBench.skeleton.dataset_skeleton import datasets as DATASET_REGISTRY


class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Pretrained ResNet-18 backbone
        base = resnet18(pretrained=True)

        # Remove classification head â†’ keep feature extractor
        self.encoder = nn.Sequential(*list(base.children())[:-1])

        # Project features to a compact embedding space
        self.embed = nn.Linear(512, 128)

        # Classification head (dynamically expanded later)
        self.classifier = nn.Linear(128, num_classes)

    def expand_head(self, new_classes):
        # Save old classifier weights
        old_w = self.classifier.weight.data.clone()
        old_b = self.classifier.bias.data.clone()
        old_n = old_w.shape[0]

        # Create a larger classifier
        self.classifier = nn.Linear(128, new_classes)

        # Copy old weights into the new head
        self.classifier.weight.data[:old_n] = old_w
        self.classifier.bias.data[:old_n] = old_b

    def forward(self, x):
        # Extract CNN features
        z = self.encoder(x).squeeze()

        # Project to embedding
        z = self.embed(z)

        # Normalize embeddings (important for cosine geometry)
        z = F.normalize(z, dim=1)

        # Classification logits
        logits = self.classifier(z)

        return logits, z


# ============================================================
# MEMORY BUFFER (CLASS PROTOTYPES)
# ============================================================

class MemoryBuffer:
    def __init__(self, max_per_class=20):
        # Stores embeddings per class
        self.data = defaultdict(list)
        self.max_per_class = max_per_class

    def add_batch(self, Z, y):
        # Add embeddings to memory
        for z in Z:
            self.data[int(y)].append(z.detach().cpu())

        # Reduce memory to fixed size
        self._reduce_class(y)

    def _reduce_class(self, y):
        Z = self.data[int(y)]

        # Do nothing if under capacity
        if len(Z) <= self.max_per_class:
            return

        Z = torch.stack(Z)

        # Compute class centroid (direction)
        mu = F.normalize(Z.mean(0), dim=0)

        # Cosine distance to centroid
        d = 1 - torch.matmul(Z, mu)

        # Keep closest samples to centroid
        idx = torch.argsort(d)[:self.max_per_class]

        self.data[int(y)] = [Z[i] for i in idx]

    def get(self):
        return self.data


# ============================================================
# NOVELTY DETECTOR (HYPERSPHERE PER CLASS)
# ============================================================

class HypersphereNovelty:
    def __init__(self, q=0.95):
        self.q = q    # Quantile used to define class radius
        self.mu = {}  # Class centroids
        self.r = {}   # Class radii

    def update(self, memory):
        # Recompute centroid and radius for each known class
        self.mu, self.r = {}, {}

        for k, Z in memory.items():
            Z = torch.stack(Z)
            mu = F.normalize(Z.mean(0), dim=0)

            # Distance of samples from centroid
            d = 1 - torch.matmul(Z, mu)

            # Radius enclosing q% of samples
            r = torch.quantile(d, self.q)

            self.mu[k] = mu
            self.r[k] = r

    def score(self, z):
        # Novelty score = how far sample is from nearest known class
        scores = []

        for k in self.mu:
            d = 1 - torch.dot(z.cpu(), self.mu[k])
            scores.append(d - self.r[k])

        return min(scores)


# ============================================================
# TRAINING FUNCTIONS
# ============================================================

def train_supervised(model, loader, device, epochs=15):
    # Supervised training for base task (Task 0)
    opt = torch.optim.Adam(model.parameters(), 1e-4)
    model.train()

    for _ in range(epochs):
        for x, y in loader:
            x = x.to(device)
            y = y.argmax(1).to(device)

            logits, _ = model(x)
            loss = F.cross_entropy(logits, y)

            opt.zero_grad()
            loss.backward()
            opt.step()


def finetune(model, memory, Z_new, new_label, device, epochs=3):
    # Finetunes classifier using memory + new class samples
    X, Y = [], []

    for cls, Zs in memory.items():
        for z in Zs:
            X.append(z)
            Y.append(cls)

    for z in Z_new:
        X.append(z.cpu())
        Y.append(new_label)

    X = torch.stack(X).detach().to(device)
    Y = torch.tensor(Y).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()

    for _ in range(epochs):
        logits = model.classifier(X)
        loss = F.cross_entropy(logits, Y)

        opt.zero_grad()
        loss.backward()
        opt.step()


# ============================================================
# MAIN OPEN-WORLD LEARNING LOOP
# ============================================================

device = "cuda" if torch.cuda.is_available() else "cpu"

TASKS = [f"splitcifar10_{i}" for i in range(NUM_TASKS)]

model = CNN(num_classes=2).to(device)
memory = MemoryBuffer(max_per_class=20)
detector = HypersphereNovelty(q=0.95)

novelty_buffer = []   # Stores unresolved novel samples
known_classes = 2     # Starts with classes {0,1}

# Tuned thresholds
COH_THR = 0.35
SEP_THR = 0.1
PURITY_THR = 0.35
NOVELTY_THR = 0.3

P = 4
MAX_NOVELTY_BUFFER = 300
TOPK_NOVELTY = 120


for t, task in enumerate(TASKS):

    print(f"\n================ TASK {t} ================")

    _, dataset_fn = DATASET_REGISTRY[task]
    train_ds, _, _, _, _, _, _, _ = dataset_fn("./data")
    loader = DataLoader(train_ds, batch_size=64, shuffle=False)

    # --------------------------------------------------------
    # TASK 0: Base supervised training
    # --------------------------------------------------------
    if t == 0:
        train_supervised(model, loader, device)

        with torch.no_grad():
            for x, y in loader:
                x = x.to(device)
                y = y.argmax(1)

                _, z = model(x)

                for cls in torch.unique(y):
                    idx = (y == cls).nonzero().squeeze()
                    memory.add_batch(z[idx], cls.item())

        detector.update(memory.get())
        print("âœ… [INFO] Task0 training complete")
        continue

    # --------------------------------------------------------
    # STAGE I: Novelty routing
    # --------------------------------------------------------
    model.eval()
    novelty_candidates = []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.argmax(1)

            _, z = model(x)

            for i in range(len(z)):
                s = detector.score(z[i])

                if s > NOVELTY_THR:
                    novelty_candidates.append((s, z[i].cpu(), y[i].item()))

    novelty_candidates.sort(key=lambda x: x[0], reverse=True)
    novelty_candidates = novelty_candidates[:TOPK_NOVELTY]

    novelty_buffer.extend([(z, y) for _, z, y in novelty_candidates])

    if len(novelty_buffer) > MAX_NOVELTY_BUFFER:
        novelty_buffer = novelty_buffer[-MAX_NOVELTY_BUFFER:]

    print(f"[DEBUG] Task {t} novelty buffer size = {len(novelty_buffer)}")

    # --------------------------------------------------------
    # STAGE II: Clustering (every P tasks)
    # --------------------------------------------------------
    if (t % P) != 0:
        print("[DEBUG] Skipping Stage II")
        continue

    print("[DEBUG] Entering Stage II clustering")
    print("Novelty buffer label distribution:", Counter([y for _, y in novelty_buffer]))

    Z = np.stack([z.numpy() for z, _ in novelty_buffer])

    labels = hdbscan.HDBSCAN(
        metric='euclidean',
        min_cluster_size=5,
        min_samples=3
    ).fit_predict(Z)

    print(f"[DEBUG] Clusters found: {set(labels)}")

    new_buffer = []

    for cid in set(labels):
        idxs = np.where(labels == cid)[0]

        # ------------------------------
        # Noise cluster â†’ keep samples
        # ------------------------------
        if cid == -1:
            print("Noise cluster kept")
            for i in idxs:
                new_buffer.append(novelty_buffer[i])
            continue

        cluster = [novelty_buffer[i] for i in idxs]
        Zc = torch.stack([z for z, _ in cluster])
        true_labels = [y for _, y in cluster]

        mu_c = F.normalize(Zc.mean(0), dim=0)
        d = 1 - torch.matmul(Zc, mu_c)
        Scoh = torch.quantile(d, 0.9)

        sep = min([(1 - torch.dot(mu_c, detector.mu[k])) - detector.r[k] for k in detector.mu])

        counter = Counter(true_labels)
        semantic_label, count = counter.most_common(1)[0]
        purity = count / len(true_labels)

        print(f"ðŸ“Š [CLUSTER {cid}] size={len(cluster)} | Scoh={Scoh:.3f} | Sep={sep:.3f} | semantic_label={semantic_label} | purity={purity:.2f}")

        # ------------------------------
        # Reject but KEEP samples
        # ------------------------------
        if semantic_label < known_classes or purity < PURITY_THR:
            print("Blocked (kept)")
            for i in idxs:
                new_buffer.append(novelty_buffer[i])
            continue

        # ------------------------------
        # Promote new class
        # ------------------------------
        if Scoh <= COH_THR and sep >= SEP_THR:
            new_label = known_classes
            known_classes += 1

            print(f"PROMOTED NEW CLASS {new_label}")

            model.expand_head(known_classes)
            model.to(device)

            finetune(model, memory.get(), Zc, new_label, device)
            memory.add_batch(Zc, new_label)
        else:
            print("Blocked: cohesion/separation failed (kept)")
            for i in idxs:
                new_buffer.append(novelty_buffer[i])

    novelty_buffer = new_buffer
    detector.update(memory.get())
    torch.cuda.empty_cache()

  $max \{ core_k(a), core_k(b), 1/\alpha d(a,b) \}$.


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 44.7M/44.7M [00:00<00:00, 191MB/s]



âœ… [INFO] Task0 training complete

[DEBUG] Task 1 novelty buffer size = 120
[DEBUG] Skipping Stage II

[DEBUG] Task 2 novelty buffer size = 240
[DEBUG] Skipping Stage II

[DEBUG] Task 3 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 4 novelty buffer size = 300
[DEBUG] Entering Stage II clustering
Novelty buffer label distribution: Counter({3: 50, 5: 49, 7: 43, 4: 36, 8: 34, 6: 33, 9: 27, 2: 26, 1: 2})
[DEBUG] Clusters found: {np.int64(0), np.int64(1), np.int64(-1)}
ðŸ“Š [CLUSTER 0] size=53 | Scoh=0.013 | Sep=0.736 | semantic_label=3 | purity=0.21
Blocked (kept)
ðŸ“Š [CLUSTER 1] size=48 | Scoh=0.026 | Sep=0.831 | semantic_label=3 | purity=0.21
Blocked (kept)
Noise cluster kept





[DEBUG] Task 5 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 6 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 7 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 8 novelty buffer size = 300
[DEBUG] Entering Stage II clustering
Novelty buffer label distribution: Counter({5: 44, 3: 40, 6: 40, 7: 37, 8: 37, 4: 36, 9: 36, 2: 27, 0: 2, 1: 1})
[DEBUG] Clusters found: {np.int64(0), np.int64(1), np.int64(-1)}
ðŸ“Š [CLUSTER 0] size=48 | Scoh=0.024 | Sep=0.834 | semantic_label=3 | purity=0.19
Blocked (kept)
ðŸ“Š [CLUSTER 1] size=24 | Scoh=0.011 | Sep=0.759 | semantic_label=3 | purity=0.21
Blocked (kept)
Noise cluster kept





[DEBUG] Task 9 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 10 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 11 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 12 novelty buffer size = 300
[DEBUG] Entering Stage II clustering
Novelty buffer label distribution: Counter({5: 54, 6: 46, 7: 45, 3: 42, 4: 31, 9: 30, 8: 29, 2: 17, 1: 5, 0: 1})
[DEBUG] Clusters found: {np.int64(0), np.int64(1), np.int64(-1)}
ðŸ“Š [CLUSTER 0] size=67 | Scoh=0.020 | Sep=0.813 | semantic_label=5 | purity=0.22
Blocked (kept)
ðŸ“Š [CLUSTER 1] size=16 | Scoh=0.009 | Sep=0.751 | semantic_label=3 | purity=0.31
Blocked (kept)
Noise cluster kept





[DEBUG] Task 13 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 14 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 15 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 16 novelty buffer size = 300
[DEBUG] Entering Stage II clustering
Novelty buffer label distribution: Counter({6: 54, 5: 47, 3: 45, 7: 42, 9: 30, 4: 27, 8: 27, 2: 26, 1: 1, 0: 1})
[DEBUG] Clusters found: {np.int64(0), np.int64(1), np.int64(-1)}
ðŸ“Š [CLUSTER 0] size=62 | Scoh=0.029 | Sep=0.845 | semantic_label=2 | purity=0.18
Blocked (kept)
ðŸ“Š [CLUSTER 1] size=8 | Scoh=0.008 | Sep=0.734 | semantic_label=3 | purity=0.25
Blocked (kept)
Noise cluster kept





[DEBUG] Task 17 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 18 novelty buffer size = 300
[DEBUG] Skipping Stage II

[DEBUG] Task 19 novelty buffer size = 300
[DEBUG] Skipping Stage II
