#### RQ3: How does the induction of domain requirements impact the robustness of the Concept Bottleneck Model?
The aim of this RQ is to show that the model becomes more robust and logical with respect to images that it has not been trained on. For this purpose we use images from the GTSRB test set that have been morphed, images from the Belgium Traffic Sign dataset that cannot be directly mapped to the GTSRB i.e., unseen traffic signs and images from a far OOD dataset the CIFAR-10 to evaluate what impact the introduction of domain requirements has for images that have none of the characteristics of traffic signs.

#### Setting up configs and models

In [None]:
# imports
import os
from pathlib import Path
import torch

# changing the cwd to be source for all the imports to continue working
os.chdir("../src")

from models.architectures import CBMSequentialEfficientNetFCN
from config import load_config
from models.trainer.cbm_trainer import CBMTrainer
from rule_eval import construct_full_graph
import pandas as pd

# Import utility functions
from analysis_utils import (
    get_dataset_predictions,
    analyze_fuzzy_loss_single_model,
    compare_fuzzy_losses,
    analyze_rule_violations,
    compare_violations,
    print_fuzzy_loss_results,
    print_violation_results,
)

In [None]:
config_path = Path("../files/configs/")
models_path = Path("../files/models/")
data_path = Path("../../../data/raw/")

In [None]:
# model configs and model loading
baseline_cbm_config = load_config(config_path / "GTSRB_CBM_config_loading.yaml")
baseline_cbm = CBMSequentialEfficientNetFCN(baseline_cbm_config)

fuzzy_cbm_config = load_config(config_path / "GTSRB_CBM_config_best_trial_loading.yaml")
fuzzy_cbm = CBMSequentialEfficientNetFCN(fuzzy_cbm_config)

In [None]:
# model paths for loading models
baseline_cbm_concept_predictor_path = Path("../notebooks/best_acc_models/20251016_224601_s907_baseline_concept_predictor_best_model.pt")
baseline_cbm_label_predictor_path = Path("../experiments/baseline_cbm/models/20251001_083717_label_predictor_best_model.pt")
fuzzy_cbm_concept_predictor_path = Path("../notebooks/best_acc_models/20251020_223819_s269_concept_predictor_best_model.pt")
fuzzy_cbm_label_precitor_path = Path("../experiments/fuzzy_CBM/models/20251001_113637_label_predictor_best_model.pt")

