# Label-free CBM — Minimal Evaluation (Simplified)
**Auto-generated:** 2025-08-19T23:50:51.252330Z

This notebook is simplified to a *single* flow:
1) Load bundle
2) Force RN50 (1024-d) encoder
3) Define `cbm_infer`
4) Run inference

> Keep using RN50 to match `W_c (281, 1024)`. If you use a 512-d encoder (e.g., ViT-B/16), it will error.


In [None]:
import torch, open_clip

# --- Paths ---
BUNDLE = "/kayla/saved_models/lf_cbm_cifar10/lf_cbm_minimal_bundle.pt"  # edit if needed

# --- Device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# --- Load bundle ---
d = torch.load(BUNDLE, map_location="cpu")
sd = d["state_dict"]
W_c = sd["concept_layer.weight"].to(device)   # [K, 1024]
W_g = sd["final_layer.weight"].to(device)     # [C, K]
b_g = sd["final_layer.bias"].to(device)       # [C]

c_mean = torch.tensor(d.get("concept_mean", []), dtype=torch.float32, device=device) if d.get("concept_mean") is not None else None
c_std  = torch.tensor(d.get("concept_std",  []), dtype=torch.float32, device=device) if d.get("concept_std")  is not None else None

print("W_c:", tuple(W_c.shape), "W_g:", tuple(W_g.shape), "b_g:", tuple(b_g.shape))

# --- Force RN50 encoder (1024-d) ---
_clip_model, _, preprocess = open_clip.create_model_and_transforms("RN50", pretrained="openai")
_clip_model = _clip_model.to(device).eval()

def ENCODER(imgs_224: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        f = _clip_model.encode_image(imgs_224.to(device))
        f = f / (f.norm(dim=-1, keepdim=True) + 1e-12)
    return f  # [N, 1024]

# Smoke test: should be 1024
print("Encoder dim:", ENCODER(torch.zeros(1,3,224,224)).shape[-1])

# --- Minimal cbm_infer ---
def cbm_infer(x_images_224: torch.Tensor, *, apply_zscore: bool = True, return_probs: bool = True):
    with torch.no_grad():
        feats = ENCODER(x_images_224)             # [N, 1024]
        if feats.shape[-1] != W_c.shape[1]:
            raise RuntimeError(f"Encoder D={feats.shape[-1]} != W_c expects D={W_c.shape[1]}. Use RN50.")
        concepts = feats @ W_c.T                  # [N, K]
        if apply_zscore and (c_mean is not None) and (c_std is not None) and (c_mean.numel() == concepts.shape[1]):
            concepts = (concepts - c_mean) / (c_std + 1e-6)
        logits = concepts @ W_g.T + b_g           # [N, C]
        if return_probs:
            probs = torch.softmax(logits, dim=-1)
            return logits, concepts, probs
        return logits, concepts

print("Ready.")


In [None]:
# --- Demo: run on a zero batch (just to prove shapes flow) ---
x = torch.zeros(2,3,224,224, device=device)
logits, concepts, probs = cbm_infer(x, apply_zscore=True, return_probs=True)
print("logits:", logits.shape, "concepts:", concepts.shape, "probs:", probs.shape)


In [None]:
import torch
import os
import random
import utils
import data_utils
import json

import cbm
import plots

In [None]:
# change this to the correct model dir, everything else should be taken care of
load_dir = "saved_models/cifar10_cbm_2025_08_19_21_43"
device = "cuda"

with open(os.path.join(load_dir, "args.txt"), "r") as f:
    args = json.load(f)
dataset = args["dataset"]
_, target_preprocess = data_utils.get_target_model(args["backbone"], device)
model = cbm.load_cbm(load_dir, device)

In [None]:
val_d_probe = dataset+"_val"
cls_file = data_utils.LABEL_FILES[dataset]

val_data_t = data_utils.get_data(val_d_probe, preprocess=target_preprocess)
val_pil_data = data_utils.get_data(val_d_probe)

In [None]:
with open(cls_file, "r") as f:
    classes = f.read().split("\n")

with open(os.path.join(load_dir, "concepts.txt"), "r") as f:
    concepts = f.read().split("\n")

In [None]:
## Measure accuracy

In [None]:
accuracy = utils.get_accuracy_cbm(model, val_data_t, device)
print("Accuracy: {:.2f}%".format(accuracy*100))

In [None]:
## Show final layer weights for some classes

In [None]:
You can build a Sankey diagram of weights by copying the incoming weights printed below into https://sankeymatic.com/build/

In [None]:
to_show = random.choices([i for i in range(len(classes))], k=1)

for i in to_show:
    print("Output class:{} - {}".format(i, classes[i]))
    print("Incoming weights:")
    for j in range(len(concepts)):
        if torch.abs(model.final.weight[i,j])>0.05:
            print("{} [{:.4f}] {}".format(concepts[j], model.final.weight[i,j], classes[i]))

In [None]:
to_show = random.choices([i for i in range(len(classes))], k=2)

top_weights, top_weight_ids = torch.topk(model.final.weight, k=5, dim=1)
bottom_weights, bottom_weight_ids = torch.topk(model.final.weight, k=5, dim=1, largest=False)

for i in to_show:
    print("Class {} - {}".format(i, classes[i]))
    out = "Highest weights: "
    for j in range(top_weights.shape[1]):
        idx = int(top_weight_ids[i, j].cpu())
        out += "{}:{:.3f}, ".format(concepts[idx], top_weights[i, j])
    print(out)
    out = "Lowest weights: "
    for j in range(bottom_weights.shape[1]):
        idx = int(bottom_weight_ids[i, j].cpu())
        out += "{}:{:.3f}, ".format(concepts[idx], bottom_weights[i, j])
    print(out + "\n")

In [None]:
# Some features may not have any non-zero outgoing weights, 
# i.e. these are not used by the model and should be deleted for better performance
weight_contribs = torch.sum(torch.abs(model.final.weight), dim=0)
print("Num concepts with outgoing weights:{}/{}".format(torch.sum(weight_contribs>1e-5), len(weight_contribs)))

In [None]:
## Explain model reasoning for random inputs

In [None]:
to_display = random.sample([i for i in range(len(val_pil_data))], k=4)

with torch.no_grad():
    for i in to_display:
        image, label = val_pil_data[i]
        x, _ = val_data_t[i]
        x = x.unsqueeze(0).to(device)
        display(image.resize([320,320]))
        
        outputs, concept_act = model(x)
        
        top_logit_vals, top_classes = torch.topk(outputs[0], dim=0, k=2)
        conf = torch.nn.functional.softmax(outputs[0], dim=0)
        print("Image:{} Gt:{}, 1st Pred:{}, {:.3f}, 2nd Pred:{}, {:.3f}".format(i, classes[int(label)], classes[top_classes[0]], top_logit_vals[0],
                                                                      classes[top_classes[1]], top_logit_vals[1]))
        
        for k in range(1):
            contributions = concept_act[0]*model.final.weight[top_classes[k], :]
            feature_names = [("NOT " if concept_act[0][i] < 0 else "") + concepts[i] for i in range(len(concepts))]
            values = contributions.cpu().numpy()
            max_display = min(int(sum(abs(values)>0.005))+1, 8)
            title = "Pred:{} - Conf: {:.3f} - Logit:{:.2f} - Bias:{:.2f}".format(classes[top_classes[k]],
                             conf[top_classes[k]], top_logit_vals[k], model.final.bias[top_classes[k]])
            plots.bar(values, feature_names, max_display=max_display, title=title, fontsize=16)