# Label-free CBM — Evaluation (Adapted, PVC Defaults)

PVC-aware defaults for Nautilus:
- `DATA_ROOT=/kayla/dataset`
- `CKPT_PATH=/kayla/saved_models/cifar10_label_free/best.pt`
- `CONCEPTS_TXT=/kayla/cbm_library/concepts/main/outputs/label_free/cifar10_filtered.txt`

You can still override via environment variables.
**Created:** 2025-08-15T22:26:27.448140Z

In [None]:
# --- Setup & imports ---
import os, sys, json, math, random, time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

import matplotlib.pyplot as plt
from typing import List, Tuple
from collections import defaultdict, Counter

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Repro
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# === PVC defaults (override with env if you like) ===
DATA_ROOT = os.environ.get("DATA_ROOT", "/kayla/dataset")
CKPT_PATH = os.environ.get("CKPT_PATH", "/kayla/saved_models/cifar10_label_free/best.pt")
CONCEPTS_TXT = os.environ.get("CONCEPTS_TXT", "/kayla/cbm_library/concepts/main/outputs/label_free/cifar10_filtered.txt")

BATCH_SIZE = int(os.environ.get("EVAL_BATCH", 64))
NUM_WORKERS = int(os.environ.get("EVAL_WORKERS", 2))
TOPK = (1, 5)
TOP_CONTRIB = int(os.environ.get("TOP_CONTRIB", 10))

EVAL_SPLIT = os.environ.get("EVAL_SPLIT", "test")  # 'test' or 'train'

print("Paths:")
print("  DATA_ROOT  =", DATA_ROOT)
print("  CKPT_PATH  =", CKPT_PATH)
print("  CONCEPTS   =", CONCEPTS_TXT)
print("Eval opts: BATCH_SIZE =", BATCH_SIZE, " WORKERS =", NUM_WORKERS, " TOPK =", TOPK)

In [None]:
# --- Data loading (CIFAR-10 by default) ---
def build_cifar10(root, train=False):
    tfm = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])
    return torchvision.datasets.CIFAR10(root=root, train=train, download=True, transform=tfm)

eval_is_train = EVAL_SPLIT.lower().startswith("train")
ds = build_cifar10(DATA_ROOT, train=eval_is_train)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
CLASS_NAMES = ds.classes
NUM_CLASSES = len(CLASS_NAMES)
print(f"Dataset: CIFAR-10 | split={EVAL_SPLIT} | size={len(ds)} | classes={NUM_CLASSES}")

In [None]:
# --- Utilities: accuracy & plotting ---
def topk_accuracy(logits: torch.Tensor, targets: torch.Tensor, topk=(1,)):
    maxk = max(topk)
    batch_size = targets.size(0)
    _, pred = logits.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(float(correct_k.mul_(100.0 / batch_size)))
    return res

def barplot_contributions(concept_names: List[str], contrib: np.ndarray, title="Top concept contributions", top_n=10):
    idx = np.argsort(-np.abs(contrib))[:top_n]
    names = [concept_names[i] if i < len(concept_names) else f"c{i}" for i in idx]
    vals = contrib[idx]
    plt.figure(figsize=(8, max(3, top_n * 0.35)))
    y = np.arange(len(idx))
    plt.barh(y, vals)
    plt.yticks(y, names)
    plt.title(title)
    plt.xlabel("signed contribution (w_i * c_i)")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.show()

In [None]:
# --- Model loading for both original LF-CBM & cbm_library ---
def load_checkpoint_any(path: str):
    ckpt = torch.load(path, map_location="cpu")
    W_c = None; W_y = None; concept_names = []

    if isinstance(ckpt, dict):
        # Direct keys
        W_c = ckpt.get("W_c", W_c)
        W_y = ckpt.get("W_y", W_y)
        concept_names = ckpt.get("concept_names", concept_names)

        # Nested common patterns
        if W_c is None or W_y is None:
            st = ckpt.get("state_dict") or ckpt.get("model") or {}
            if isinstance(st, dict):
                W_c = st.get("concept_layer.weight", W_c)
                W_y = st.get("final_layer.weight", W_y)

    # Numpy -> torch
    if W_c is not None and not isinstance(W_c, torch.Tensor):
        import numpy as _np
        if isinstance(W_c, _np.ndarray):
            W_c = torch.tensor(W_c, dtype=torch.float32)
    if W_y is not None and not isinstance(W_y, torch.Tensor):
        import numpy as _np
        if isinstance(W_y, _np.ndarray):
            W_y = torch.tensor(W_y, dtype=torch.float32)

    if W_c is None or W_y is None:
        raise RuntimeError("Checkpoint missing W_c and/or W_y. Expected original LF-CBM or compatible cbm_library save.")

    W_c = W_c.float().contiguous()
    W_y = W_y.float().contiguous()
    print("Loaded shapes: W_c =", tuple(W_c.shape), " W_y =", tuple(W_y.shape))
    return W_c, W_y, concept_names, ckpt

try:
    W_c, W_y, CONCEPT_NAMES, CKPT_META = load_checkpoint_any(CKPT_PATH)
except Exception as e:
    print("⚠️ Failed to load checkpoint:", e)
    print("Using small random demo weights so you can run the notebook structure (not for real eval).")
    W_c = torch.randn(100, 512) * 0.01
    W_y = torch.randn(NUM_CLASSES, 100) * 0.01
    CONCEPT_NAMES = [f"concept_{i}" for i in range(W_c.size(0))]
    CKPT_META = {}

In [None]:
# --- Feature extractor (ResNet-18 default) ---
from torchvision import models
def build_backbone():
    m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    m.fc = nn.Identity()
    return m.to(DEVICE).eval()

