#### RQ1: How does the logical consistency of the CBM change when introducing the requirements via Fuzzy Loss

- This RQ aims to evaluate the logical consistency of the CBM and validate the claim that the FuzzyLoss CBM learned the rules and adheres to them.
- Load the models
- Load the GTSRB Dataset
- Load the Raw Fuzzy Loss
- Calculate Fuzzy Loss on the test set. (GTSRB, BTS)
- Calculate Rule violations on the test set. (GTSRB, BTS)
- Calculate the HD on the test set (GTSRB, BTS)
- Calculate CL-Matrix

In [1]:
import sys
from pathlib import Path

current_dir = Path.cwd()
parent_dir = current_dir.parent
sys.path.insert(0, str(f"{parent_dir}/src"))

from models.architectures import CBMSequentialEfficientNetFCN
from train_cbm import cbm_load_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
baseline_cbm_config = cbm_load_config(Path("../files/configs/GTSRB_CBM_config_loading.yaml"))
baseline_cbm = CBMSequentialEfficientNetFCN(baseline_cbm_config)
fuzzy_cbm_config = cbm_load_config(Path("../files/configs/GTSRB_CBM_config_best_trial_loading.yaml"))
fuzzy_cbm = CBMSequentialEfficientNetFCN(fuzzy_cbm_config)

Directory 'experiments/20251010_150436' created successfully.
Directory 'experiments/20251010_150436' created successfully.


In [3]:
# loading the concrete models
baseline_cbm_concept_predictor_path = Path("../experiments/20251008_155607/models/20251008_155607_concept_predictor_best_model.pt")
baseline_cbm_label_predictor_path = Path("../experiments/baseline_cbm/models/20251001_083717_label_predictor_best_model.pt")
fuzzy_cbm_concept_predictor_path = Path("../experiments/20251010_102822/models/20251010_102822_concept_predictor_best_model.pt")
fuzzy_cbm_label_precitor_path = Path("../experiments/fuzzy_CBM/models/20251001_113637_label_predictor_best_model.pt")

In [4]:
import torch

