## Step #1: Train a CBM

In [1]:
import numpy as np
import torch

from torch_concepts.nn import LinearZC
from torch_concepts.data.datasets.traffic import TrafficLights

# Fix seeds first
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

n_samples = 500

# Loading training dataset
dataset = TrafficLights(
    n_samples=n_samples,
    possible_starting_directions=['west'],
    resize_final_image=0.05,
    selected_concepts=[
        'green light on selected lane',
        'car in intersection',
        'ambulance seen',
        'ambulance approaching perpendicular to selected car',
    ],
    split='train',
)
concept_names, task_names = dataset.concept_names, dataset.task_names
n_concepts = len(concept_names)

# Loading testing dataset
# Generate the test dataset
test_dataset = TrafficLights(
    n_samples=n_samples,
    possible_starting_directions=['west'],
    resize_final_image=0.05,
    selected_concepts=[
        'green light on selected lane',
        'car in intersection',
        'ambulance seen',
        'ambulance approaching perpendicular to selected car',
    ],
    split='test',
)
print(
    f"Training set has {len(dataset)} samples while test set "
    f"has {len(test_dataset)} samples"
)

  from .modules.mid.base.model import BaseConstructor
INFO:root:We found a dataset previously generated with the same config that has been cached.
INFO:root:	If you wish to re-generate it, please use regenerate=True.
INFO:root:We found a dataset previously generated with the same config that has been cached.
INFO:root:	If you wish to re-generate it, please use regenerate=True.


Training set has 300 samples while test set has 100 samples


In [2]:
# -----------------------
# Dimensions
# -----------------------
latent_dims = 32
n_concepts = len(concept_names)

# -----------------------
# Encoder: X â†’ Z
# -----------------------
encoder = torch.nn.Sequential(
    torch.nn.Conv2d(3, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),

    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm2d(4),

    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),

    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm2d(4),

    torch.nn.MaxPool2d((5, 5)),

    torch.nn.Flatten(start_dim=1),
    torch.nn.Linear(576, latent_dims),
    torch.nn.LeakyReLU(),
)

# -----------------------
# Concept layer: Z â†’ C
# -----------------------
c_layer = LinearZC(
    in_features=latent_dims,
    out_features=n_concepts,
)

# ðŸ”‘ Attach concept semantics manually
c_layer.annotations = list(concept_names)

# -----------------------
# Task predictor: C â†’ Y
# -----------------------
y_predictor = torch.nn.Sequential(
    torch.nn.Linear(n_concepts, latent_dims),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(latent_dims, 1),  # binary task
)

# -----------------------
# Full CBM
# -----------------------
model = torch.nn.Sequential(
    encoder,
    c_layer,
    y_predictor,
)

print(model)


