#### RQ1: How does the logical consistency of the CBM change when introducing the requirements via Fuzzy Loss

- This RQ aims to evaluate the logical consistency of the CBM and validate the claim that the FuzzyLoss CBM learned the rules and adheres to them.
- Load the models
- Load the GTSRB Dataset
- Load the Raw Fuzzy Loss
- Calculate Fuzzy Loss on the test set. (GTSRB, BTS)
- Calculate Rule violations on the test set. (GTSRB, BTS)
- Calculate the HD on the test set (GTSRB, BTS)
- Calculate CL-Matrix

In [21]:
import sys
from pathlib import Path

current_dir = Path.cwd()
parent_dir = current_dir.parent
sys.path.insert(0, str(f"{parent_dir}/src"))

from models.architectures import CBMSequentialEfficientNetFCN
from train_cbm import cbm_load_config

In [22]:
baseline_cbm_config = cbm_load_config(Path("../files/configs/GTSRB_CBM_config_loading.yaml"))
baseline_cbm = CBMSequentialEfficientNetFCN(baseline_cbm_config)
fuzzy_cbm_config = cbm_load_config(Path("../files/configs/GTSRB_CBM_config_best_trial_loading.yaml"))
fuzzy_cbm = CBMSequentialEfficientNetFCN(fuzzy_cbm_config)

Directory 'experiments/20251009_105749' created successfully.
Directory 'experiments/20251009_105749' created successfully.


In [23]:
# loading the concrete models
baseline_cbm_concept_predictor_path = Path("../experiments/20251008_155607/models/20251008_155607_concept_predictor_best_model.pt")
baseline_cbm_label_predictor_path = Path("../experiments/baseline_cbm/models/20251001_083717_label_predictor_best_model.pt")
fuzzy_cbm_concept_predictor_path = Path("../experiments/20251009_120258/models/20251009_120258_concept_predictor_best_model.pt")
fuzzy_cbm_label_precitor_path = Path("../experiments/fuzzy_CBM/models/20251001_113637_label_predictor_best_model.pt")

In [24]:
import torch