backbone = build_backbone()

@torch.no_grad()
def extract_features(dataloader):
    feats = []
    for x, _ in dataloader:
        x = x.to(DEVICE, non_blocking=True)
        f = backbone(x)
        if f.ndim > 2:
            f = torch.flatten(f, 1)
        feats.append(f.detach().cpu())
    return torch.cat(feats, dim=0)

print("Extracting features...")
X = extract_features(dl)  # [N, D]
ys = torch.tensor([y for _, y in ds], dtype=torch.long)
print("Feature matrix:", tuple(X.shape))

In [None]:
# --- Forward & metrics ---
@torch.no_grad()
def forward_logits(X_cpu: torch.Tensor, W_c: torch.Tensor, W_y: torch.Tensor):
    C = X_cpu @ W_c.t()   # [N, C]
    Y = C @ W_y.t()       # [N, K]
    return C, Y

C_acts, logits = forward_logits(X, W_c, W_y)

top1, top5 = topk_accuracy(logits, ys, topk=(1,5))
print(f"Top-1: {top1:.2f}%   Top-5: {top5:.2f}%")

try:
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(ys.numpy(), logits.argmax(dim=1).numpy(), labels=list(range(NUM_CLASSES)))
    print("Confusion matrix shape:", cm.shape)
except Exception as e:
    print("sklearn not available for confusion matrix:", e)

In [None]:
# --- Sparsity & concept usage ---
import numpy as np
Wy = W_y.cpu().numpy()
abs_Wy = np.abs(Wy)

per_class_nnz = (abs_Wy > 1e-12).sum(axis=1)
global_nnz = (abs_Wy > 1e-12).any(axis=0).sum()

print("Per-class nonzero concepts:", per_class_nnz.tolist())
print("Mean nonzeros per class:", float(np.mean(per_class_nnz)))
print("Global #effective concepts:", int(global_nnz), "/", Wy.shape[1])

def top_concepts_for_class(k=10, class_idx=0):
    w = Wy[class_idx]
    idx = np.argsort(-np.abs(w))[:k]
    names = [CONCEPT_NAMES[i] if i < len(CONCEPT_NAMES) else f'c{i}' for i in idx]
    vals = w[idx]
    return list(zip(names, vals))

print("Top concepts for class 0:", top_concepts_for_class(10, 0))

In [None]:
# --- Per-sample explanation ---
def show_image(timg):
    inv = T.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    img = inv(timg).clamp(0,1).permute(1,2,0).cpu().numpy()
    plt.figure()
    plt.imshow(img)
    plt.axis('off')
    plt.show()

def explain_sample(idx: int, top_n: int = TOP_CONTRIB):
    x, y = ds[idx]
    with torch.no_grad():
        f = backbone(x.unsqueeze(0).to(DEVICE))
        if f.ndim > 2:
            f = torch.flatten(f, 1)
        c = (f.cpu() @ W_c.t()).squeeze(0)           # [C]
        ylog = (c @ W_y.t()).cpu().numpy()           # [K]
    pred = int(np.argmax(ylog))
    contrib = (c.numpy() * W_y[pred].cpu().numpy())  # signed contribution per concept

    print(f"idx={idx} | GT={CLASS_NAMES[int(y)]} ({int(y)}) | Pred={CLASS_NAMES[pred]} ({pred})")
    show_image(x)
    barplot_contributions(CONCEPTS_TXT if isinstance(CONCEPTS_TXT, list) else (CONCEPTS_TXT or CONCEPTS_TXT), contrib, 
                          title=f"Top concept contributions → {CLASS_NAMES[pred]}", top_n=top_n)

# Quick demo
for i in [0, 1, 2]:
    explain_sample(i, top_n=TOP_CONTRIB)

In [None]:
# --- Optional: export explanations ---
EXPORT_PATH = os.environ.get("EXPORT_JSONL", "/kayla/logs/eval_explanations.jsonl")
K = int(os.environ.get("EXPORT_TOPK", 10))
Path("/kayla/logs").mkdir(parents=True, exist_ok=True)

def topk_concepts_for_pred(contrib: np.ndarray, k: int) -> list:
    idx = np.argsort(-np.abs(contrib))[:k]
    names = [CONCEPT_NAMES[i] if i < len(CONCEPT_NAMES) else f'c{i}' for i in idx]
    vals = contrib[idx].tolist()
    return [{"concept": n, "contribution": v, "index": int(i)} for n, v, i in zip(names, vals, idx)]

@torch.no_grad()
def export_explanations(jsonl_path: str, num_samples: int = 50):
    with open(jsonl_path, "w") as f:
        for i in range(min(num_samples, len(ds))):
            x, y = ds[i]
            ftr = backbone(x.unsqueeze(0).to(DEVICE))
            if ftr.ndim > 2:
                ftr = torch.flatten(ftr, 1)
            c = (ftr.cpu() @ W_c.t()).squeeze(0).numpy()
            ylog = (c @ W_y.t().cpu().numpy())
            pred = int(np.argmax(ylog))
            contrib = (c * W_y[pred].cpu().numpy())
            topk = topk_concepts_for_pred(contrib, K)
            rec = {
                "index": i,
                "gt": int(y),
                "gt_name": CLASS_NAMES[int(y)],
                "pred": pred,
                "pred_name": CLASS_NAMES[pred],
                "top_concepts": topk,
            }
            f.write(json.dumps(rec) + "\n")
    print(f"Wrote {jsonl_path}")

# To run:
# export_explanations(EXPORT_PATH, num_samples=100)