# Load the baseline model components
baseline_cbm.concept_predictor.load_state_dict(
    torch.load(baseline_cbm_concept_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)
baseline_cbm.label_predictor.load_state_dict(
    torch.load(baseline_cbm_label_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)

# Load the fuzzy model components
fuzzy_cbm.concept_predictor.load_state_dict(
    torch.load(fuzzy_cbm_concept_predictor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)
fuzzy_cbm.label_predictor.load_state_dict(
    torch.load(fuzzy_cbm_label_precitor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)

# Set models to evaluation mode
baseline_cbm.eval()
fuzzy_cbm.eval()

print("Model components loaded successfully!")
print(f"Baseline CBM loaded from {baseline_cbm_concept_predictor_path.parent}")
print(f"Fuzzy CBM loaded from {fuzzy_cbm_concept_predictor_path.parent}")

Model components loaded successfully!
Baseline CBM loaded from ../experiments/20251008_155607/models
Fuzzy CBM loaded from ../experiments/20251010_102822/models


In [5]:
dataset_factory = baseline_cbm_config.dataset.factory(
    seed=baseline_cbm_config.seed, config=baseline_cbm_config.dataset
).set_dataloaders()
train_loader = dataset_factory.train_dataloader
val_loader = dataset_factory.val_dataloader
test_loader = dataset_factory.test_dataloader

In [6]:
# calculating the accuracy on the test of the model
from models.trainer.cbm_trainer import CBMTrainer
baseline_cbm_trainer = CBMTrainer(
    config=baseline_cbm_config,
    model=baseline_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)
fuzzy_cbm_trainer = CBMTrainer(
    config=fuzzy_cbm_config,
    model=fuzzy_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)

In [7]:
# getting the fuzzy loss function
neutral_fuzzy_loss = fuzzy_cbm_trainer.concept_predictor_trainer.criterion

In [8]:
# getting all the predictions on the test set from the fuzzy and the baseline cbm
all_logits_fuzzy, concept_predictions_fuzzy, concept_ground_truth_fuzzy, concept_probabilities_fuzzy = fuzzy_cbm_trainer.concept_predictor_trainer.get_predictions(dataloader=test_loader)
neutral_fuzzy_loss(torch.tensor(all_logits_fuzzy), torch.tensor(concept_ground_truth_fuzzy))

Getting predictions: 100%|██████████| 99/99 [01:54<00:00,  1.15s/it]


tensor(1.5155)

In [9]:
all_logits_baseline, concept_predictions_baseline, concept_ground_truth_baseline, concept_probabilities_baseline = baseline_cbm_trainer.concept_predictor_trainer.get_predictions(dataloader=test_loader)

Getting predictions: 100%|██████████| 99/99 [02:47<00:00,  1.70s/it]


In [10]:
# calculating the fuzzy neutral fuzzy loss on the test set
import torch
import numpy as np
from tqdm import tqdm

def compute_fuzzy_loss_metrics(model, dataloader, fuzzy_loss_fn, device, normal=False):
    """
    Compute fuzzy loss metrics for logical consistency evaluation.
    
    Args:
        model: The CBM model to evaluate
        dataloader: DataLoader containing the test data
        fuzzy_loss_fn: The CustomFuzzyLoss function with fuzzy rules
        device: Device to run computations on
        
    Returns:
        dict: Dictionary containing overall and per-rule fuzzy loss metrics
    """
    model.eval()
    
    # Accumulators for losses
    total_standard_loss = 0.0
    total_fuzzy_loss = 0.0
    rule_losses = {name: 0.0 for name in fuzzy_loss_fn.fuzzy_rules.keys()}
    num_samples = 0
    
    with torch.no_grad():
        for batch_idx, (idx, inputs, (concepts, _)) in enumerate(tqdm(dataloader, desc="Computing fuzzy loss")):
            inputs = inputs.to(device)
            concepts = concepts.to(device)
            
            # Forward pass
            outputs = model.concept_predictor(inputs)
            
            # Compute fuzzy loss (this updates last_standard_loss, last_fuzzy_loss, etc.)
            total_loss = fuzzy_loss_fn(outputs, concepts)
            
            # Accumulate losses
            batch_size = inputs.size(0)
            num_samples += batch_size
            
            total_standard_loss += fuzzy_loss_fn.last_standard_loss.item() * batch_size
            total_fuzzy_loss += fuzzy_loss_fn.last_fuzzy_loss.item() * batch_size
            
            # Accumulate individual rule losses
            for rule_name, loss_val in fuzzy_loss_fn.last_individual_losses.items():
                rule_losses[rule_name] += loss_val.item() * batch_size
    
    # Calculate averages
    avg_standard_loss = total_standard_loss / num_samples
    avg_fuzzy_loss = total_fuzzy_loss / num_samples
    avg_total_loss = avg_standard_loss + avg_fuzzy_loss
    avg_rule_losses = {name: loss / num_samples for name, loss in rule_losses.items()}
    
    # Calculate relative contribution of each rule to total fuzzy loss
    rule_contributions = {}
    if avg_fuzzy_loss > 0:
        rule_contributions = {
            name: (loss / num_samples) / avg_fuzzy_loss 
            for name, loss in rule_losses.items()
        }
    
    return {
        'standard_loss': avg_standard_loss,
        'fuzzy_loss': avg_fuzzy_loss,
        'total_loss': avg_total_loss,
        'rule_losses': avg_rule_losses,
        'rule_contributions': rule_contributions,
        'num_samples': num_samples
    }

# Compute fuzzy loss metrics for both models on the test set
print("Computing fuzzy loss metrics for Baseline CBM...")
baseline_fuzzy_metrics = compute_fuzzy_loss_metrics(
    baseline_cbm, 
    test_loader, 
    neutral_fuzzy_loss,
    baseline_cbm_config.device
)

print("\nComputing fuzzy loss metrics for Fuzzy CBM...")
fuzzy_cbm_metrics = compute_fuzzy_loss_metrics(
    fuzzy_cbm, 
    test_loader, 
    neutral_fuzzy_loss,
    fuzzy_cbm_config.device
)

# Display results
print("\n" + "="*80)
print("FUZZY LOSS ANALYSIS - LOGICAL CONSISTENCY METRICS")
print("="*80)

print(f"\nBaseline CBM (trained without fuzzy loss):")
print(f"  Standard BCE Loss:    {baseline_fuzzy_metrics['standard_loss']:.10f}")
print(f"  Fuzzy Rules Loss:     {baseline_fuzzy_metrics['fuzzy_loss']:.10f}")
print(f"  Total Loss:           {baseline_fuzzy_metrics['total_loss']:.10f}")
print(f"\n  Individual Rule Violations:")
for rule_name, loss in baseline_fuzzy_metrics['rule_losses'].items():
    contribution = baseline_fuzzy_metrics['rule_contributions'].get(rule_name, 0) * 100
    print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")

print(f"\nFuzzy CBM (trained with fuzzy loss):")
print(f"  Standard BCE Loss:    {fuzzy_cbm_metrics['standard_loss']:.10f}")
print(f"  Fuzzy Rules Loss:     {fuzzy_cbm_metrics['fuzzy_loss']:.10f}")
print(f"  Total Loss:           {fuzzy_cbm_metrics['total_loss']:.10f}")
print(f"\n  Individual Rule Violations:")
for rule_name, loss in fuzzy_cbm_metrics['rule_losses'].items():
    contribution = fuzzy_cbm_metrics['rule_contributions'].get(rule_name, 0) * 100
    print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")

# Calculate improvements
print("\n" + "="*80)
print("IMPROVEMENT ANALYSIS")
print("="*80)

fuzzy_loss_reduction = baseline_fuzzy_metrics['fuzzy_loss'] - fuzzy_cbm_metrics['fuzzy_loss']
fuzzy_loss_reduction_pct = (fuzzy_loss_reduction / baseline_fuzzy_metrics['fuzzy_loss']) * 100

print(f"\nOverall Fuzzy Loss Reduction:")
print(f"  Absolute: {fuzzy_loss_reduction:.10f}")
print(f"  Relative: {fuzzy_loss_reduction_pct:.2f}%")

print(f"\nPer-Rule Improvements:")
for rule_name in baseline_fuzzy_metrics['rule_losses'].keys():
    baseline_loss = baseline_fuzzy_metrics['rule_losses'][rule_name]
    fuzzy_loss = fuzzy_cbm_metrics['rule_losses'][rule_name]
    improvement = baseline_loss - fuzzy_loss
    improvement_pct = (improvement / baseline_loss) * 100 if baseline_loss > 0 else 0
    print(f"  {rule_name:30s}: {improvement:.10f} ({improvement_pct:+6.2f}%)")

# Visualization of rule contributions
print("\n" + "="*80)
print("RULE CONTRIBUTION COMPARISON")
print("="*80)

import pandas as pd

# Create comparison dataframe
rule_comparison = []
for rule_name in baseline_fuzzy_metrics['rule_losses'].keys():
    rule_comparison.append({
        'Rule': rule_name,
        'Baseline Loss': baseline_fuzzy_metrics['rule_losses'][rule_name],
        'Fuzzy CBM Loss': fuzzy_cbm_metrics['rule_losses'][rule_name],
        'Improvement': baseline_fuzzy_metrics['rule_losses'][rule_name] - fuzzy_cbm_metrics['rule_losses'][rule_name],
        'Improvement %': ((baseline_fuzzy_metrics['rule_losses'][rule_name] - fuzzy_cbm_metrics['rule_losses'][rule_name]) 
                         / baseline_fuzzy_metrics['rule_losses'][rule_name] * 100) if baseline_fuzzy_metrics['rule_losses'][rule_name] > 0 else 0
    })

df_comparison = pd.DataFrame(rule_comparison)
df_comparison = df_comparison.sort_values('Improvement %', ascending=False)
print(df_comparison.to_string(index=False))

# Summary statistics
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Number of fuzzy rules evaluated: {len(baseline_fuzzy_metrics['rule_losses'])}")
print(f"Number of test samples: {baseline_fuzzy_metrics['num_samples']}")
print(f"\nLogical Consistency Score (lower is better):")
print(f"  Baseline CBM:  {baseline_fuzzy_metrics['fuzzy_loss']:.10f}")
print(f"  Fuzzy CBM:     {fuzzy_cbm_metrics['fuzzy_loss']:.10f}")
print(f"  Improvement:   {fuzzy_loss_reduction_pct:.2f}%")

Computing fuzzy loss metrics for Baseline CBM...


Computing fuzzy loss: 100%|██████████| 99/99 [02:49<00:00,  1.71s/it]



Computing fuzzy loss metrics for Fuzzy CBM...


Computing fuzzy loss: 100%|██████████| 99/99 [01:19<00:00,  1.25it/s]


FUZZY LOSS ANALYSIS - LOGICAL CONSISTENCY METRICS

Baseline CBM (trained without fuzzy loss):
  Standard BCE Loss:    0.0022116409
  Fuzzy Rules Loss:     1.5133473602
  Total Loss:           1.5155590011

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000008134 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7500880942 (49.56% of fuzzy loss)
    exactly_one_shape             : 0.7501135680 (49.57% of fuzzy loss)
    no_symbols_exactly_two_colours: 0.0005755201 ( 0.04% of fuzzy loss)

Fuzzy CBM (trained with fuzzy loss):
  Standard BCE Loss:    0.0024025973
  Fuzzy Rules Loss:     1.5130741719
  Total Loss:           1.5154767692

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000016019 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7500558726 (49.57% of fuzzy loss)
    exactly_one_shape             : 0.7500491422 (49.57% of fuzzy loss)
    no_symbols_exactly_two_colours: 0.0006275345 ( 0.04% of fuzzy loss)

IMPROV




In [11]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_prediction_level_accuracy(model, dataloader, device):
    """Evaluate model accuracy at the prediction level (all concepts must be correct)"""
    model.eval()
    all_correct_predictions = 0
    all_predictions = 0 
    
    # For collecting per-concept metrics
    concept_true = []
    concept_pred = []
    
    # For collecting per-sample metrics (all concepts must match)
    sample_correct = []
    
    with torch.no_grad():
        for batch_idx, (idx, inputs, (concepts, _)) in enumerate(tqdm(dataloader, desc="Evaluating")):
            inputs = inputs.to(device)
            concepts = concepts.to(device)
            
            # Forward pass
            outputs = model.concept_predictor(inputs)
            
            # Get binary predictions
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            # Add to concept-level metrics collection
            concept_true.append(concepts.cpu().numpy())
            concept_pred.append(preds.cpu().numpy())
            
            # Calculate prediction-level accuracy (all concepts must be correct)
            batch_size = concepts.size(0)
            for i in range(batch_size):
                # Check if all concepts match for this sample
                sample_match = torch.all(preds[i] == concepts[i]).item()
                sample_correct.append(sample_match)
                
            # Update counters
            all_predictions += batch_size
            all_correct_predictions += torch.sum(torch.all(preds == concepts, dim=1)).item()
    
    # Calculate metrics
    prediction_accuracy = all_correct_predictions / all_predictions
    
    # Calculate per-concept metrics
    concept_true = np.vstack(concept_true)
    concept_pred = np.vstack(concept_pred)
    
    # Per-concept accuracy (as currently reported)
    concept_accuracy = np.mean((concept_true == concept_pred).flatten())
    
    # Per-concept precision, recall, F1
    precision = precision_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    recall = recall_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    f1 = f1_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    
    return {
        'prediction_accuracy': prediction_accuracy,
        'concept_accuracy': concept_accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'sample_correct': sample_correct
    }

# Run evaluation for both models
print("Evaluating baseline CBM...")
baseline_results = evaluate_prediction_level_accuracy(baseline_cbm, test_loader, baseline_cbm_config.device)

print("Evaluating fuzzy CBM...")
fuzzy_results = evaluate_prediction_level_accuracy(fuzzy_cbm, test_loader, fuzzy_cbm_config.device)

# Print results
print("\n===== RESULTS =====")
print(f"Baseline CBM:")
print(f"  Per-prediction accuracy (all concepts correct): {baseline_results['prediction_accuracy']:.4f}")
print(f"  Per-concept accuracy (as reported before): {baseline_results['concept_accuracy']:.4f}")
print(f"  Precision: {baseline_results['precision']:.4f}")
print(f"  Recall: {baseline_results['recall']:.4f}")
print(f"  F1 Score: {baseline_results['f1']:.4f}")

print(f"\nFuzzy CBM:")
print(f"  Per-prediction accuracy (all concepts correct): {fuzzy_results['prediction_accuracy']:.4f}")
print(f"  Per-concept accuracy (as reported before): {fuzzy_results['concept_accuracy']:.4f}")
print(f"  Precision: {fuzzy_results['precision']:.4f}")
print(f"  Recall: {fuzzy_results['recall']:.4f}")
print(f"  F1 Score: {fuzzy_results['f1']:.4f}")

# Calculate improvement
prediction_improvement = fuzzy_results['prediction_accuracy'] - baseline_results['prediction_accuracy']
concept_improvement = fuzzy_results['concept_accuracy'] - baseline_results['concept_accuracy']

print(f"\nImprovement with Fuzzy Loss:")
print(f"  Per-prediction accuracy: {prediction_improvement:.4f} ({prediction_improvement*100:.2f}%)")
print(f"  Per-concept accuracy: {concept_improvement:.4f} ({concept_improvement*100:.2f}%)")

Evaluating baseline CBM...


Evaluating: 100%|██████████| 99/99 [02:17<00:00,  1.38s/it]


Evaluating fuzzy CBM...


Evaluating: 100%|██████████| 99/99 [02:42<00:00,  1.64s/it]



===== RESULTS =====
Baseline CBM:
  Per-prediction accuracy (all concepts correct): 0.9880
  Per-concept accuracy (as reported before): 0.9995
  Precision: 0.9981
  Recall: 0.9970
  F1 Score: 0.9975

Fuzzy CBM:
  Per-prediction accuracy (all concepts correct): 0.9904
  Per-concept accuracy (as reported before): 0.9996
  Precision: 0.9985
  Recall: 0.9974
  F1 Score: 0.9980

Improvement with Fuzzy Loss:
  Per-prediction accuracy: 0.0025 (0.25%)
  Per-concept accuracy: 0.0001 (0.01%)


In [12]:
# loading the predefined rule checker that constructs a graph of the concept groups and assigns rules on top of them
from rule_eval import construct_full_graph
rule_checker = construct_full_graph()

- All concepts
   - All colors
      - Main colors
         - Border colors
         - Arrow symbols
            - All symbols
               - General symbols
               - Curve symbols
                  - All shapes
               - Regulatory signs


In [21]:
# evaluating the rules violations on the test set of the fuzzy CBM
counter_graph_semantic_relation_fuzzy = 0
flagged_indices_graph_semantic_relation_fuzzy = []
constraints_violated_relation_fuzzy = []
for i, pred in enumerate(concept_predictions_fuzzy):
    violated = rule_checker.check_concept_vector(pred, verbose=True, early_stop=False)
    if violated:
        counter_graph_semantic_relation_fuzzy+= 1
        flagged_indices_graph_semantic_relation_fuzzy.append(i)
        constraints_violated_relation_fuzzy.append(violated)

# Sanity check for semantic graph
print(f"Number of flagged indices in semantic graph: {len(flagged_indices_graph_semantic_relation_fuzzy)}")

Constraint violated: General concept invariant
Violating concept names: ['shape_round']
Constraint violated: General color constraint
Violating concept names: []
Constraint violated: No symbols => 2 colors
Violating concept names: []
Constraint violated: Main color constraint
Violating concept names: []
Violating concept names: ['symbol_animal', 'symbol_double_curve']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_round']
Constraint violated: No symbols => 2 colors
Violating concept names: ['main_color_blue']
Constraint violated: Main blue => arrow (possible OOD)
Violating concept names: ['main_color_blue']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_round']
Constraint violated: No symbols => 2 colors
Violating concept names: ['main_color_blue']
Constraint violated: Main blue => arrow (possible OOD)
Violating concept names: ['main_color_blue']
Constraint violated: General concept 

In [22]:
# evaluating the rule violations on the test set of the baseline CBM
counter_graph_semantic_relation_baseline = 0
flagged_indices_graph_semantic_relation_baseline = []
constraints_violated_relation_baseline = []
for i, pred in enumerate(concept_predictions_baseline):
    violated = rule_checker.check_concept_vector(pred, verbose=True, early_stop=False)
    if violated:
        counter_graph_semantic_relation_fuzzy+= 1
        flagged_indices_graph_semantic_relation_baseline.append(i)
        constraints_violated_relation_baseline.append(violated)

# Sanity check for semantic graph
print(f"Number of flagged indices in semantic graph: {len(flagged_indices_graph_semantic_relation_baseline)}")

Violating concept names: ['symbol_animal', 'symbol_construction_site']
Violating concept names: ['symbol_animal', 'symbol_construction_site']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_round']
Constraint violated: No symbols => 2 colors
Violating concept names: ['main_color_blue']
Constraint violated: Main blue => arrow (possible OOD)
Violating concept names: ['main_color_blue']
Constraint violated: Shape constraint
Violating concept names: ['shape_round', 'shape_square']
Constraint violated: No symbols => 2 colors
Violating concept names: ['main_color_blue']
Constraint violated: Main blue => arrow (possible OOD)
Violating concept names: ['main_color_blue']
Constraint violated: General concept invariant
Violating concept names: ['shape_triangular']
Constraint violated: General color constraint
Violating concept names: []
Constraint violated: No symbols => 2 colors
Violating concept names: []
Constraint violated: Main color constra

#### Loading and testing on the BTS dataset as a near out-of-distribution dataset

In [15]:
# fixing seed to ensuring reproducibility
seed = 42
from config.dataset_config import ConceptDatasetConfig
from pathlib import Path

# Create BTS dataset configuration
bts_config = ConceptDatasetConfig(
    name="bts",
    n_labels=62,  
    batch_size=64,
    num_workers=4,
    shuffle_dataset=False,  # For evaluation, typically don't shuffle
    pin_memory=True,
    data_path=Path("../data/raw/BTSD/"),
    val_split=0.2

)
# Resolve the configuration to validate paths and load concept map
bts_config.resolve()

# Create the dataset factory
from data_access.datasets.BTSFactory import BTSFactory

bts_factory = BTSFactory(config=bts_config, seed=seed)
bts_factory.set_dataloaders()

# Access the dataloaders
bts_test_loader = bts_factory.test_dataloader
bts_train_loader = bts_factory.train_dataloader
bts_val_loader = bts_factory.val_dataloader

In [16]:
# list of each sign in the GTSRB and how it is mapped to the BTS. If there is an unambiguous mapping, the value is set to the BST label, otherwise it is -1. 
labels_id_GTS_to_BTS = [-1,-1,-1,-1,-1,-1,-1,-1,-1,31,-1,17,61,19,21,28,25,22,13,3,4,5,0,2,16,10,-1,-1,-1,8,-1,-1,-1,-1,-1,34,-1,-1,35,-1,37,-1,-1,]
mapping_BTS_to_GTSRB = {i : labels_id_GTS_to_BTS.index(i) for i in labels_id_GTS_to_BTS if i != -1}
Belgium_ID_to_Name = {0: "Bumpy road",1: "Bump",2: "Slippery road",3: "Bend to The left",4: "Bend to The right",5: "Double curves first to the left",6: "Double curves first to the right",7: "School zone",8: "Bikes can be cross",9: "Domestic animal crossing",10: "Roadworks",11: "Traffic light",12: "Gated railroad crossing ahead",13: "Caution",14: "Road narrows",15: "Road narrows on the left",16: "Road narrows on the right",17: "Intersection with priority",18: "Intersection with priority to the right",19: "Yield",20: "Yield to incoming traffic",21: "Stop",22: "No entry",23: "No entry for cyclists",24: "No vehicle over 2t",25: "No entry for trucks",26: "Width limit",27: "Height limit",28: "No vehicles",29: "No left turn",30: "No right turn",31: "No overtaking",32: "Speed limit",33: "Shared path for pedestrians and cyclists",34: "Ahead only",35: "Right only",36: "Ahead and right only",37: "Roundabout",38: "Cycleway",39: "Segregated path for pedestrians and cyclists",40: "No parking",41: "No stopping",42: "No parking from the 1st till 15th day of the month",43: "No parking from the 16th till last day of the month",44: "Priority over oncoming traffic",45: "Parking permitted",46: "Parking for disabled",47: "Parking reserved for motorcycles, cars, vans (< 3.5t) and minibusses",48: "Parking reserved for trucks",49: "Parking reserved for coaches",50: "Parking mandatory on the verge or sidewalk",51: "Start of a living street",52: "End of living street",53: "One-way road",54: "Dead end",55: "End of roadworks",56: "Pedestrian crossing",57: "Cyclist and moped crossing",58: "Parking lot",59: "Hump",60: "End of priority road",61: "Priority road"}

In [17]:
from torch.utils.data import DataLoader, Subset

# Get labels that have unambiguous mapping to GTSRB
valid_bts_labels = [i for i in labels_id_GTS_to_BTS if i != -1]

def filter_dataset_by_labels(dataset, valid_labels):
    """Filter dataset to only include samples with labels in valid_labels"""
    valid_indices = []
    
    # Iterate through the dataset and collect indices with valid labels
    for idx in range(len(dataset)):
        idx, _, label = dataset[idx]  # Unpack: (idx, image, (concepts, label))
        if label in valid_labels:
            valid_indices.append(idx)
    
    return Subset(dataset, valid_indices)

# Filter the BTS datasets
print("Filtering BTS datasets to only include unambiguous mappings to GTSRB...")
print(f"Valid BTS labels: {sorted(valid_bts_labels)}")
print(f"Number of valid labels: {len(valid_bts_labels)}")

# Filter test dataset
bts_test_filtered = filter_dataset_by_labels(bts_factory.test_dataset, valid_bts_labels)
print(f"Original BTS test size: {len(bts_factory.test_dataset)}")
print(f"Filtered BTS test size: {len(bts_test_filtered)}")

# Filter train dataset
bts_train_filtered = filter_dataset_by_labels(bts_factory.train_dataset, valid_bts_labels)
print(f"Original BTS train size: {len(bts_factory.train_dataset)}")
print(f"Filtered BTS train size: {len(bts_train_filtered)}")

# Filter validation dataset
bts_val_filtered = filter_dataset_by_labels(bts_factory.val_dataset, valid_bts_labels)
print(f"Original BTS val size: {len(bts_factory.val_dataset)}")
print(f"Filtered BTS val size: {len(bts_val_filtered)}")

# Create new dataloaders with filtered datasets
bts_test_loader_filtered = DataLoader(
    bts_test_filtered,
    batch_size=bts_config.batch_size,
    shuffle=False,
    num_workers=bts_config.num_workers,
    pin_memory=bts_config.pin_memory
)

bts_train_loader_filtered = DataLoader(
    bts_train_filtered,
    batch_size=bts_config.batch_size,
    shuffle=False,
    num_workers=bts_config.num_workers,
    pin_memory=bts_config.pin_memory
)

bts_val_loader_filtered = DataLoader(
    bts_val_filtered,
    batch_size=bts_config.batch_size,
    shuffle=False,
    num_workers=bts_config.num_workers,
    pin_memory=bts_config.pin_memory
)

print("\nFiltered dataloaders created successfully!")
print(f"You can now use: bts_test_loader_filtered, bts_train_loader_filtered, bts_val_loader_filtered")

Filtering BTS datasets to only include unambiguous mappings to GTSRB...
Valid BTS labels: [0, 2, 3, 4, 5, 8, 10, 13, 16, 17, 19, 21, 22, 25, 28, 31, 34, 35, 37, 61]
Number of valid labels: 20
Original BTS test size: 2520
Filtered BTS test size: 1016
Original BTS train size: 3660
Filtered BTS train size: 1334
Original BTS val size: 915
Filtered BTS val size: 326

Filtered dataloaders created successfully!
You can now use: bts_test_loader_filtered, bts_train_loader_filtered, bts_val_loader_filtered


In [18]:
def get_bts_predictions(model, dataloader, device):
    """
    Get concept predictions for BTS dataset (which doesn't have ground truth concepts).
    
    Args:
        model: The CBM model to evaluate
        dataloader: DataLoader for BTS dataset (returns idx, image, label)
        device: Device to run computations on
        
    Returns:
        tuple: (all_logits, concept_predictions, labels, concept_probabilities)
    """
    model.eval()
    
    all_logits = []
    concept_predictions = []
    concept_probabilities = []
    all_labels = []
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc="Getting BTS predictions"):
            # BTS dataset returns (idx, image, label) - no concepts
            if len(batch_data) == 3:
                idx, inputs, labels = batch_data
            else:
                # In case it's wrapped differently
                idx, inputs, (_, labels) = batch_data
            
            inputs = inputs.to(device)
            
            # Forward pass through concept predictor
            outputs = model.concept_predictor(inputs)
            
            # Get probabilities and binary predictions
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            # Store results
            all_logits.append(outputs.cpu().numpy())
            concept_predictions.append(preds.cpu().numpy())
            concept_probabilities.append(probs.cpu().numpy())
            all_labels.append(labels.cpu().numpy() if isinstance(labels, torch.Tensor) else labels)
    
    # Concatenate all batches
    all_logits = np.vstack(all_logits)
    concept_predictions = np.vstack(concept_predictions)
    concept_probabilities = np.vstack(concept_probabilities)
    all_labels = np.concatenate(all_labels)
    
    return all_logits, concept_predictions, all_labels, concept_probabilities

# Get predictions for both models on BTS dataset
print("Getting predictions for Fuzzy CBM on BTS...")
all_logits_fuzzy_bts, concept_predictions_fuzzy_bts, labels_fuzzy_bts, concept_probabilities_fuzzy_bts = \
    get_bts_predictions(fuzzy_cbm, bts_test_loader_filtered, fuzzy_cbm_config.device)

print(f"Fuzzy CBM BTS predictions shape: {concept_predictions_fuzzy_bts.shape}")
print(f"Number of BTS samples: {len(labels_fuzzy_bts)}")

print("\nGetting predictions for Baseline CBM on BTS...")
all_logits_baseline_bts, concept_predictions_baseline_bts, labels_baseline_bts, concept_probabilities_baseline_bts = \
    get_bts_predictions(baseline_cbm, bts_test_loader_filtered, baseline_cbm_config.device)

print(f"Baseline CBM BTS predictions shape: {concept_predictions_baseline_bts.shape}")
print(f"Number of BTS samples: {len(labels_baseline_bts)}")

# Verify both models processed the same data
assert np.array_equal(labels_fuzzy_bts, labels_baseline_bts), "Labels mismatch between models!"
print("\n✓ Both models processed the same BTS test set successfully")

Getting predictions for Fuzzy CBM on BTS...


Getting BTS predictions:   0%|          | 0/16 [00:00<?, ?it/s]

Getting BTS predictions: 100%|██████████| 16/16 [00:01<00:00,  8.85it/s]


Fuzzy CBM BTS predictions shape: (1016, 43)
Number of BTS samples: 1016

Getting predictions for Baseline CBM on BTS...


Getting BTS predictions: 100%|██████████| 16/16 [00:02<00:00,  6.30it/s]

Baseline CBM BTS predictions shape: (1016, 43)
Number of BTS samples: 1016

✓ Both models processed the same BTS test set successfully





In [19]:
# ============================================================================
# FUZZY LOSS ANALYSIS ON BTS DATASET (without ground truth concepts)
# ============================================================================

print("="*80)
print("FUZZY LOSS ANALYSIS ON BTS DATASET")
print("="*80)

# Since BTS doesn't have ground truth concepts, we'll evaluate fuzzy loss
# using the predicted concepts as if they were ground truth
# This shows how well each model adheres to fuzzy rules on OOD data

# Convert predictions to tensors
all_logits_fuzzy_bts_tensor = torch.tensor(all_logits_fuzzy_bts, dtype=torch.float32)
all_logits_baseline_bts_tensor = torch.tensor(all_logits_baseline_bts, dtype=torch.float32)
concept_predictions_fuzzy_bts_tensor = torch.tensor(concept_predictions_fuzzy_bts, dtype=torch.float32)
concept_predictions_baseline_bts_tensor = torch.tensor(concept_predictions_baseline_bts, dtype=torch.float32)

# Calculate fuzzy loss for Fuzzy CBM on BTS
# Note: We use predicted concepts as pseudo ground truth to evaluate rule consistency
fuzzy_cbm_bts_loss = neutral_fuzzy_loss(all_logits_fuzzy_bts_tensor, concept_predictions_fuzzy_bts_tensor)
fuzzy_cbm_bts_standard_loss = neutral_fuzzy_loss.last_standard_loss.item()
fuzzy_cbm_bts_fuzzy_loss = neutral_fuzzy_loss.last_fuzzy_loss.item()
fuzzy_cbm_bts_total_loss = fuzzy_cbm_bts_standard_loss + fuzzy_cbm_bts_fuzzy_loss
fuzzy_cbm_bts_rule_losses = {name: loss.item() for name, loss in neutral_fuzzy_loss.last_individual_losses.items()}

print(f"\nFuzzy CBM on BTS Dataset:")
print(f"  Standard BCE Loss:    {fuzzy_cbm_bts_standard_loss:.10f}")
print(f"  Fuzzy Rules Loss:     {fuzzy_cbm_bts_fuzzy_loss:.10f}")
print(f"  Total Loss:           {fuzzy_cbm_bts_total_loss:.10f}")
print(f"\n  Individual Rule Violations:")
if fuzzy_cbm_bts_fuzzy_loss > 0:
    for rule_name, loss in fuzzy_cbm_bts_rule_losses.items():
        contribution = (loss / fuzzy_cbm_bts_fuzzy_loss) * 100
        print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")
else:
    for rule_name, loss in fuzzy_cbm_bts_rule_losses.items():
        print(f"    {rule_name:30s}: {loss:.10f}")

# Calculate fuzzy loss for Baseline CBM on BTS
baseline_cbm_bts_loss = neutral_fuzzy_loss(all_logits_baseline_bts_tensor, concept_predictions_baseline_bts_tensor)
baseline_cbm_bts_standard_loss = neutral_fuzzy_loss.last_standard_loss.item()
baseline_cbm_bts_fuzzy_loss = neutral_fuzzy_loss.last_fuzzy_loss.item()
baseline_cbm_bts_total_loss = baseline_cbm_bts_standard_loss + baseline_cbm_bts_fuzzy_loss
baseline_cbm_bts_rule_losses = {name: loss.item() for name, loss in neutral_fuzzy_loss.last_individual_losses.items()}

print(f"\nBaseline CBM on BTS Dataset:")
print(f"  Standard BCE Loss:    {baseline_cbm_bts_standard_loss:.10f}")
print(f"  Fuzzy Rules Loss:     {baseline_cbm_bts_fuzzy_loss:.10f}")
print(f"  Total Loss:           {baseline_cbm_bts_total_loss:.10f}")
print(f"\n  Individual Rule Violations:")
if baseline_cbm_bts_fuzzy_loss > 0:
    for rule_name, loss in baseline_cbm_bts_rule_losses.items():
        contribution = (loss / baseline_cbm_bts_fuzzy_loss) * 100
        print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")
else:
    for rule_name, loss in baseline_cbm_bts_rule_losses.items():
        print(f"    {rule_name:30s}: {loss:.10f}")

# ============================================================================
# COMPARISON AND IMPROVEMENT ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("BTS DATASET - IMPROVEMENT ANALYSIS")
print("="*80)

fuzzy_loss_reduction_bts = baseline_cbm_bts_fuzzy_loss - fuzzy_cbm_bts_fuzzy_loss
fuzzy_loss_reduction_pct_bts = (fuzzy_loss_reduction_bts / baseline_cbm_bts_fuzzy_loss * 100) if baseline_cbm_bts_fuzzy_loss > 0 else 0

print(f"\nOverall Fuzzy Loss Reduction on BTS:")
print(f"  Baseline CBM Fuzzy Loss:  {baseline_cbm_bts_fuzzy_loss:.10f}")
print(f"  Fuzzy CBM Fuzzy Loss:     {fuzzy_cbm_bts_fuzzy_loss:.10f}")
print(f"  Absolute Reduction:       {fuzzy_loss_reduction_bts:.10f}")
print(f"  Relative Reduction:       {fuzzy_loss_reduction_pct_bts:.2f}%")

print(f"\nPer-Rule Improvements on BTS:")
for rule_name in baseline_cbm_bts_rule_losses.keys():
    baseline_loss = baseline_cbm_bts_rule_losses[rule_name]
    fuzzy_loss = fuzzy_cbm_bts_rule_losses[rule_name]
    improvement = baseline_loss - fuzzy_loss
    improvement_pct = (improvement / baseline_loss * 100) if baseline_loss > 0 else 0
    print(f"  {rule_name:30s}: {improvement:.10f} ({improvement_pct:+6.2f}%)")

# ============================================================================
# CROSS-DATASET COMPARISON (GTSRB vs BTS)
# ============================================================================

print("\n" + "="*80)
print("CROSS-DATASET COMPARISON: GTSRB vs BTS")
print("="*80)

import pandas as pd

comparison_data = {
    'Dataset': ['GTSRB', 'BTS'],
    'Baseline Fuzzy Loss': [baseline_fuzzy_metrics['fuzzy_loss'], baseline_cbm_bts_fuzzy_loss],
    'Fuzzy CBM Fuzzy Loss': [fuzzy_cbm_metrics['fuzzy_loss'], fuzzy_cbm_bts_fuzzy_loss],
    'Absolute Reduction': [
        baseline_fuzzy_metrics['fuzzy_loss'] - fuzzy_cbm_metrics['fuzzy_loss'],
        fuzzy_loss_reduction_bts
    ],
    'Relative Reduction (%)': [
        (baseline_fuzzy_metrics['fuzzy_loss'] - fuzzy_cbm_metrics['fuzzy_loss']) / baseline_fuzzy_metrics['fuzzy_loss'] * 100,
        fuzzy_loss_reduction_pct_bts
    ]
}

df_cross_comparison = pd.DataFrame(comparison_data)
print("\n" + df_cross_comparison.to_string(index=False))

# Per-rule comparison across datasets
print("\n" + "="*80)
print("PER-RULE COMPARISON: GTSRB vs BTS")
print("="*80)

rule_cross_comparison = []
for rule_name in baseline_fuzzy_metrics['rule_losses'].keys():
    rule_cross_comparison.append({
        'Rule': rule_name,
        'GTSRB Baseline': baseline_fuzzy_metrics['rule_losses'][rule_name],
        'GTSRB Fuzzy CBM': fuzzy_cbm_metrics['rule_losses'][rule_name],
        'BTS Baseline': baseline_cbm_bts_rule_losses[rule_name],
        'BTS Fuzzy CBM': fuzzy_cbm_bts_rule_losses[rule_name],
        'GTSRB Improvement %': ((baseline_fuzzy_metrics['rule_losses'][rule_name] - fuzzy_cbm_metrics['rule_losses'][rule_name]) / baseline_fuzzy_metrics['rule_losses'][rule_name] * 100) if baseline_fuzzy_metrics['rule_losses'][rule_name] > 0 else 0,
        'BTS Improvement %': ((baseline_cbm_bts_rule_losses[rule_name] - fuzzy_cbm_bts_rule_losses[rule_name]) / baseline_cbm_bts_rule_losses[rule_name] * 100) if baseline_cbm_bts_rule_losses[rule_name] > 0 else 0
    })

df_rule_cross_comparison = pd.DataFrame(rule_cross_comparison)
df_rule_cross_comparison = df_rule_cross_comparison.sort_values('BTS Improvement %', ascending=False)
print("\n" + df_rule_cross_comparison.to_string(index=False))

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*80)
print("SUMMARY - FUZZY LOSS ON BTS DATASET")
print("="*80)

print(f"\nKey Findings:")
print(f"  1. Fuzzy loss reduction on GTSRB: {fuzzy_loss_reduction_pct:.2f}%")
print(f"  2. Fuzzy loss reduction on BTS:   {fuzzy_loss_reduction_pct_bts:.2f}%")

if fuzzy_loss_reduction_pct_bts > 0:
    print(f"\n✓ The Fuzzy CBM maintains better logical consistency on OOD data (BTS)")
    print(f"  This suggests that fuzzy loss training improves generalization of logical rules")
else:
    print(f"\n⚠ The Fuzzy CBM shows less improvement on OOD data (BTS) compared to GTSRB")

print(f"\nNote: BTS analysis uses predicted concepts as pseudo ground truth since")
print(f"BTS doesn't have annotated concepts. The fuzzy loss measures internal")
print(f"consistency of predictions with respect to the learned fuzzy rules.")

FUZZY LOSS ANALYSIS ON BTS DATASET

Fuzzy CBM on BTS Dataset:
  Standard BCE Loss:    0.0002277495
  Fuzzy Rules Loss:     1.5115638971
  Total Loss:           1.5117916467

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000383665 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7504361272 (49.65% of fuzzy loss)
    exactly_one_shape             : 0.7500138879 (49.62% of fuzzy loss)
    no_symbols_exactly_two_colours: 0.0103128115 ( 0.68% of fuzzy loss)

Baseline CBM on BTS Dataset:
  Standard BCE Loss:    0.0003441266
  Fuzzy Rules Loss:     1.5071085691
  Total Loss:           1.5074526957

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000000935 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7501800656 (49.78% of fuzzy loss)
    exactly_one_shape             : 0.7503016591 (49.78% of fuzzy loss)
    no_symbols_exactly_two_colours: 0.0059707984 ( 0.40% of fuzzy loss)

BTS DATASET - IMPROVEMENT ANALYSIS

Overall Fuz

In [20]:
# ============================================================================
# RULE VIOLATION ANALYSIS USING RULE CHECKER
# ============================================================================

print("="*80)
print("RULE VIOLATION ANALYSIS - GTSRB AND BTS DATASETS")
print("="*80)

from collections import defaultdict
import pandas as pd

def analyze_rule_violations(concept_predictions, dataset_name, model_name):
    """
    Analyze rule violations for a given set of concept predictions.
    
    Args:
        concept_predictions: numpy array of concept predictions
        dataset_name: name of the dataset (e.g., "GTSRB", "BTS")
        model_name: name of the model (e.g., "Baseline CBM", "Fuzzy CBM")
        
    Returns:
        dict: Dictionary containing violation statistics
    """
    total_violations = 0
    flagged_indices = []
    violated_constraints = []
    constraint_counts = defaultdict(int)
    
    for i, pred in enumerate(concept_predictions):
        violations = rule_checker.check_concept_vector(pred.astype(int), verbose=False, early_stop=False)
        if violations:
            total_violations += 1
            flagged_indices.append(i)
            violated_constraints.append(violations)
            
            # Count each constraint violation
            for violation in violations:
                constraint_counts[violation['constraint']] += 1
    
    violation_rate = (total_violations / len(concept_predictions)) * 100
    
    return {
        'dataset': dataset_name,
        'model': model_name,
        'total_samples': len(concept_predictions),
        'total_violations': total_violations,
        'violation_rate': violation_rate,
        'flagged_indices': flagged_indices,
        'violated_constraints': violated_constraints,
        'constraint_counts': dict(constraint_counts)
    }

# ============================================================================
# ANALYZE VIOLATIONS ON GTSRB
# ============================================================================

print("\n" + "="*80)
print("ANALYZING RULE VIOLATIONS ON GTSRB TEST SET")
print("="*80)

# Analyze Baseline CBM on GTSRB
print("\nAnalyzing Baseline CBM on GTSRB...")
baseline_gtsrb_violations = analyze_rule_violations(
    concept_predictions_baseline, 
    "GTSRB", 
    "Baseline CBM"
)

# Analyze Fuzzy CBM on GTSRB
print("Analyzing Fuzzy CBM on GTSRB...")
fuzzy_gtsrb_violations = analyze_rule_violations(
    concept_predictions_fuzzy, 
    "GTSRB", 
    "Fuzzy CBM"
)

print(f"\nBaseline CBM on GTSRB:")
print(f"  Total samples:        {baseline_gtsrb_violations['total_samples']}")
print(f"  Samples with violations: {baseline_gtsrb_violations['total_violations']}")
print(f"  Violation rate:       {baseline_gtsrb_violations['violation_rate']:.2f}%")

print(f"\nFuzzy CBM on GTSRB:")
print(f"  Total samples:        {fuzzy_gtsrb_violations['total_samples']}")
print(f"  Samples with violations: {fuzzy_gtsrb_violations['total_violations']}")
print(f"  Violation rate:       {fuzzy_gtsrb_violations['violation_rate']:.2f}%")

# ============================================================================
# ANALYZE VIOLATIONS ON BTS
# ============================================================================

print("\n" + "="*80)
print("ANALYZING RULE VIOLATIONS ON BTS TEST SET")
print("="*80)

# Analyze Baseline CBM on BTS
print("\nAnalyzing Baseline CBM on BTS...")
baseline_bts_violations = analyze_rule_violations(
    concept_predictions_baseline_bts, 
    "BTS", 
    "Baseline CBM"
)

# Analyze Fuzzy CBM on BTS
print("Analyzing Fuzzy CBM on BTS...")
fuzzy_bts_violations = analyze_rule_violations(
    concept_predictions_fuzzy_bts, 
    "BTS", 
    "Fuzzy CBM"
)

print(f"\nBaseline CBM on BTS:")
print(f"  Total samples:        {baseline_bts_violations['total_samples']}")
print(f"  Samples with violations: {baseline_bts_violations['total_violations']}")
print(f"  Violation rate:       {baseline_bts_violations['violation_rate']:.2f}%")

print(f"\nFuzzy CBM on BTS:")
print(f"  Total samples:        {fuzzy_bts_violations['total_samples']}")
print(f"  Samples with violations: {fuzzy_bts_violations['total_violations']}")
print(f"  Violation rate:       {fuzzy_bts_violations['violation_rate']:.2f}%")

# ============================================================================
# COMPARISON SUMMARY
# ============================================================================

print("\n" + "="*80)
print("VIOLATION RATE COMPARISON")
print("="*80)

# Create comparison DataFrame
violation_summary = pd.DataFrame([
    {
        'Dataset': 'GTSRB',
        'Model': 'Baseline CBM',
        'Violation Rate (%)': baseline_gtsrb_violations['violation_rate'],
        'Violations': baseline_gtsrb_violations['total_violations'],
        'Total Samples': baseline_gtsrb_violations['total_samples']
    },
    {
        'Dataset': 'GTSRB',
        'Model': 'Fuzzy CBM',
        'Violation Rate (%)': fuzzy_gtsrb_violations['violation_rate'],
        'Violations': fuzzy_gtsrb_violations['total_violations'],
        'Total Samples': fuzzy_gtsrb_violations['total_samples']
    },
    {
        'Dataset': 'BTS',
        'Model': 'Baseline CBM',
        'Violation Rate (%)': baseline_bts_violations['violation_rate'],
        'Violations': baseline_bts_violations['total_violations'],
        'Total Samples': baseline_bts_violations['total_samples']
    },
    {
        'Dataset': 'BTS',
        'Model': 'Fuzzy CBM',
        'Violation Rate (%)': fuzzy_bts_violations['violation_rate'],
        'Violations': fuzzy_bts_violations['total_violations'],
        'Total Samples': fuzzy_bts_violations['total_samples']
    }
])

print("\n" + violation_summary.to_string(index=False))

# Calculate improvements
gtsrb_improvement = baseline_gtsrb_violations['violation_rate'] - fuzzy_gtsrb_violations['violation_rate']
bts_improvement = baseline_bts_violations['violation_rate'] - fuzzy_bts_violations['violation_rate']

print("\n" + "="*80)
print("IMPROVEMENT ANALYSIS")
print("="*80)

print(f"\nGTSRB Dataset:")
print(f"  Baseline violation rate:  {baseline_gtsrb_violations['violation_rate']:.2f}%")
print(f"  Fuzzy CBM violation rate: {fuzzy_gtsrb_violations['violation_rate']:.2f}%")
print(f"  Improvement:              {gtsrb_improvement:.2f} percentage points")
print(f"  Relative improvement:     {(gtsrb_improvement / baseline_gtsrb_violations['violation_rate'] * 100):.2f}%")

print(f"\nBTS Dataset:")
print(f"  Baseline violation rate:  {baseline_bts_violations['violation_rate']:.2f}%")
print(f"  Fuzzy CBM violation rate: {fuzzy_bts_violations['violation_rate']:.2f}%")
print(f"  Improvement:              {bts_improvement:.2f} percentage points")
print(f"  Relative improvement:     {(bts_improvement / baseline_bts_violations['violation_rate'] * 100):.2f}%")

# ============================================================================
# PER-CONSTRAINT VIOLATION ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("PER-CONSTRAINT VIOLATION ANALYSIS - GTSRB")
print("="*80)

# Combine all constraint names from both models
all_constraints_gtsrb = set(baseline_gtsrb_violations['constraint_counts'].keys()) | \
                        set(fuzzy_gtsrb_violations['constraint_counts'].keys())

constraint_comparison_gtsrb = []
for constraint in sorted(all_constraints_gtsrb):
    baseline_count = baseline_gtsrb_violations['constraint_counts'].get(constraint, 0)
    fuzzy_count = fuzzy_gtsrb_violations['constraint_counts'].get(constraint, 0)
    improvement = baseline_count - fuzzy_count
    improvement_pct = (improvement / baseline_count * 100) if baseline_count > 0 else 0
    
    constraint_comparison_gtsrb.append({
        'Constraint': constraint,
        'Baseline Count': baseline_count,
        'Fuzzy CBM Count': fuzzy_count,
        'Improvement': improvement,
        'Improvement %': improvement_pct
    })

df_constraint_gtsrb = pd.DataFrame(constraint_comparison_gtsrb)
df_constraint_gtsrb = df_constraint_gtsrb.sort_values('Improvement', ascending=False)
print("\n" + df_constraint_gtsrb.to_string(index=False))

print("\n" + "="*80)
print("PER-CONSTRAINT VIOLATION ANALYSIS - BTS")
print("="*80)

# Combine all constraint names from both models
all_constraints_bts = set(baseline_bts_violations['constraint_counts'].keys()) | \
                      set(fuzzy_bts_violations['constraint_counts'].keys())

constraint_comparison_bts = []
for constraint in sorted(all_constraints_bts):
    baseline_count = baseline_bts_violations['constraint_counts'].get(constraint, 0)
    fuzzy_count = fuzzy_bts_violations['constraint_counts'].get(constraint, 0)
    improvement = baseline_count - fuzzy_count
    improvement_pct = (improvement / baseline_count * 100) if baseline_count > 0 else 0
    
    constraint_comparison_bts.append({
        'Constraint': constraint,
        'Baseline Count': baseline_count,
        'Fuzzy CBM Count': fuzzy_count,
        'Improvement': improvement,
        'Improvement %': improvement_pct
    })

df_constraint_bts = pd.DataFrame(constraint_comparison_bts)
df_constraint_bts = df_constraint_bts.sort_values('Improvement', ascending=False)
print("\n" + df_constraint_bts.to_string(index=False))

# ============================================================================
# CROSS-DATASET CONSTRAINT COMPARISON
# ============================================================================

print("\n" + "="*80)
print("CROSS-DATASET CONSTRAINT COMPARISON")
print("="*80)

all_constraints = all_constraints_gtsrb | all_constraints_bts
cross_dataset_comparison = []

for constraint in sorted(all_constraints):
    cross_dataset_comparison.append({
        'Constraint': constraint,
        'GTSRB Baseline': baseline_gtsrb_violations['constraint_counts'].get(constraint, 0),
        'GTSRB Fuzzy': fuzzy_gtsrb_violations['constraint_counts'].get(constraint, 0),
        'BTS Baseline': baseline_bts_violations['constraint_counts'].get(constraint, 0),
        'BTS Fuzzy': fuzzy_bts_violations['constraint_counts'].get(constraint, 0),
        'GTSRB Improvement': baseline_gtsrb_violations['constraint_counts'].get(constraint, 0) - 
                            fuzzy_gtsrb_violations['constraint_counts'].get(constraint, 0),
        'BTS Improvement': baseline_bts_violations['constraint_counts'].get(constraint, 0) - 
                          fuzzy_bts_violations['constraint_counts'].get(constraint, 0)
    })

df_cross_dataset = pd.DataFrame(cross_dataset_comparison)
print("\n" + df_cross_dataset.to_string(index=False))

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*80)
print("FINAL SUMMARY - RULE VIOLATION ANALYSIS")
print("="*80)

print(f"\nKey Findings:")
print(f"  1. GTSRB violation rate reduction: {gtsrb_improvement:.2f} percentage points ({(gtsrb_improvement / baseline_gtsrb_violations['violation_rate'] * 100):.2f}%)")
print(f"  2. BTS violation rate reduction:   {bts_improvement:.2f} percentage points ({(bts_improvement / baseline_bts_violations['violation_rate'] * 100):.2f}%)")

if gtsrb_improvement > 0 and bts_improvement > 0:
    print(f"\n✓ Fuzzy CBM shows improved rule adherence on both GTSRB and BTS datasets")
    print(f"  This validates that fuzzy loss training successfully enforces logical constraints")
elif gtsrb_improvement > 0:
    print(f"\n✓ Fuzzy CBM shows improved rule adherence on GTSRB")
    print(f"⚠ Limited improvement on OOD data (BTS)")
else:
    print(f"\n⚠ Unexpected results - review constraint implementation")

print(f"\nMost improved constraints on GTSRB:")
top_improved_gtsrb = df_constraint_gtsrb.head(3)
for _, row in top_improved_gtsrb.iterrows():
    print(f"  - {row['Constraint']}: {row['Improvement']} fewer violations ({row['Improvement %']:.1f}%)")

print(f"\nMost improved constraints on BTS:")
top_improved_bts = df_constraint_bts.head(3)
for _, row in top_improved_bts.iterrows():
    print(f"  - {row['Constraint']}: {row['Improvement']} fewer violations ({row['Improvement %']:.1f}%)")

RULE VIOLATION ANALYSIS - GTSRB AND BTS DATASETS

ANALYZING RULE VIOLATIONS ON GTSRB TEST SET

Analyzing Baseline CBM on GTSRB...
Analyzing Fuzzy CBM on GTSRB...

Baseline CBM on GTSRB:
  Total samples:        12630
  Samples with violations: 15
  Violation rate:       0.12%

Fuzzy CBM on GTSRB:
  Total samples:        12630
  Samples with violations: 12
  Violation rate:       0.10%

ANALYZING RULE VIOLATIONS ON BTS TEST SET

Analyzing Baseline CBM on BTS...
Analyzing Fuzzy CBM on BTS...

Baseline CBM on BTS:
  Total samples:        1016
  Samples with violations: 9
  Violation rate:       0.89%

Fuzzy CBM on BTS:
  Total samples:        1016
  Samples with violations: 11
  Violation rate:       1.08%

VIOLATION RATE COMPARISON

Dataset        Model  Violation Rate (%)  Violations  Total Samples
  GTSRB Baseline CBM            0.118765          15          12630
  GTSRB    Fuzzy CBM            0.095012          12          12630
    BTS Baseline CBM            0.885827           9    