# 🔍 Discover‑then‑Name: *Task‑Agnostic Concept Bottleneck Models*

**Hands‑On Tutorial Notebook – KDD 2025 ‘Beyond Feature Attribution’**  
*Based on*: Rao *et al.* (2024) *Discover‑then‑Name* (ECCV).  
*Official repo*: <https://github.com/neuroexplicit-saar/discover-then-name>

<a target="_blank" href="https://colab.research.google.com/github/cxai-mechint-htutorial-kdd2025/cxai-mechint-htutorial-kdd2025.github.io/blob/main/notebooks/03_discover_then_name.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---
### 🌟 What you will learn
1. **Discover** monosemantic latent concepts in a vision model using a Sparse Autoencoder (SAE).
2. **Name** those concepts automatically via CLIP text embeddings.

**Estimated runtime** (Colab, T4 GPU): ≈&nbsp;15 min using pretrained checkpoints.


## 🗺️ Notebook Roadmap
1. **Setup & Dependencies**  
2. **Load CLIP ViT‑B/32**  
3. **Prepare CIFAR‑100 probe dataset**  
4. **Discover Concepts** (load pretrained SAE)  
5. **Visualize & Name Concepts**  
6. **Exercises & References**

## 1 · Environment Setup
👉 **Run the cell below** on Colab (GPU runtime) to clone the repo and install all requirements. On a local machine, make sure you have CUDA‑enabled PyTorch ≥ 2.1.

In [None]:
# Clone the repository only if not already done
import os
if not os.path.exists('discover-then-name'):
    !git clone https://github.com/neuroexplicit-saar/discover-then-name.git
%cd discover-then-name

import torch, random, numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)
torch.manual_seed(0); random.seed(0); np.random.seed(0)


## 2 · Load CLIP Encoder

In [None]:
import clip, torch
clip_model, preprocess = clip.load('ViT-B/16', device=device)
clip_model.eval(); print('✅ CLIP ViT‑B/16 loaded.')

## 3 · Prepare CIFAR‑100 Probe Dataset
We will embed **~10 000** validation images from CIFAR‑100 with CLIP. This step takes < 1 min on GPU.

In [None]:
# Load the probe dataset, CIFAR100
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

cifar100_dataset = datasets.CIFAR100(root='data', train=False, download=True, transform=preprocess)
cifar100_loader = DataLoader(cifar100_dataset, batch_size=64, shuffle=False, num_workers=4)

print(f'Loaded CIFAR100 dataset with {len(cifar100_dataset)} images.')

In [None]:
# Collect the probe encodings using CLIP
import numpy as np
probe_encodings = []
with torch.no_grad():
    for images, _ in cifar100_loader:
        images = images.to(device)
        encodings = clip_model.encode_image(images).cpu().numpy()
        probe_encodings.append(encodings)
probe_encodings = np.concatenate(probe_encodings, axis=0)
print(f'Collected {len(probe_encodings)} probe encodings.')

## 4 · Discover Latent Concepts with Sparse Autoencoder
To save time, we download a **pretrained SAE** trained on millions of CLIP embeddings. If the download fails (e.g., offline), a small fallback SAE will be trained quickly (low quality but runnable).

In [None]:
# Load a pre-trained Sparse Autoencoder
from sparse_autoencoder import SparseAutoencoder

sae = SparseAutoencoder(n_input_features=512, n_learned_features=4096, n_components=1).to(device)

### 3.1 . Train the autoencoder with the probe dataset activations

In [None]:
from sparse_autoencoder import (
    L2ReconstructionLoss,
    LearnedActivationsL1Loss,
    LossReducer,
    LossReductionType
)
loss = LossReducer(
    LearnedActivationsL1Loss(l1_coefficient=0.0003),
    L2ReconstructionLoss())

optim = torch.optim.Adam(sae.parameters(), lr=1e-3)
# extend the batch dimension to match the expected input shape
data = torch.tensor(probe_encodings, dtype=torch.float32).unsqueeze(0).to(device)
for epoch in range(30):
    learned_activations, reconstructed_activations = sae.forward(data)
    total_loss, loss_metrics = loss.scalar_loss_with_log(data, learned_activations, reconstructed_activations,component_reduction=LossReductionType.MEAN)
    optim.zero_grad(); total_loss.backward(); optim.step()
    print(f'Epoch {epoch+1} – loss: {total_loss.item():.4f}')

### 3.2 . Load a pre-trained encoder