In [23]:
# Load the baseline model components weights
baseline_cbm.concept_predictor.load_state_dict(
    torch.load(baseline_cbm_concept_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)
baseline_cbm.label_predictor.load_state_dict(
    torch.load(baseline_cbm_label_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)

# Load the fuzzy model components weights
fuzzy_cbm.concept_predictor.load_state_dict(
    torch.load(fuzzy_cbm_concept_predictor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)
fuzzy_cbm.label_predictor.load_state_dict(
    torch.load(fuzzy_cbm_label_precitor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)

# Set models to evaluation mode
baseline_cbm.eval()
fuzzy_cbm.eval()

print(f"  Baseline CBM: {baseline_cbm_concept_predictor_path.parent}")
print(f"  Fuzzy CBM:    {fuzzy_cbm_concept_predictor_path.parent}")

  Baseline CBM: ../notebooks/best_acc_models
  Fuzzy CBM:    ../notebooks/best_acc_models


In [24]:
dataset_factory = baseline_cbm_config.dataset.factory(
    seed=baseline_cbm_config.seed, config=baseline_cbm_config.dataset
).set_dataloaders()
train_loader = dataset_factory.train_dataloader
val_loader = dataset_factory.val_dataloader
test_loader = dataset_factory.test_dataloader

In [25]:
# Setup Trainers and Get Fuzzy Loss Function
baseline_cbm_trainer = CBMTrainer(
    config=baseline_cbm_config,
    model=baseline_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)

fuzzy_cbm_trainer = CBMTrainer(
    config=fuzzy_cbm_config,
    model=fuzzy_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)

# Get the fuzzy loss function
neutral_fuzzy_loss = fuzzy_cbm_trainer.concept_predictor_trainer.criterion

In [26]:
# Load rule checker
rule_checker = construct_full_graph()

#### Loading the GTSRB test-set

In [27]:
# Load GTSRB Dataset
dataset_factory = baseline_cbm_config.dataset.factory(
    seed=baseline_cbm_config.seed, config=baseline_cbm_config.dataset
).set_dataloaders()

test_loader = dataset_factory.test_dataloader

#### Loading the morphed GTSRB

In [28]:
# loading the morphed GTSRB dataset
from data_access.datasets.GTSRBFactory import reicnn_transform_TRAIN

dataset_factory = baseline_cbm_config.dataset.factory(
    seed=baseline_cbm_config.seed, config=baseline_cbm_config.dataset
).set_dataloaders(train_transform=reicnn_transform_TRAIN, test_transform=reicnn_transform_TRAIN)
morphed_gtsrb_test_loader = dataset_factory.test_dataloader

#### Loading the unmapped BTS

In [29]:
# loading the BTS and filtering out the signs that cannot be mapped 
# fixing seed to ensuring reproducibility
from config.dataset_config import ConceptDatasetConfig

# Create BTS dataset configuration
bts_config = ConceptDatasetConfig(
    name="bts",
    n_labels=62,  
    batch_size=64,
    num_workers=4,
    shuffle_dataset=False,
    pin_memory=True,
    data_path=Path("../data/raw/BTSD/"),
    val_split=0.2

)
# Resolve the configuration to validate paths and load concept map
bts_config.resolve()

# Create the dataset factory
from data_access.datasets import GTSRBFactory

bts_factory = GTSRBFactory(config=bts_config, seed=fuzzy_cbm_config.seed)
bts_factory.set_dataloaders()

# Access the dataloaders
bts_test_loader = bts_factory.test_dataloader
bts_train_loader = bts_factory.train_dataloader
bts_val_loader = bts_factory.val_dataloader

In [30]:
# list of each sign in the GTSRB and how it is mapped to the BTS. If there is an unambiguous mapping, the value is set to the BST label, otherwise it is -1. 
labels_id_GTS_to_BTS = [-1,-1,-1,-1,-1,-1,-1,-1,-1,31,-1,17,61,19,21,28,25,22,13,3,4,5,0,2,16,10,-1,-1,-1,8,-1,-1,-1,-1,-1,34,-1,-1,35,-1,37,-1,-1]
mapping_BTS_to_GTSRB = {i : labels_id_GTS_to_BTS.index(i) for i in labels_id_GTS_to_BTS if i != -1}
Belgium_ID_to_Name = {0: "Bumpy road",1: "Bump",2: "Slippery road",3: "Bend to The left",4: "Bend to The right",5: "Double curves first to the left",6: "Double curves first to the right",7: "School zone",8: "Bikes can be cross",9: "Domestic animal crossing",10: "Roadworks",11: "Traffic light",12: "Gated railroad crossing ahead",13: "Caution",14: "Road narrows",15: "Road narrows on the left",16: "Road narrows on the right",17: "Intersection with priority",18: "Intersection with priority to the right",19: "Yield",20: "Yield to incoming traffic",21: "Stop",22: "No entry",23: "No entry for cyclists",24: "No vehicle over 2t",25: "No entry for trucks",26: "Width limit",27: "Height limit",28: "No vehicles",29: "No left turn",30: "No right turn",31: "No overtaking",32: "Speed limit",33: "Shared path for pedestrians and cyclists",34: "Ahead only",35: "Right only",36: "Ahead and right only",37: "Roundabout",38: "Cycleway",39: "Segregated path for pedestrians and cyclists",40: "No parking",41: "No stopping",42: "No parking from the 1st till 15th day of the month",43: "No parking from the 16th till last day of the month",44: "Priority over oncoming traffic",45: "Parking permitted",46: "Parking for disabled",47: "Parking reserved for motorcycles, cars, vans (< 3.5t) and minibusses",48: "Parking reserved for trucks",49: "Parking reserved for coaches",50: "Parking mandatory on the verge or sidewalk",51: "Start of a living street",52: "End of living street",53: "One-way road",54: "Dead end",55: "End of roadworks",56: "Pedestrian crossing",57: "Cyclist and moped crossing",58: "Parking lot",59: "Hump",60: "End of priority road",61: "Priority road"}

In [31]:
from torch.utils.data import DataLoader, Subset

# Get labels that have unambiguous mapping to GTSRB
valid_bts_labels = [i for i in labels_id_GTS_to_BTS if i != -1]

# Filter the BTS datasets
print("Filtering BTS datasets to only include unambiguous mappings to GTSRB...")
print(f"Valid BTS labels: {sorted(valid_bts_labels)}")
print(f"Number of valid labels: {len(valid_bts_labels)}")

# Filter train dataset - get valid indices
valid_indices = []
for idx in range(len(bts_factory.train_dataset)):
    try:
        _, _, label_data = bts_factory.train_dataset[idx]
        # Handle both tuple and single label formats
        if isinstance(label_data, tuple):
            label = label_data[0].item() if hasattr(label_data[0], 'item') else label_data[0]
        else:
            label = label_data.item() if hasattr(label_data, 'item') else label_data
        
        if label not in valid_bts_labels:
            valid_indices.append(idx)
    except Exception as e:
        print(f"Warning: Could not process index {idx}: {e}")
        continue

print(f"\nOriginal BTS train size: {len(bts_factory.train_dataset)}")
print(f"Valid indices found: {len(valid_indices)}")

# Create filtered dataset using Subset
bts_train_filtered = Subset(bts_factory.train_dataset, valid_indices)
print(f"Filtered BTS train size: {len(bts_train_filtered)}")

# Create dataloader with filtered dataset
bts_train_loader_filtered = DataLoader(
    bts_train_filtered,
    batch_size=bts_config.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=bts_config.pin_memory,
)

print("\nFiltered dataloader created successfully!")
print(f"Dataloader ready: bts_train_loader_filtered")

Filtering BTS datasets to only include unambiguous mappings to GTSRB...
Valid BTS labels: [0, 2, 3, 4, 5, 8, 10, 13, 16, 17, 19, 21, 22, 25, 28, 31, 34, 35, 37, 61]
Number of valid labels: 20

Original BTS train size: 3660
Valid indices found: 2326
Filtered BTS train size: 2326

Filtered dataloader created successfully!
Dataloader ready: bts_train_loader_filtered


#### Loading the CIFAR-10 dataset

In [32]:
from data_access.datasets.CIFAR10Factory import CIFAR10Factory

cifar10_config = ConceptDatasetConfig(
    name="cifar10",
    n_labels=10, 
    batch_size=64,
    num_workers=4,
    shuffle_dataset=False,
    pin_memory=True,
    data_path=Path("../data/raw/CIFAR10/"),
    val_split=0.2
)
cifar10_config.resolve()

cifar10_factory = CIFAR10Factory(config=cifar10_config, seed=fuzzy_cbm_config.seed)
cifar10_factory.set_dataloaders()

# Access the dataloaders
cifar10_test_loader = cifar10_factory.test_dataloader
cifar10_train_loader = cifar10_factory.train_dataloader
cifar10_val_loader = cifar10_factory.val_dataloader

Files already downloaded and verified
Files already downloaded and verified


#### Printing analysis

In [33]:
# ============================================================================
# COMPLETE ANALYSIS: MORPHED GTSRB, FILTERED BTS
# ============================================================================

print("="*80)
print("RQ2/3: ROBUSTNESS ANALYSIS ON OUT-OF-DISTRIBUTION DATASETS")
print("="*80)
print("\nAnalyzing three datasets:")
print("  0. GTSRB (ID - Baseline)")
print("  1. Morphed GTSRB (Near OOD - data augmentation)")
print("  2. Filtered BTS  (Near OOD - different traffic signs)")
print("="*80)

# Storage for all results
all_predictions = {}
all_fuzzy_metrics = {}
all_violations = {}

RQ2/3: ROBUSTNESS ANALYSIS ON OUT-OF-DISTRIBUTION DATASETS

Analyzing three datasets:
  0. GTSRB (ID - Baseline)
  1. Morphed GTSRB (Near OOD - data augmentation)
  2. Filtered BTS  (Near OOD - different traffic signs)


In [34]:
# ============================================================================
# 0. GTSRB ANALYSIS (In-Distribution - for comparison baseline)
# ============================================================================
print("\n" + "="*80)
print("ANALYZING GTSRB DATASET (In-Distribution)")
print("="*80)

# Get predictions from both models
print("\nGetting predictions from both models on GTSRB...")
baseline_preds_gtsrb = get_dataset_predictions(
    baseline_cbm, test_loader, baseline_cbm_config.device, "GTSRB (Baseline)"
)
fuzzy_preds_gtsrb = get_dataset_predictions(
    fuzzy_cbm, test_loader, fuzzy_cbm_config.device, "GTSRB (Fuzzy)"
)

# FUZZY LOSS ANALYSIS - GTSRB
print("\n" + "-"*80)
print("FUZZY LOSS ANALYSIS - GTSRB")
print("-"*80)

baseline_fuzzy_gtsrb = analyze_fuzzy_loss_single_model(
    baseline_preds_gtsrb['logits'], 
    baseline_preds_gtsrb['predictions'],
    neutral_fuzzy_loss, 
    'Baseline CBM', 
    'GTSRB'
)

fuzzy_fuzzy_gtsrb = analyze_fuzzy_loss_single_model(
    fuzzy_preds_gtsrb['logits'], 
    fuzzy_preds_gtsrb['predictions'],
    neutral_fuzzy_loss, 
    'Fuzzy CBM', 
    'GTSRB'
)

fuzzy_comparison_gtsrb = compare_fuzzy_losses(baseline_fuzzy_gtsrb, fuzzy_fuzzy_gtsrb)

print_fuzzy_loss_results(baseline_fuzzy_gtsrb)
print_fuzzy_loss_results(fuzzy_fuzzy_gtsrb, fuzzy_comparison_gtsrb)

# RULE VIOLATION ANALYSIS - GTSRB
print("\n" + "-"*80)
print("RULE VIOLATION ANALYSIS - GTSRB")
print("-"*80)

baseline_viols_gtsrb = analyze_rule_violations(
    baseline_preds_gtsrb['predictions'], 
    'GTSRB', 
    'Baseline CBM', 
    rule_checker
)

fuzzy_viols_gtsrb = analyze_rule_violations(
    fuzzy_preds_gtsrb['predictions'], 
    'GTSRB', 
    'Fuzzy CBM', 
    rule_checker
)

violation_comparison_gtsrb = compare_violations(baseline_viols_gtsrb, fuzzy_viols_gtsrb)

print_violation_results(baseline_viols_gtsrb)
print_violation_results(fuzzy_viols_gtsrb, violation_comparison_gtsrb)

print("\n" + "="*80)
print("PER-RULE FUZZY LOSS COMPARISON - GTSRB")
print("="*80)
print("\n" + fuzzy_comparison_gtsrb['rule_comparison'].to_string(index=False))

# Store results
all_predictions['gtsrb_baseline'] = baseline_preds_gtsrb
all_predictions['gtsrb_fuzzy'] = fuzzy_preds_gtsrb
all_fuzzy_metrics['GTSRB_baseline'] = baseline_fuzzy_gtsrb
all_fuzzy_metrics['GTSRB_fuzzy'] = fuzzy_fuzzy_gtsrb
all_violations['GTSRB_baseline'] = baseline_viols_gtsrb
all_violations['GTSRB_fuzzy'] = fuzzy_viols_gtsrb


ANALYZING GTSRB DATASET (In-Distribution)

Getting predictions from both models on GTSRB...


Getting GTSRB (Baseline) predictions: 100%|██████████| 99/99 [00:13<00:00,  7.19it/s]
Getting GTSRB (Fuzzy) predictions: 100%|██████████| 99/99 [00:13<00:00,  7.51it/s]



--------------------------------------------------------------------------------
FUZZY LOSS ANALYSIS - GTSRB
--------------------------------------------------------------------------------

Baseline CBM on GTSRB:
  Standard BCE Loss:    0.0003751071
  Fuzzy Rules Loss:     0.0137938075
  Total Loss:           0.0141689146

Fuzzy CBM on GTSRB:
  Standard BCE Loss:    0.0002976162
  Fuzzy Rules Loss:     0.0130780460
  Total Loss:           0.0133756623

  Improvement over Baseline:
    Absolute: 0.0007157614
    Relative: 5.19%

--------------------------------------------------------------------------------
RULE VIOLATION ANALYSIS - GTSRB
--------------------------------------------------------------------------------

Baseline CBM on GTSRB:
  Total samples:           12630
  Samples with violations: 17
  Violation rate:          0.13%

Fuzzy CBM on GTSRB:
  Total samples:           12630
  Samples with violations: 21
  Violation rate:          0.17%

  Improvement over Baseline:
   

In [35]:
# ============================================================================
# 1. MORPHED GTSRB ANALYSIS
# ============================================================================
print("\n" + "="*80)
print("ANALYZING MORPHED GTSRB DATASET (Near OOD - Augmented)")
print("="*80)

# Get predictions from both models
print("\nGetting predictions from both models on Morphed GTSRB...")
baseline_preds_morphed = get_dataset_predictions(
    baseline_cbm, morphed_gtsrb_test_loader, baseline_cbm_config.device, "Morphed GTSRB (Baseline)"
)
fuzzy_preds_morphed = get_dataset_predictions(
    fuzzy_cbm, morphed_gtsrb_test_loader, fuzzy_cbm_config.device, "Morphed GTSRB (Fuzzy)"
)

# FUZZY LOSS ANALYSIS - Morphed GTSRB
print("\n" + "-"*80)
print("FUZZY LOSS ANALYSIS - MORPHED GTSRB")
print("-"*80)

baseline_fuzzy_morphed = analyze_fuzzy_loss_single_model(
    baseline_preds_morphed['logits'], 
    baseline_preds_morphed['predictions'],
    neutral_fuzzy_loss, 
    'Baseline CBM', 
    'Morphed GTSRB'
)

fuzzy_fuzzy_morphed = analyze_fuzzy_loss_single_model(
    fuzzy_preds_morphed['logits'], 
    fuzzy_preds_morphed['predictions'],
    neutral_fuzzy_loss, 
    'Fuzzy CBM', 
    'Morphed GTSRB'
)

fuzzy_comparison_morphed = compare_fuzzy_losses(baseline_fuzzy_morphed, fuzzy_fuzzy_morphed)

print_fuzzy_loss_results(baseline_fuzzy_morphed)
print_fuzzy_loss_results(fuzzy_fuzzy_morphed, fuzzy_comparison_morphed)

# RULE VIOLATION ANALYSIS - Morphed GTSRB
print("\n" + "-"*80)
print("RULE VIOLATION ANALYSIS - MORPHED GTSRB")
print("-"*80)

baseline_viols_morphed = analyze_rule_violations(
    baseline_preds_morphed['predictions'], 
    'Morphed GTSRB', 
    'Baseline CBM', 
    rule_checker
)

fuzzy_viols_morphed = analyze_rule_violations(
    fuzzy_preds_morphed['predictions'], 
    'Morphed GTSRB', 
    'Fuzzy CBM', 
    rule_checker
)

violation_comparison_morphed = compare_violations(baseline_viols_morphed, fuzzy_viols_morphed)

print_violation_results(baseline_viols_morphed)
print_violation_results(fuzzy_viols_morphed, violation_comparison_morphed)

print("\n" + "="*80)
print("PER-RULE FUZZY LOSS COMPARISON - MORPHED GTSRB")
print("="*80)
print("\n" + fuzzy_comparison_morphed['rule_comparison'].to_string(index=False))

# Store results
all_predictions['morphed_baseline'] = baseline_preds_morphed
all_predictions['morphed_fuzzy'] = fuzzy_preds_morphed
all_fuzzy_metrics['Morphed GTSRB_baseline'] = baseline_fuzzy_morphed
all_fuzzy_metrics['Morphed GTSRB_fuzzy'] = fuzzy_fuzzy_morphed
all_violations['Morphed GTSRB_baseline'] = baseline_viols_morphed
all_violations['Morphed GTSRB_fuzzy'] = fuzzy_viols_morphed


ANALYZING MORPHED GTSRB DATASET (Near OOD - Augmented)

Getting predictions from both models on Morphed GTSRB...


Getting Morphed GTSRB (Baseline) predictions: 100%|██████████| 99/99 [00:13<00:00,  7.15it/s]
Getting Morphed GTSRB (Fuzzy) predictions: 100%|██████████| 99/99 [00:14<00:00,  6.94it/s]



--------------------------------------------------------------------------------
FUZZY LOSS ANALYSIS - MORPHED GTSRB
--------------------------------------------------------------------------------

Baseline CBM on Morphed GTSRB:
  Standard BCE Loss:    0.0081922812
  Fuzzy Rules Loss:     0.0293253567
  Total Loss:           0.0375176379

Fuzzy CBM on Morphed GTSRB:
  Standard BCE Loss:    0.0070346040
  Fuzzy Rules Loss:     0.0172272548
  Total Loss:           0.0242618588

  Improvement over Baseline:
    Absolute: 0.0120981019
    Relative: 41.25%

--------------------------------------------------------------------------------
RULE VIOLATION ANALYSIS - MORPHED GTSRB
--------------------------------------------------------------------------------

Baseline CBM on Morphed GTSRB:
  Total samples:           12630
  Samples with violations: 732
  Violation rate:          5.80%

Fuzzy CBM on Morphed GTSRB:
  Total samples:           12630
  Samples with violations: 554
  Violation rat

In [36]:
# ============================================================================
# 2. FILTERED BTS ANALYSIS
# ============================================================================
print("\n" + "="*80)
print("ANALYZING FILTERED BTS DATASET (Near OOD - Different Traffic Signs)")
print("="*80)

# Get predictions from both models
print("\nGetting predictions from both models on Filtered BTS...")
baseline_preds_bts = get_dataset_predictions(
    baseline_cbm, bts_train_loader_filtered, baseline_cbm_config.device, "Filtered BTS (Baseline)"
)
fuzzy_preds_bts = get_dataset_predictions(
    fuzzy_cbm, bts_train_loader_filtered, fuzzy_cbm_config.device, "Filtered BTS (Fuzzy)"
)

# FUZZY LOSS ANALYSIS - BTS
print("\n" + "-"*80)
print("FUZZY LOSS ANALYSIS - FILTERED BTS")
print("-"*80)

baseline_fuzzy_bts = analyze_fuzzy_loss_single_model(
    baseline_preds_bts['logits'], 
    baseline_preds_bts['predictions'],
    neutral_fuzzy_loss, 
    'Baseline CBM', 
    'Filtered BTS'
)

fuzzy_fuzzy_bts = analyze_fuzzy_loss_single_model(
    fuzzy_preds_bts['logits'], 
    fuzzy_preds_bts['predictions'],
    neutral_fuzzy_loss, 
    'Fuzzy CBM', 
    'Filtered BTS'
)

fuzzy_comparison_bts = compare_fuzzy_losses(baseline_fuzzy_bts, fuzzy_fuzzy_bts)

print_fuzzy_loss_results(baseline_fuzzy_bts)
print_fuzzy_loss_results(fuzzy_fuzzy_bts, fuzzy_comparison_bts)

# RULE VIOLATION ANALYSIS - BTS
print("\n" + "-"*80)
print("RULE VIOLATION ANALYSIS - FILTERED BTS")
print("-"*80)

baseline_viols_bts = analyze_rule_violations(
    baseline_preds_bts['predictions'], 
    'Filtered BTS', 
    'Baseline CBM', 
    rule_checker
)

fuzzy_viols_bts = analyze_rule_violations(
    fuzzy_preds_bts['predictions'], 
    'Filtered BTS', 
    'Fuzzy CBM', 
    rule_checker
)

violation_comparison_bts = compare_violations(baseline_viols_bts, fuzzy_viols_bts)

print_violation_results(baseline_viols_bts)
print_violation_results(fuzzy_viols_bts, violation_comparison_bts)

print("\n" + "="*80)
print("PER-RULE FUZZY LOSS COMPARISON - FILTERED BTS")
print("="*80)
print("\n" + fuzzy_comparison_bts['rule_comparison'].to_string(index=False))

# Store results
all_predictions['bts_baseline'] = baseline_preds_bts
all_predictions['bts_fuzzy'] = fuzzy_preds_bts
all_fuzzy_metrics['Filtered BTS_baseline'] = baseline_fuzzy_bts
all_fuzzy_metrics['Filtered BTS_fuzzy'] = fuzzy_fuzzy_bts
all_violations['Filtered BTS_baseline'] = baseline_viols_bts
all_violations['Filtered BTS_fuzzy'] = fuzzy_viols_bts



ANALYZING FILTERED BTS DATASET (Near OOD - Different Traffic Signs)

Getting predictions from both models on Filtered BTS...


Getting Filtered BTS (Baseline) predictions: 100%|██████████| 37/37 [00:02<00:00, 13.47it/s]
Getting Filtered BTS (Fuzzy) predictions: 100%|██████████| 37/37 [00:02<00:00, 13.55it/s]



--------------------------------------------------------------------------------
FUZZY LOSS ANALYSIS - FILTERED BTS
--------------------------------------------------------------------------------

Baseline CBM on Filtered BTS:
  Standard BCE Loss:    0.0133687975
  Fuzzy Rules Loss:     0.1319774687
  Total Loss:           0.1453462662

Fuzzy CBM on Filtered BTS:
  Standard BCE Loss:    0.0123623293
  Fuzzy Rules Loss:     0.1152333990
  Total Loss:           0.1275957283

  Improvement over Baseline:
    Absolute: 0.0167440698
    Relative: 12.69%

--------------------------------------------------------------------------------
RULE VIOLATION ANALYSIS - FILTERED BTS
--------------------------------------------------------------------------------

Baseline CBM on Filtered BTS:
  Total samples:           2326
  Samples with violations: 567
  Violation rate:          24.38%

Fuzzy CBM on Filtered BTS:
  Total samples:           2326
  Samples with violations: 353
  Violation rate:     

In [37]:
# ============================================================================
# CROSS-DATASET PER-RULE FUZZY LOSS COMPARISON
# ============================================================================

print("\n" + "="*80)
print("CROSS-DATASET PER-RULE FUZZY LOSS COMPARISON")
print("="*80)

# Collect all rules across datasets
all_rules = set()
for comparison in [fuzzy_comparison_gtsrb, fuzzy_comparison_morphed, fuzzy_comparison_bts]:
    all_rules.update(comparison['rule_comparison']['Rule'].values)

# Build cross-dataset rule comparison
cross_dataset_rules = []
for rule in sorted(all_rules):
    row_data = {'Rule': rule}
    
    # GTSRB
    gtsrb_row = fuzzy_comparison_gtsrb['rule_comparison'][fuzzy_comparison_gtsrb['rule_comparison']['Rule'] == rule]
    if not gtsrb_row.empty:
        row_data['GTSRB Baseline'] = gtsrb_row.iloc[0]['Baseline Loss']
        row_data['GTSRB Fuzzy'] = gtsrb_row.iloc[0]['Fuzzy CBM Loss']
        row_data['GTSRB Improvement %'] = gtsrb_row.iloc[0]['Improvement %']
    else:
        row_data['GTSRB Baseline'] = 0.0
        row_data['GTSRB Fuzzy'] = 0.0
        row_data['GTSRB Improvement %'] = 0.0
    
    # Morphed GTSRB
    morphed_row = fuzzy_comparison_morphed['rule_comparison'][fuzzy_comparison_morphed['rule_comparison']['Rule'] == rule]
    if not morphed_row.empty:
        row_data['Morphed Baseline'] = morphed_row.iloc[0]['Baseline Loss']
        row_data['Morphed Fuzzy'] = morphed_row.iloc[0]['Fuzzy CBM Loss']
        row_data['Morphed Improvement %'] = morphed_row.iloc[0]['Improvement %']
    else:
        row_data['Morphed Baseline'] = 0.0
        row_data['Morphed Fuzzy'] = 0.0
        row_data['Morphed Improvement %'] = 0.0
    
    # BTS
    bts_row = fuzzy_comparison_bts['rule_comparison'][fuzzy_comparison_bts['rule_comparison']['Rule'] == rule]
    if not bts_row.empty:
        row_data['BTS Baseline'] = bts_row.iloc[0]['Baseline Loss']
        row_data['BTS Fuzzy'] = bts_row.iloc[0]['Fuzzy CBM Loss']
        row_data['BTS Improvement %'] = bts_row.iloc[0]['Improvement %']
    else:
        row_data['BTS Baseline'] = 0.0
        row_data['BTS Fuzzy'] = 0.0
        row_data['BTS Improvement %'] = 0.0
    
    cross_dataset_rules.append(row_data)

df_cross_dataset_rules = pd.DataFrame(cross_dataset_rules)
print("\n" + df_cross_dataset_rules.to_string(index=False))


CROSS-DATASET PER-RULE FUZZY LOSS COMPARISON

                          Rule  GTSRB Baseline  GTSRB Fuzzy  GTSRB Improvement %  Morphed Baseline  Morphed Fuzzy  Morphed Improvement %  BTS Baseline    BTS Fuzzy  BTS Improvement %
     at_most_one_border_colour    1.156950e-05 0.000000e+00           100.000000      3.267445e-03   9.438582e-12             100.000000  1.701329e-04 0.000000e+00         100.000000
       exactly_one_main_colour    7.500465e-05 7.501285e-05            -0.010942      7.541812e-05   7.526257e-05               0.206254  7.609694e-05 7.607189e-05           0.032920
             exactly_one_shape    7.501225e-05 7.501248e-05            -0.000310      7.532208e-05   7.524533e-05               0.101892  7.555744e-05 7.524015e-05           0.419932
no_symbols_exactly_two_colours    8.708388e-04 3.486725e-04            59.961305      1.799850e-02   8.990700e-03              50.047512  7.186738e-02 3.551781e-02          50.578682


#### Analysing the rule violations

In [38]:
# ============================================================================
# RULE VIOLATION ANALYSIS WITH DETAILED PER-RULE COMPARISON - THREE DATASETS
# ============================================================================

print("\n" + "="*80)
print("RULE VIOLATION ANALYSIS - ALL DATASETS")
print("="*80)

# Analyze violations for all datasets
print("\nAnalyzing GTSRB (In-Distribution)...")
baseline_viols_gtsrb = analyze_rule_violations(
    baseline_preds_gtsrb['predictions'], 
    'GTSRB', 
    'Baseline CBM', 
    rule_checker
)
fuzzy_viols_gtsrb = analyze_rule_violations(
    fuzzy_preds_gtsrb['predictions'], 
    'GTSRB', 
    'Fuzzy CBM', 
    rule_checker
)

print("Analyzing Morphed GTSRB...")
baseline_viols_morphed = analyze_rule_violations(
    baseline_preds_morphed['predictions'], 
    'Morphed GTSRB', 
    'Baseline CBM', 
    rule_checker
)
fuzzy_viols_morphed = analyze_rule_violations(
    fuzzy_preds_morphed['predictions'], 
    'Morphed GTSRB', 
    'Fuzzy CBM', 
    rule_checker
)

print("Analyzing Filtered BTS...")
baseline_viols_bts = analyze_rule_violations(
    baseline_preds_bts['predictions'], 
    'Filtered BTS', 
    'Baseline CBM', 
    rule_checker
)
fuzzy_viols_bts = analyze_rule_violations(
    fuzzy_preds_bts['predictions'], 
    'Filtered BTS', 
    'Fuzzy CBM', 
    rule_checker
)

# ============================================================================
# OVERALL VIOLATION RATE SUMMARY - THREE DATASETS
# ============================================================================

print("\n" + "="*80)
print("OVERALL VIOLATION RATE SUMMARY - ALL DATASETS")
print("="*80)

violation_summary = pd.DataFrame([
    {
        'Dataset': 'GTSRB',
        'Type': 'In-Dist',
        'Model': 'Baseline CBM',
        'Violation Rate (%)': baseline_viols_gtsrb['violation_rate'],
        'Violations': baseline_viols_gtsrb['total_violations'],
        'Total Samples': baseline_viols_gtsrb['total_samples']
    },
    {
        'Dataset': 'GTSRB',
        'Type': 'In-Dist',
        'Model': 'Fuzzy CBM',
        'Violation Rate (%)': fuzzy_viols_gtsrb['violation_rate'],
        'Violations': fuzzy_viols_gtsrb['total_violations'],
        'Total Samples': fuzzy_viols_gtsrb['total_samples']
    },
    {
        'Dataset': 'Morphed GTSRB',
        'Type': 'Near OOD',
        'Model': 'Baseline CBM',
        'Violation Rate (%)': baseline_viols_morphed['violation_rate'],
        'Violations': baseline_viols_morphed['total_violations'],
        'Total Samples': baseline_viols_morphed['total_samples']
    },
    {
        'Dataset': 'Morphed GTSRB',
        'Type': 'Near OOD',
        'Model': 'Fuzzy CBM',
        'Violation Rate (%)': fuzzy_viols_morphed['violation_rate'],
        'Violations': fuzzy_viols_morphed['total_violations'],
        'Total Samples': fuzzy_viols_morphed['total_samples']
    },
    {
        'Dataset': 'Filtered BTS',
        'Type': 'Near OOD',
        'Model': 'Baseline CBM',
        'Violation Rate (%)': baseline_viols_bts['violation_rate'],
        'Violations': baseline_viols_bts['total_violations'],
        'Total Samples': baseline_viols_bts['total_samples']
    },
    {
        'Dataset': 'Filtered BTS',
        'Type': 'Near OOD',
        'Model': 'Fuzzy CBM',
        'Violation Rate (%)': fuzzy_viols_bts['violation_rate'],
        'Violations': fuzzy_viols_bts['total_violations'],
        'Total Samples': fuzzy_viols_bts['total_samples']
    }
])

print("\n" + violation_summary.to_string(index=False))

# ============================================================================
# PER-CONSTRAINT ANALYSIS FOR ALL DATASETS
# ============================================================================

# Collect all constraints across all datasets
all_constraints = (
    set(baseline_viols_gtsrb['constraint_counts'].keys()) |
    set(baseline_viols_morphed['constraint_counts'].keys()) |
    set(baseline_viols_bts['constraint_counts'].keys())
)

# GTSRB Per-Constraint
print("\n" + "="*80)
print("PER-CONSTRAINT VIOLATION COUNTS - GTSRB (In-Distribution)")
print("="*80)

constraint_comparison_gtsrb = []
for constraint in sorted(all_constraints):
    baseline_count = baseline_viols_gtsrb['constraint_counts'].get(constraint, 0)
    fuzzy_count = fuzzy_viols_gtsrb['constraint_counts'].get(constraint, 0)
    improvement = baseline_count - fuzzy_count
    improvement_pct = (improvement / baseline_count * 100) if baseline_count > 0 else 0
    
    constraint_comparison_gtsrb.append({
        'Constraint': constraint,
        'Baseline Count': baseline_count,
        'Fuzzy CBM Count': fuzzy_count,
        'Improvement': improvement,
        'Improvement %': improvement_pct
    })

df_constraint_gtsrb = pd.DataFrame(constraint_comparison_gtsrb)
df_constraint_gtsrb = df_constraint_gtsrb.sort_values('Improvement', ascending=False)
print("\n" + df_constraint_gtsrb.to_string(index=False))

# Morphed GTSRB Per-Constraint
print("\n" + "="*80)
print("PER-CONSTRAINT VIOLATION COUNTS - MORPHED GTSRB")
print("="*80)

constraint_comparison_morphed = []
for constraint in sorted(all_constraints):
    baseline_count = baseline_viols_morphed['constraint_counts'].get(constraint, 0)
    fuzzy_count = fuzzy_viols_morphed['constraint_counts'].get(constraint, 0)
    improvement = baseline_count - fuzzy_count
    improvement_pct = (improvement / baseline_count * 100) if baseline_count > 0 else 0
    
    constraint_comparison_morphed.append({
        'Constraint': constraint,
        'Baseline Count': baseline_count,
        'Fuzzy CBM Count': fuzzy_count,
        'Improvement': improvement,
        'Improvement %': improvement_pct
    })

df_constraint_morphed = pd.DataFrame(constraint_comparison_morphed)
df_constraint_morphed = df_constraint_morphed.sort_values('Improvement', ascending=False)
print("\n" + df_constraint_morphed.to_string(index=False))

# BTS Per-Constraint
print("\n" + "="*80)
print("PER-CONSTRAINT VIOLATION COUNTS - FILTERED BTS")
print("="*80)

constraint_comparison_bts = []
for constraint in sorted(all_constraints):
    baseline_count = baseline_viols_bts['constraint_counts'].get(constraint, 0)
    fuzzy_count = fuzzy_viols_bts['constraint_counts'].get(constraint, 0)
    improvement = baseline_count - fuzzy_count
    improvement_pct = (improvement / baseline_count * 100) if baseline_count > 0 else 0
    
    constraint_comparison_bts.append({
        'Constraint': constraint,
        'Baseline Count': baseline_count,
        'Fuzzy CBM Count': fuzzy_count,
        'Improvement': improvement,
        'Improvement %': improvement_pct
    })

df_constraint_bts = pd.DataFrame(constraint_comparison_bts)
df_constraint_bts = df_constraint_bts.sort_values('Improvement', ascending=False)
print("\n" + df_constraint_bts.to_string(index=False))

# ============================================================================
# CROSS-DATASET PER-CONSTRAINT COMPARISON - THREE DATASETS
# ============================================================================

print("\n" + "="*80)
print("CROSS-DATASET PER-CONSTRAINT COMPARISON - ALL DATASETS")
print("="*80)

cross_dataset_all = []
for constraint in sorted(all_constraints):
    cross_dataset_all.append({
        'Constraint': constraint,
        'GTSRB Baseline': baseline_viols_gtsrb['constraint_counts'].get(constraint, 0),
        'GTSRB Fuzzy': fuzzy_viols_gtsrb['constraint_counts'].get(constraint, 0),
        'GTSRB Improvement': baseline_viols_gtsrb['constraint_counts'].get(constraint, 0) - 
                             fuzzy_viols_gtsrb['constraint_counts'].get(constraint, 0),
        'Morphed GTSRB Baseline': baseline_viols_morphed['constraint_counts'].get(constraint, 0),
        'Morphed GTSRB Fuzzy': fuzzy_viols_morphed['constraint_counts'].get(constraint, 0),
        'Morphed GTSRB Improvement': baseline_viols_morphed['constraint_counts'].get(constraint, 0) - 
                                     fuzzy_viols_morphed['constraint_counts'].get(constraint, 0),
        'BTS Baseline': baseline_viols_bts['constraint_counts'].get(constraint, 0),
        'BTS Fuzzy': fuzzy_viols_bts['constraint_counts'].get(constraint, 0),
        'BTS Improvement': baseline_viols_bts['constraint_counts'].get(constraint, 0) - 
                          fuzzy_viols_bts['constraint_counts'].get(constraint, 0)
    })

df_cross_dataset_all = pd.DataFrame(cross_dataset_all)
print("\n" + df_cross_dataset_all.to_string(index=False))

# ============================================================================
# IMPROVEMENT ANALYSIS - ALL DATASETS
# ============================================================================

print("\n" + "="*80)
print("IMPROVEMENT ANALYSIS - ALL DATASETS")
print("="*80)

gtsrb_improvement = baseline_viols_gtsrb['violation_rate'] - fuzzy_viols_gtsrb['violation_rate']
morphed_improvement = baseline_viols_morphed['violation_rate'] - fuzzy_viols_morphed['violation_rate']
bts_improvement = baseline_viols_bts['violation_rate'] - fuzzy_viols_bts['violation_rate']

print(f"\nGTSRB Dataset (In-Distribution):")
print(f"  Baseline violation rate:  {baseline_viols_gtsrb['violation_rate']:.2f}%")
print(f"  Fuzzy CBM violation rate: {fuzzy_viols_gtsrb['violation_rate']:.2f}%")
print(f"  Improvement:              {gtsrb_improvement:.2f} percentage points")
if baseline_viols_gtsrb['violation_rate'] > 0:
    print(f"  Relative improvement:     {(gtsrb_improvement / baseline_viols_gtsrb['violation_rate'] * 100):.2f}%")

print(f"\nMorphed GTSRB Dataset (Near OOD):")
print(f"  Baseline violation rate:  {baseline_viols_morphed['violation_rate']:.2f}%")
print(f"  Fuzzy CBM violation rate: {fuzzy_viols_morphed['violation_rate']:.2f}%")
print(f"  Improvement:              {morphed_improvement:.2f} percentage points")
if baseline_viols_morphed['violation_rate'] > 0:
    print(f"  Relative improvement:     {(morphed_improvement / baseline_viols_morphed['violation_rate'] * 100):.2f}%")

print(f"\nFiltered BTS Dataset (Near OOD):")
print(f"  Baseline violation rate:  {baseline_viols_bts['violation_rate']:.2f}%")
print(f"  Fuzzy CBM violation rate: {fuzzy_viols_bts['violation_rate']:.2f}%")
print(f"  Improvement:              {bts_improvement:.2f} percentage points")
if baseline_viols_bts['violation_rate'] > 0:
    print(f"  Relative improvement:     {(bts_improvement / baseline_viols_bts['violation_rate'] * 100):.2f}%")

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*80)
print("FINAL SUMMARY - RQ3 ROBUSTNESS ANALYSIS")
print("="*80)

print(f"\nViolation Rate Reductions Across All Datasets:")
print(f"  1. GTSRB (In-Dist):       {gtsrb_improvement:.2f} pp", end="")
if baseline_viols_gtsrb['violation_rate'] > 0:
    print(f" ({(gtsrb_improvement / baseline_viols_gtsrb['violation_rate'] * 100):+.2f}%)")
else:
    print()

print(f"  2. Morphed GTSRB:         {morphed_improvement:.2f} pp", end="")
if baseline_viols_morphed['violation_rate'] > 0:
    print(f" ({(morphed_improvement / baseline_viols_morphed['violation_rate'] * 100):+.2f}%)")
else:
    print()

print(f"  3. BTS:                   {bts_improvement:.2f} pp", end="")
if baseline_viols_bts['violation_rate'] > 0:
    print(f" ({(bts_improvement / baseline_viols_bts['violation_rate'] * 100):+.2f}%)")
else:
    print()

print("\n" + "="*80)
print("RQ3 ANALYSIS COMPLETE")
print("="*80)


RULE VIOLATION ANALYSIS - ALL DATASETS

Analyzing GTSRB (In-Distribution)...
Analyzing Morphed GTSRB...
Analyzing Filtered BTS...

OVERALL VIOLATION RATE SUMMARY - ALL DATASETS

      Dataset     Type        Model  Violation Rate (%)  Violations  Total Samples
        GTSRB  In-Dist Baseline CBM            0.134600          17          12630
        GTSRB  In-Dist    Fuzzy CBM            0.166271          21          12630
Morphed GTSRB Near OOD Baseline CBM            5.795724         732          12630
Morphed GTSRB Near OOD    Fuzzy CBM            4.386382         554          12630
 Filtered BTS Near OOD Baseline CBM           24.376612         567           2326
 Filtered BTS Near OOD    Fuzzy CBM           15.176268         353           2326

PER-CONSTRAINT VIOLATION COUNTS - GTSRB (In-Distribution)

                    Constraint  Baseline Count  Fuzzy CBM Count  Improvement  Improvement %
no_symbols_exactly_two_colours              16               12            4      25.000