Sequential(
  (0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): LeakyReLU(negative_slope=0.01)
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (8): LeakyReLU(negative_slope=0.01)
    (9): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): MaxPool2d(kernel_size=(5, 5), stride=(5, 5), padding=0, dilation=1, ceil_mode=False)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Linear(in_features=576, out_features=32, bias=True)
    (13): LeakyReLU(negative_slope=0.01)
  )
  (1): LinearZC(
    (encoder): Sequential(
      (0): Linear(in_features=32, out_features=4, bias=True)

## We train the model:

In [3]:
from torch.utils.data import DataLoader

n_epochs = 20
concept_loss_weight = 10
lr = 0.01
batch_size = 50

# Define optimizer and loss function
model = torch.nn.Sequential(encoder, c_layer, y_predictor)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = torch.nn.BCELoss()

# Make a batch dataset loader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=2,
)

# Standard PyTorch learning cycle
model.train()
for epoch in range(n_epochs):
    for batch_idx, (x, y, c, _, _) in enumerate(dataloader):
        # Encode input, then predict concept and downstream tasks activations
        emb = encoder(x)
        c_pred = c_layer(emb).sigmoid()
        y_pred = y_predictor(c_pred).sigmoid().view(-1)

        # Double loss on concepts and tasks
        loss = loss_fn(y_pred, y) + concept_loss_weight * loss_fn(c_pred, c)

        # Perform the update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        task_acc = torch.mean(((y_pred > 0.5) == y).type(torch.float))
        task_acc = task_acc.detach().cpu().numpy()
        if ((epoch + 1) % 5 == 0) and (batch_idx == 0):
            print(
                f"Epoch [{epoch+1}/{n_epochs}], "
                f"Step [{batch_idx+1}/{len(dataloader)}], "
                f"Loss: {loss.item():.4f}, "
                f"Task Accuracy: {task_acc * 100:.2f}%, "
            )

Epoch [5/20], Step [1/6], Loss: 3.0957, Task Accuracy: 96.00%, 
Epoch [10/20], Step [1/6], Loss: 1.9443, Task Accuracy: 90.00%, 
Epoch [15/20], Step [1/6], Loss: 1.9977, Task Accuracy: 90.00%, 
Epoch [20/20], Step [1/6], Loss: 1.1628, Task Accuracy: 92.00%, 


In [4]:
# Load the train set to memory
x_train = []
c_train = []
y_train = []
for (x, y, c, _, _) in dataset:
    x_train.append(x.unsqueeze(0))
    y_train.append(y.unsqueeze(0))
    c_train.append(c.unsqueeze(0))
x_train = torch.concat(x_train, dim=0)
y_train = torch.concat(y_train, dim=0)
c_train = torch.concat(c_train, dim=0)

# Load the test set to memory
x_test = []
c_test = []
y_test = []
for (x, y, c, _, _) in test_dataset:
    x_test.append(x.unsqueeze(0))
    y_test.append(y.unsqueeze(0))
    c_test.append(c.unsqueeze(0))
x_test = torch.concat(x_test, dim=0)
y_test = torch.concat(y_test, dim=0)
c_test = torch.concat(c_test, dim=0)

In [5]:
model.eval()
c_pred = c_layer(encoder(x_test)).sigmoid()
y_pred = y_predictor(c_pred).sigmoid()
print("Average task prediction:", y_pred.mean(0).detach().cpu().numpy())
print("Average concept prediction:", c_pred.mean(0).detach().cpu().numpy())

Average task prediction: [0.5466324]
Average concept prediction: [0.6060981  0.28488654 0.10826973 0.0669183 ]


## Step #2: Compute task and concept performance

In [6]:
from sklearn.metrics import roc_auc_score

concept_performance = roc_auc_score(c_test, c_pred.detach())
task_performance = roc_auc_score(y_test, y_pred.detach())

print(f'Task performance: {task_performance*100:.2f}%')
print(f'Concept performance: {concept_performance*100:.2f}%')

Task performance: 94.58%
Concept performance: 92.12%


## Step #3: Compute intervention effectiveness

    Cannot find the intervention_score : 

    so we created a custom function 

In [8]:
from torch_concepts.nn import intervention_score
intervention_groups = [[], [0], [1], [0, 1]]

# Evaluate intervention effectiveness of each concept group individually
intervention_scores = intervention_score(
    y_predictor,
    c_pred,
    c_test,
    y_test,
    intervention_groups,
    auc=False,
)
print(f'Individual intervention scores: {intervention_scores}')

# Evaluate the global intervention effectiveness as the AUC
intervention_auc = intervention_score(
    y_predictor,
    c_pred,
    c_test,
    y_test,
    intervention_groups,
)
print(f'Intervention AUC: {intervention_auc:.4f}')

ImportError: cannot import name 'intervention_score' from 'torch_concepts.nn' (/home/user_cril/meher/CRIL/CBM/TrafficLights/traffic-env/lib/python3.10/site-packages/torch_concepts/nn/__init__.py)

In [11]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def intervention_score_cbm(
    predictor,
    c_pred,
    c_true,
    y_true,
    intervention_groups,
    auc=True,
):
    """
    Correct reimplementation of intervention score from the CBM literature.
    """

    y_true_np = y_true.view(-1).cpu().numpy()
    scores = []
    num_intervened = []

    for group in intervention_groups:
        c_intervened = c_pred.clone()

        if len(group) > 0:
            c_intervened[:, group] = c_true[:, group]

        y_logits = predictor(c_intervened).view(-1)
        y_probs = torch.sigmoid(y_logits).cpu().numpy()

        score = roc_auc_score(y_true_np, y_probs)

        scores.append(score)
        num_intervened.append(len(group))

    if not auc:
        return scores

    # ðŸ”‘ Sort by number of intervened concepts
    num_intervened = np.array(num_intervened)
    scores = np.array(scores)

    order = np.argsort(num_intervened)
    num_intervened = num_intervened[order]
    scores = scores[order]

    # ðŸ”‘ Normalize x-axis to [0, 1]
    x = num_intervened / num_intervened.max()

    # ðŸ”‘ Area under intervention curve
    intervention_auc = np.trapz(scores, x)

    return intervention_auc


In [12]:
intervention_groups = [
    [],
    [0],
    [1],
    [0, 1],
]

scores = intervention_score_cbm(
    predictor=y_predictor,
    c_pred=c_pred,
    c_true=c_test,
    y_true=y_test,
    intervention_groups=intervention_groups,
    auc=False,
)

print("Individual intervention scores:", scores)

auc = intervention_score_cbm(
    predictor=y_predictor,
    c_pred=c_pred,
    c_true=c_test,
    y_true=y_test,
    intervention_groups=intervention_groups,
    auc=True,
)

print(f"Intervention AUC: {auc:.4f}")


Individual intervention scores: [0.9458333333333333, 0.9420833333333334, 0.9691666666666667, 0.9570833333333333]
Intervention AUC: 0.9535


## Step #4: Compute concept completeness 
To compute concept completeness, we need a black-box baseline model that uses both raw features and concept labels, as this matches the information provided to the CBM. The following code implements a simple black-box model with a similar parameter count to the CBM for a fair comparison.

In [13]:
# Make maps containing the values of each concept
c_train_maps = c_train.unsqueeze(-1).unsqueeze(-1)
c_train_maps = c_train_maps.expand(-1, -1, 64, 64)
c_test_maps = c_test.unsqueeze(-1).unsqueeze(-1)
c_test_maps = c_test_maps.expand(-1, -1, 64, 64)

# Put them together with the input features
xc_train = torch.concat((x_train, c_train_maps), dim=1)
xc_test = torch.concat((x_test, c_test_maps), dim=1)

# Defining a balck box baseline
baseline = torch.nn.Sequential(
    # A 3x3 convolution with 4 output channels
    torch.nn.Conv2d(3 + n_concepts, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),

    # A 3x3 convolution with 4 output channels with a batch norm
    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm2d(4),

    # A 3x3 convolution with 4 output channels
    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),

    # A 3x3 convolution with 4 output channels with a batch norm
    torch.nn.Conv2d(4, 4, (3, 3), padding='same'),
    torch.nn.LeakyReLU(),
    torch.nn.BatchNorm2d(4),

    # A 5x5 max pooling layer
    torch.nn.MaxPool2d((5, 5)),

    # Finally, we flatten and map it to a known latent space size
    torch.nn.Flatten(start_dim=1, end_dim=-1),
    torch.nn.Linear(576, 1),
)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(baseline.parameters(), lr=lr)
loss_fn = torch.nn.BCELoss()