In [None]:
state_dict = torch.load("../Checkpoints/clip_ViT-B:16_sparse_autoencoder_final.pt", map_location=device)
sae.load_state_dict(state_dict)
sae.eval()
print('✅ Pretrained Sparse Autoencoder loaded.')

### 3.3 . Load the concept names

In [None]:
# Load the concept names
with open("../Assigned Names/clip_ViT-B:16_concept_names.csv", "r") as f:
    concept_names = f.read().splitlines()
concept_names = [name.split(',')[1] for name in concept_names]
concept_indexes = dict(zip(concept_names, range(len(concept_names))))

In [None]:
vocab = ['fences', 'pupil', 'doors', 'bed']

# concept_idx = 1526 # fences
# concept_idx = 3955 # pupil
# concept_idx = 704 # doors
concept_idx = concept_indexes['bed'] #2061 # bed

### 4.1 · Visualize Learned Dictionary

In [None]:
import matplotlib.pyplot as plt
weights = sae.decoder.weight.detach().cpu()
# neurons = weights.norm(p=2, dim=1).topk(6).indices
neurons = [concept_indexes[name] for name in vocab if name in concept_indexes]
fig, axes = plt.subplots(2, 3, figsize=(9, 6))
for ax, idx in zip(axes.flatten(), neurons):
    ax.imshow(weights[0,:,idx].view(32,16), cmap='viridis', aspect='auto')
    ax.set_title(f'Neuron {idx}: {concept_names[idx]}')
    ax.axis('off')
fig.suptitle('Random projections of decoder rows'); plt.tight_layout()

In [None]:
weights.detach().cpu().numpy()[0].T.shape

In [None]:
# Make a 2D projection of the decoder weights using t-SNE
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0, max_iter=1000, perplexity=30)
weights_2d = tsne.fit_transform(weights.detach().cpu().numpy()[0].T)

In [None]:
neurons = [concept_indexes[name] for i, name in enumerate(concept_indexes.keys()) if i%2==0]

plt.figure(figsize=(20, 20))
plt.scatter(weights_2d[:, 0], weights_2d[:, 1], s=5, alpha=0.5)
for idx in neurons:
    plt.annotate(concept_names[idx], (weights_2d[idx, 0], weights_2d[idx, 1]), fontsize=8)
plt.title('t-SNE projection of decoder weights')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.grid(True)
plt.show()

## 5 . Extract concepts from the probe dataset using the autoencoder

In [None]:
features = torch.tensor(probe_encodings, device=device, dtype=torch.float32)
sae.eval()
with torch.no_grad():
    concepts, reconstructions = sae(features)

## 5 . Show top images for selected concepts

In [None]:
cifar100_dataset_orig = datasets.CIFAR100(root='data', train=False, download=False)

In [None]:
concept_idx = concept_indexes['fences']
# concept_idx = concept_indexes['pupil']
# concept_idx = concept_indexes['doors']
# concept_idx = concept_indexes['bed']

concept_names[concept_idx]

In [None]:
concept_strengths = concepts[:, 0, concept_idx].cpu().numpy()
top_indices = concept_strengths.argsort()[::-1][:10]
top_images = [cifar100_dataset_orig[i][0] for i in top_indices]

In [None]:
# plot the top images for concept 1526
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 5, figsize=(8, 3))
for i, img_idx in enumerate(top_indices):
    img = cifar100_dataset_orig[img_idx][0]
    # img = img.permute(1, 2, 0).numpy()  # Convert to HWC format for plotting
    ax = axes[i // 5, i % 5]
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"Concept Strength: {concept_strengths[top_indices[i]]:.4f}\nImage: {img_idx}", fontsize=8)
plt.tight_layout()

## 8 · Exercises & Further Reading
1. **Improve concept naming** by providing a richer vocabulary—e.g., the 50 000 CLIP tokens—and measuring naming recall.
2. **Tune sparsity** (`l1_lambda`) or **hidden size** and observe effects on interpretability vs. accuracy.
3. Replace CIFAR‑100 with your own dataset (or **ImageNet‑mini**) and retrain the DN‑CBM.
4. Combine **TCAV** with the discovered concepts to quantify their directional influence on model predictions.

---
### 📑 References
- Rao *et al.* (2024) *Discover‑then‑Name: Task‑Agnostic Concept Bottlenecks via Automated Concept Discovery.* ECCV.
- Kim *et al.* (2018) *TCAV: Quantitative Testing with Concept Activation Vectors.* ICML.
- Oikarinen *et al.* (2023) *Label‑Free Concept Bottleneck Models.* arXiv.
- Radford *et al.* (2021) *Learning Transferable Visual Models from Natural Language Supervision.* ICML.