# Load the baseline model components
baseline_cbm.concept_predictor.load_state_dict(
    torch.load(baseline_cbm_concept_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)
baseline_cbm.label_predictor.load_state_dict(
    torch.load(baseline_cbm_label_predictor_path, map_location=baseline_cbm_config.device, weights_only=True)
)

# Load the fuzzy model components
fuzzy_cbm.concept_predictor.load_state_dict(
    torch.load(fuzzy_cbm_concept_predictor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)
fuzzy_cbm.label_predictor.load_state_dict(
    torch.load(fuzzy_cbm_label_precitor_path, map_location=fuzzy_cbm_config.device, weights_only=True)
)

# Set models to evaluation mode
baseline_cbm.eval()
fuzzy_cbm.eval()

print("Model components loaded successfully!")
print(f"Baseline CBM loaded from {baseline_cbm_concept_predictor_path.parent}")
print(f"Fuzzy CBM loaded from {fuzzy_cbm_concept_predictor_path.parent}")

Model components loaded successfully!
Baseline CBM loaded from ../experiments/20251008_155607/models
Fuzzy CBM loaded from ../experiments/20251009_120258/models


In [25]:
dataset_factory = baseline_cbm_config.dataset.factory(
    seed=baseline_cbm_config.seed, config=baseline_cbm_config.dataset
).set_dataloaders()
train_loader = dataset_factory.train_dataloader
val_loader = dataset_factory.val_dataloader
test_loader = dataset_factory.test_dataloader

In [26]:
# calculating the accuracy on the test of the model
from models.trainer.cbm_trainer import CBMTrainer
baseline_cbm_trainer = CBMTrainer(
    config=baseline_cbm_config,
    model=baseline_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)
fuzzy_cbm_trainer = CBMTrainer(
    config=fuzzy_cbm_config,
    model=fuzzy_cbm,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
)

In [27]:
# getting the fuzzy loss function
neutral_fuzzy_loss = fuzzy_cbm_trainer.concept_predictor_trainer.criterion

In [28]:
# getting all the predictions on the test set from the fuzzy and the baseline cbm
all_logits_fuzzy, concept_predictions_fuzzy, concept_ground_truth_fuzzy, concept_probabilities_fuzzy = fuzzy_cbm_trainer.concept_predictor_trainer.get_predictions(dataloader=test_loader)
neutral_fuzzy_loss(torch.tensor(all_logits_fuzzy), torch.tensor(concept_ground_truth_fuzzy))

Getting predictions: 100%|██████████| 99/99 [00:34<00:00,  2.83it/s]


tensor(1.5155)

In [29]:
all_logits_baseline, concept_predictions_baseline, concept_ground_truth_baseline, concept_probabilities_baseline = baseline_cbm_trainer.concept_predictor_trainer.get_predictions(dataloader=test_loader)


Getting predictions: 100%|██████████| 99/99 [00:38<00:00,  2.55it/s]


In [30]:
# calculating the fuzzy neutral fuzzy loss on the test set
import torch
import numpy as np
from tqdm import tqdm

def compute_fuzzy_loss_metrics(model, dataloader, fuzzy_loss_fn, device):
    """
    Compute fuzzy loss metrics for logical consistency evaluation.
    
    Args:
        model: The CBM model to evaluate
        dataloader: DataLoader containing the test data
        fuzzy_loss_fn: The CustomFuzzyLoss function with fuzzy rules
        device: Device to run computations on
        
    Returns:
        dict: Dictionary containing overall and per-rule fuzzy loss metrics
    """
    model.eval()
    
    # Accumulators for losses
    total_standard_loss = 0.0
    total_fuzzy_loss = 0.0
    rule_losses = {name: 0.0 for name in fuzzy_loss_fn.fuzzy_rules.keys()}
    num_samples = 0
    
    with torch.no_grad():
        for batch_idx, (idx, inputs, (concepts, _)) in enumerate(tqdm(dataloader, desc="Computing fuzzy loss")):
            inputs = inputs.to(device)
            concepts = concepts.to(device)
            
            # Forward pass
            outputs = model.concept_predictor(inputs)
            
            # Compute fuzzy loss (this updates last_standard_loss, last_fuzzy_loss, etc.)
            total_loss = fuzzy_loss_fn(outputs, concepts)
            
            # Accumulate losses
            batch_size = inputs.size(0)
            num_samples += batch_size
            
            total_standard_loss += fuzzy_loss_fn.last_standard_loss.item() * batch_size
            total_fuzzy_loss += fuzzy_loss_fn.last_fuzzy_loss.item() * batch_size
            
            # Accumulate individual rule losses
            for rule_name, loss_val in fuzzy_loss_fn.last_individual_losses.items():
                rule_losses[rule_name] += loss_val.item() * batch_size
    
    # Calculate averages
    avg_standard_loss = total_standard_loss / num_samples
    avg_fuzzy_loss = total_fuzzy_loss / num_samples
    avg_total_loss = avg_standard_loss + avg_fuzzy_loss
    avg_rule_losses = {name: loss / num_samples for name, loss in rule_losses.items()}
    
    # Calculate relative contribution of each rule to total fuzzy loss
    rule_contributions = {}
    if avg_fuzzy_loss > 0:
        rule_contributions = {
            name: (loss / num_samples) / avg_fuzzy_loss 
            for name, loss in rule_losses.items()
        }
    
    return {
        'standard_loss': avg_standard_loss,
        'fuzzy_loss': avg_fuzzy_loss,
        'total_loss': avg_total_loss,
        'rule_losses': avg_rule_losses,
        'rule_contributions': rule_contributions,
        'num_samples': num_samples
    }

# Compute fuzzy loss metrics for both models on the test set
print("Computing fuzzy loss metrics for Baseline CBM...")
baseline_fuzzy_metrics = compute_fuzzy_loss_metrics(
    baseline_cbm, 
    test_loader, 
    neutral_fuzzy_loss,
    baseline_cbm_config.device
)

print("\nComputing fuzzy loss metrics for Fuzzy CBM...")
fuzzy_cbm_metrics = compute_fuzzy_loss_metrics(
    fuzzy_cbm, 
    test_loader, 
    neutral_fuzzy_loss,
    fuzzy_cbm_config.device
)

# Display results
print("\n" + "="*80)
print("FUZZY LOSS ANALYSIS - LOGICAL CONSISTENCY METRICS")
print("="*80)

print(f"\nBaseline CBM (trained without fuzzy loss):")
print(f"  Standard BCE Loss:    {baseline_fuzzy_metrics['standard_loss']:.10f}")
print(f"  Fuzzy Rules Loss:     {baseline_fuzzy_metrics['fuzzy_loss']:.10f}")
print(f"  Total Loss:           {baseline_fuzzy_metrics['total_loss']:.10f}")
print(f"\n  Individual Rule Violations:")
for rule_name, loss in baseline_fuzzy_metrics['rule_losses'].items():
    contribution = baseline_fuzzy_metrics['rule_contributions'].get(rule_name, 0) * 100
    print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")

print(f"\nFuzzy CBM (trained with fuzzy loss):")
print(f"  Standard BCE Loss:    {fuzzy_cbm_metrics['standard_loss']:.10f}")
print(f"  Fuzzy Rules Loss:     {fuzzy_cbm_metrics['fuzzy_loss']:.10f}")
print(f"  Total Loss:           {fuzzy_cbm_metrics['total_loss']:.10f}")
print(f"\n  Individual Rule Violations:")
for rule_name, loss in fuzzy_cbm_metrics['rule_losses'].items():
    contribution = fuzzy_cbm_metrics['rule_contributions'].get(rule_name, 0) * 100
    print(f"    {rule_name:30s}: {loss:.10f} ({contribution:5.2f}% of fuzzy loss)")

# Calculate improvements
print("\n" + "="*80)
print("IMPROVEMENT ANALYSIS")
print("="*80)

fuzzy_loss_reduction = baseline_fuzzy_metrics['fuzzy_loss'] - fuzzy_cbm_metrics['fuzzy_loss']
fuzzy_loss_reduction_pct = (fuzzy_loss_reduction / baseline_fuzzy_metrics['fuzzy_loss']) * 100

print(f"\nOverall Fuzzy Loss Reduction:")
print(f"  Absolute: {fuzzy_loss_reduction:.10f}")
print(f"  Relative: {fuzzy_loss_reduction_pct:.2f}%")

print(f"\nPer-Rule Improvements:")
for rule_name in baseline_fuzzy_metrics['rule_losses'].keys():
    baseline_loss = baseline_fuzzy_metrics['rule_losses'][rule_name]
    fuzzy_loss = fuzzy_cbm_metrics['rule_losses'][rule_name]
    improvement = baseline_loss - fuzzy_loss
    improvement_pct = (improvement / baseline_loss) * 100 if baseline_loss > 0 else 0
    print(f"  {rule_name:30s}: {improvement:.10f} ({improvement_pct:+6.2f}%)")

# Visualization of rule contributions
print("\n" + "="*80)
print("RULE CONTRIBUTION COMPARISON")
print("="*80)

import pandas as pd

# Create comparison dataframe
rule_comparison = []
for rule_name in baseline_fuzzy_metrics['rule_losses'].keys():
    rule_comparison.append({
        'Rule': rule_name,
        'Baseline Loss': baseline_fuzzy_metrics['rule_losses'][rule_name],
        'Fuzzy CBM Loss': fuzzy_cbm_metrics['rule_losses'][rule_name],
        'Improvement': baseline_fuzzy_metrics['rule_losses'][rule_name] - fuzzy_cbm_metrics['rule_losses'][rule_name],
        'Improvement %': ((baseline_fuzzy_metrics['rule_losses'][rule_name] - fuzzy_cbm_metrics['rule_losses'][rule_name]) 
                         / baseline_fuzzy_metrics['rule_losses'][rule_name] * 100) if baseline_fuzzy_metrics['rule_losses'][rule_name] > 0 else 0
    })

df_comparison = pd.DataFrame(rule_comparison)
df_comparison = df_comparison.sort_values('Improvement %', ascending=False)
print(df_comparison.to_string(index=False))

# Summary statistics
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Number of fuzzy rules evaluated: {len(baseline_fuzzy_metrics['rule_losses'])}")
print(f"Number of test samples: {baseline_fuzzy_metrics['num_samples']}")
print(f"\nLogical Consistency Score (lower is better):")
print(f"  Baseline CBM:  {baseline_fuzzy_metrics['fuzzy_loss']:.10f}")
print(f"  Fuzzy CBM:     {fuzzy_cbm_metrics['fuzzy_loss']:.10f}")
print(f"  Improvement:   {fuzzy_loss_reduction_pct:.2f}%")

Computing fuzzy loss metrics for Baseline CBM...


Computing fuzzy loss:   0%|          | 0/99 [00:00<?, ?it/s]

Computing fuzzy loss: 100%|██████████| 99/99 [00:47<00:00,  2.07it/s]



Computing fuzzy loss metrics for Fuzzy CBM...


Computing fuzzy loss: 100%|██████████| 99/99 [00:38<00:00,  2.56it/s]


FUZZY LOSS ANALYSIS - LOGICAL CONSISTENCY METRICS

Baseline CBM (trained without fuzzy loss):
  Standard BCE Loss:    0.0022116409
  Fuzzy Rules Loss:     1.5134086369
  Total Loss:           1.5156202778

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000008134 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7500880942 (49.56% of fuzzy loss)
    exactly_one_shape             : 0.7501135680 (49.56% of fuzzy loss)
    between_two_and_three_numbers : 0.0000612766 ( 0.00% of fuzzy loss)
    no_symbols_exactly_two_colours: 0.0005755201 ( 0.04% of fuzzy loss)

Fuzzy CBM (trained with fuzzy loss):
  Standard BCE Loss:    0.0019492534
  Fuzzy Rules Loss:     1.5135019425
  Total Loss:           1.5154511959

  Individual Rule Violations:
    at_most_one_border_colour     : 0.0000001020 ( 0.00% of fuzzy loss)
    exactly_one_main_colour       : 0.7501242044 (49.56% of fuzzy loss)
    exactly_one_shape             : 0.7501746979 (49.57% of fuzzy loss)
    bet




In [31]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_prediction_level_accuracy(model, dataloader, device):
    """Evaluate model accuracy at the prediction level (all concepts must be correct)"""
    model.eval()
    all_correct_predictions = 0
    all_predictions = 0 
    
    # For collecting per-concept metrics
    concept_true = []
    concept_pred = []
    
    # For collecting per-sample metrics (all concepts must match)
    sample_correct = []
    
    with torch.no_grad():
        for batch_idx, (idx, inputs, (concepts, _)) in enumerate(tqdm(dataloader, desc="Evaluating")):
            inputs = inputs.to(device)
            concepts = concepts.to(device)
            
            # Forward pass
            outputs = model.concept_predictor(inputs)
            
            # Get binary predictions
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            # Add to concept-level metrics collection
            concept_true.append(concepts.cpu().numpy())
            concept_pred.append(preds.cpu().numpy())
            
            # Calculate prediction-level accuracy (all concepts must be correct)
            batch_size = concepts.size(0)
            for i in range(batch_size):
                # Check if all concepts match for this sample
                sample_match = torch.all(preds[i] == concepts[i]).item()
                sample_correct.append(sample_match)
                
            # Update counters
            all_predictions += batch_size
            all_correct_predictions += torch.sum(torch.all(preds == concepts, dim=1)).item()
    
    # Calculate metrics
    prediction_accuracy = all_correct_predictions / all_predictions
    
    # Calculate per-concept metrics
    concept_true = np.vstack(concept_true)
    concept_pred = np.vstack(concept_pred)
    
    # Per-concept accuracy (as currently reported)
    concept_accuracy = np.mean((concept_true == concept_pred).flatten())
    
    # Per-concept precision, recall, F1
    precision = precision_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    recall = recall_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    f1 = f1_score(concept_true.flatten(), concept_pred.flatten(), zero_division=0)
    
    return {
        'prediction_accuracy': prediction_accuracy,
        'concept_accuracy': concept_accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'sample_correct': sample_correct
    }

# Run evaluation for both models
print("Evaluating baseline CBM...")
baseline_results = evaluate_prediction_level_accuracy(baseline_cbm, test_loader, baseline_cbm_config.device)

print("Evaluating fuzzy CBM...")
fuzzy_results = evaluate_prediction_level_accuracy(fuzzy_cbm, test_loader, fuzzy_cbm_config.device)

# Print results
print("\n===== RESULTS =====")
print(f"Baseline CBM:")
print(f"  Per-prediction accuracy (all concepts correct): {baseline_results['prediction_accuracy']:.4f}")
print(f"  Per-concept accuracy (as reported before): {baseline_results['concept_accuracy']:.4f}")
print(f"  Precision: {baseline_results['precision']:.4f}")
print(f"  Recall: {baseline_results['recall']:.4f}")
print(f"  F1 Score: {baseline_results['f1']:.4f}")

print(f"\nFuzzy CBM:")
print(f"  Per-prediction accuracy (all concepts correct): {fuzzy_results['prediction_accuracy']:.4f}")
print(f"  Per-concept accuracy (as reported before): {fuzzy_results['concept_accuracy']:.4f}")
print(f"  Precision: {fuzzy_results['precision']:.4f}")
print(f"  Recall: {fuzzy_results['recall']:.4f}")
print(f"  F1 Score: {fuzzy_results['f1']:.4f}")

# Calculate improvement
prediction_improvement = fuzzy_results['prediction_accuracy'] - baseline_results['prediction_accuracy']
concept_improvement = fuzzy_results['concept_accuracy'] - baseline_results['concept_accuracy']

print(f"\nImprovement with Fuzzy Loss:")
print(f"  Per-prediction accuracy: {prediction_improvement:.4f} ({prediction_improvement*100:.2f}%)")
print(f"  Per-concept accuracy: {concept_improvement:.4f} ({concept_improvement*100:.2f}%)")

Evaluating baseline CBM...


Evaluating: 100%|██████████| 99/99 [00:58<00:00,  1.70it/s]


Evaluating fuzzy CBM...


Evaluating: 100%|██████████| 99/99 [00:36<00:00,  2.71it/s]



===== RESULTS =====
Baseline CBM:
  Per-prediction accuracy (all concepts correct): 0.9880
  Per-concept accuracy (as reported before): 0.9995
  Precision: 0.9981
  Recall: 0.9970
  F1 Score: 0.9975

Fuzzy CBM:
  Per-prediction accuracy (all concepts correct): 0.9869
  Per-concept accuracy (as reported before): 0.9995
  Precision: 0.9985
  Recall: 0.9968
  F1 Score: 0.9977

Improvement with Fuzzy Loss:
  Per-prediction accuracy: -0.0010 (-0.10%)
  Per-concept accuracy: 0.0000 (0.00%)


In [32]:
# loading the predefined rule checker that constructs a graph of the concept groups and assigns rules on top of them
from rule_eval import construct_full_graph
rule_checker = construct_full_graph()

- All concepts
   - All colors
      - Main colors
         - Border colors
         - Arrow symbols
   - All shapes
   - All symbols
      - Number symbols
      - General symbols
      - Curve symbols
      - Regulatory signs


In [None]:
# evaluating the rules violations on the test set of the fuzzy CBM
counter_graph_semantic_relation_fuzzy = 0
flagged_indices_graph_semantic_relation_fuzzy = []
constraints_violated_relation_fuzzy = []
for i, pred in enumerate(concept_predictions_fuzzy):
    violated = rule_checker.check_concept_vector(pred, verbose=True, early_stop=True)
    if violated:
        counter_graph_semantic_relation_fuzzy+= 1
        flagged_indices_graph_semantic_relation_fuzzy.append(i)#
        constraints_violated_relation_fuzzy.append(violated)

# Sanity check for semantic graph
print(f"Number of flagged indices in semantic graph: {len(flagged_indices_graph_semantic_relation_fuzzy)}")

Constraint violated: General concept invariant
Violating concept names: ['shape_round']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: General concept invariant
Violating concept names: ['shape_round', 'number_0']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: General concept invariant
Violating concept names: ['shape_round']
Constraint violated: General concept invariant
Violating concept names: ['shape_round']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: Symbols constraint
Violating concept names: ['number_0', 'number_8', 'symbol_car', 'symbol_diagonal_stripes']
Constraint violated: Shape constraint
Violating concept names: ['shape_round', 'shape_square']
Constraint violated: Shape constraint
Violating concept name

In [34]:
# evaluating the rule violations on the test set of the baseline CBM
counter_graph_semantic_relation_baseline = 0
flagged_indices_graph_semantic_relation_baseline = []
constraints_violated_relation_baseline = []
for i, pred in enumerate(concept_predictions_baseline):
    violated = rule_checker.check_concept_vector(pred, verbose=True, early_stop=True)
    if violated:
        counter_graph_semantic_relation_fuzzy+= 1
        flagged_indices_graph_semantic_relation_baseline.append(i)
        constraints_violated_relation_baseline.append(violated)

# Sanity check for semantic graph
print(f"Number of flagged indices in semantic graph: {len(flagged_indices_graph_semantic_relation_baseline)}")

Violating concept names: ['symbol_animal', 'symbol_construction_site']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_round']
Constraint violated: Shape constraint
Violating concept names: ['shape_round', 'shape_square']
Constraint violated: General concept invariant
Violating concept names: ['shape_triangular']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_round']
Violating concept names: ['symbol_animal', 'symbol_double_curve']
Constraint violated: One number
Violating concept names: ['number_0']
Constraint violated: General concept invariant
Violating concept names: ['main_color_blue', 'shape_triangular']
Constraint violated: Main color constraint
Violating concept names: ['main_color_white', 'main_color_blue']
Constraint violated: Shape constraint
Violating concept names: ['shape_round', 'shape_triangular']
Constraint violated: One number
Violating concept names: ['number_0']
Co