# Standard PyTorch learning cycle
baseline.train()
for epoch in range(n_epochs):
   optimizer.zero_grad()

   # Encode input, then predict concept and downstream tasks activations
   y_pred_baseline = baseline(xc_train).sigmoid().view(-1)

   # Double loss on concepts and tasks
   loss = loss_fn(y_pred_baseline, y_train)
   loss.backward()
   optimizer.step()

baseline.eval()
y_pred_baseline = baseline(xc_test).sigmoid()
task_performance_baseline = roc_auc_score(y_test, y_pred_baseline.detach())

In [14]:
from torch_concepts.metrics import completeness_score

concept_completeness = completeness_score(y_test, y_pred_baseline, y_pred)

print(f'Task performance: {task_performance*100:.2f}%')
print(f'Task performance baseline: {task_performance_baseline*100:.2f}%')
print(f'Concept completeness: {concept_completeness*100:.2f}%')

ModuleNotFoundError: No module named 'torch_concepts.metrics'

## Create a custom fct for completness score 
    can u verify it (Tanmoy)

In [17]:
def completeness_score_cbm(
    y_true,
    y_pred_baseline,
    y_pred_cbm,
    metric="auc",
):
    import numpy as np
    import torch
    from sklearn.metrics import roc_auc_score, accuracy_score

    # -----------------------
    # Detach tensors safely
    # -----------------------
    def to_numpy(x):
        if torch.is_tensor(x):
            return x.detach().cpu().numpy().ravel()
        return np.asarray(x).ravel()

    y_true = to_numpy(y_true)
    y_pred_baseline = to_numpy(y_pred_baseline)
    y_pred_cbm = to_numpy(y_pred_cbm)

    # -----------------------
    # Compute performance
    # -----------------------
    if metric == "auc":
        perf_baseline = roc_auc_score(y_true, y_pred_baseline)
        perf_cbm = roc_auc_score(y_true, y_pred_cbm)
    elif metric == "accuracy":
        perf_baseline = accuracy_score(y_true, y_pred_baseline > 0.5)
        perf_cbm = accuracy_score(y_true, y_pred_cbm > 0.5)
    else:
        raise ValueError("metric must be 'auc' or 'accuracy'")

    if perf_baseline >= 1.0:
        return 0.0

    completeness = (perf_cbm - perf_baseline) / (1.0 - perf_baseline)
    return max(0.0, completeness)


In [19]:
concept_completeness = completeness_score_cbm(
    y_true=y_test,
    y_pred_baseline=y_pred_baseline,
    y_pred_cbm=y_pred,
    metric="auc",
)

print(f"Task performance (CBM): {task_performance*100:.2f}%")
print(f"Task performance (baseline): {task_performance_baseline*100:.2f}%")
print(f"Concept completeness: {concept_completeness*100:.2f}%")


Task performance (CBM): 94.58%
Task performance (baseline): 97.46%
Concept completeness: 0.00%


## Quick diagnostic (recommended)
ðŸ”¹ Check unclipped completeness (for insight only)

In [20]:
raw_completeness = (
    task_performance - task_performance_baseline
) / (1 - task_performance_baseline)

print(f"Raw completeness: {raw_completeness:.3f}")


Raw completeness: -1.131
