# Causal Transparency Framework - MIMIC-III Example

This notebook demonstrates the application of the Causal Transparency Framework (CTF) to the MIMIC-III clinical dataset for mortality prediction.

## Overview

The CTF provides a structured approach to evaluating and enhancing model transparency through causal reasoning. In this example, we'll:

1. Load and preprocess the MIMIC-III dataset
2. Discover causal structure
3. Train predictive models (causal and standard)
4. Calculate transparency metrics
5. Generate visualizations and reports

This allows us to understand the tradeoffs between model performance and transparency in clinical prediction tasks.

In [None]:
import sys
from notebook_utils import add_ctf_to_path

In [None]:
# Add repository root to path
add_ctf_to_path()

In [None]:
# Import necessary libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Import CTF components
from ctf.framework import CausalTransparencyFramework
from ctf.causal_discovery import CausalDiscovery
from ctf.transparency_metrics import TransparencyMetrics

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

In [None]:
# Path to the processed MIMIC-III dataset
data_path = "../data/mimic_processed_for_ctf.csv"

In [None]:
# Check if the file exists
if not os.path.exists(data_path):
    print(f"Warning: {data_path} not found.")
    print("Please download the processed MIMIC-III dataset or update the path.")
    
    # For demonstration purposes, we'll create a small synthetic dataset
    print("Creating synthetic dataset for demonstration...")
    
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Create synthetic data
    n_samples = 1000
    
    # Create features
    age = np.random.normal(65, 15, n_samples)
    gender_m = np.random.binomial(1, 0.6, n_samples)
    
    # Create SOFA score (influenced by age)
    sofa = 0.2 * age + np.random.normal(0, 3, n_samples)
    sofa = np.clip(sofa, 0, 24).astype(int)
    
    # Create lab values (influenced by SOFA)
    lactate = 0.3 * sofa + np.random.normal(0, 1, n_samples)
    lactate = np.clip(lactate, 0, 15)
    
    creatinine = 0.1 * sofa + 0.01 * age + np.random.normal(0, 0.5, n_samples)
    creatinine = np.clip(creatinine, 0.3, 7)
    
    # Create mortality (target - influenced by age, SOFA, lactate)
    logits = -5 + 0.02 * age + 0.3 * sofa + 0.4 * lactate
    p_mortality = 1 / (1 + np.exp(-logits))
    mortality = np.random.binomial(1, p_mortality)
    
    # Create DataFrame
    df = pd.DataFrame({
        'age': age,
        'gender_m': gender_m,
        'sofa_score': sofa,
        'lactate': lactate,
        'creatinine': creatinine,
        'heart_rate': np.random.normal(85, 20, n_samples),
        'respiratory_rate': np.random.normal(18, 5, n_samples),
        'wbc': np.random.normal(10, 4, n_samples),
        'mortality': mortality
    })
    
    # Save synthetic data
    os.makedirs(os.path.dirname(data_path), exist_ok=True)
    df.to_csv(data_path, index=False)
    print(f"Synthetic dataset created and saved to {data_path}")
else:
    # Load the real dataset
    df = pd.read_csv(data_path)
    print(f"Loaded MIMIC-III dataset with {df.shape[0]} samples and {df.shape[1]} features")

In [None]:
# Explore the dataset
print("Dataset columns:")
print(df.columns.tolist())

print("\nDataset summary:")
df.describe()

In [None]:
# Check target distribution
print("\nTarget (mortality) distribution:")
print(df['mortality'].value_counts(normalize=True))

In [None]:
# Visualize target distribution
plt.figure(figsize=(8, 5))
sns.countplot(x='mortality', data=df)
plt.title('Mortality Distribution')
plt.xlabel('Mortality')
plt.ylabel('Count')
plt.show()

In [None]:
# Initialize the framework
ctf = CausalTransparencyFramework(
    data_path=data_path,
    target_col="mortality",
    output_dir="../results/mimic_iii",
    random_state=42
)

In [None]:
# Add domain knowledge for clinical data
domain_knowledge = {
    "edges": [
        # Clinical knowledge about mortality predictors
        # Format: [source, target, weight]
        ["age", "mortality", 0.8],
        ["sofa_score", "mortality", 0.9],
        ["lactate", "mortality", 0.7],
        ["creatinine", "mortality", 0.6],
        
        # Feature relationships
        ["age", "creatinine", 0.3],
        ["sofa_score", "lactate", 0.5]
    ]
}

In [None]:
# Discover causal structure
G = ctf.discover_causal_structure(domain_knowledge=domain_knowledge)
print(f"Causal graph discovered with {len(G.nodes())} nodes and {len(G.edges())} edges")

In [None]:
# Train predictive models
models = ctf.train_models(test_size=0.2)
print(f"Trained {len(models)} models")

In [None]:
# Calculate transparency metrics
metrics = ctf.calculate_transparency_metrics()
print("Transparency metrics calculated")

In [None]:
# Generate report
report_path = ctf.generate_report()
print(f"CTF report generated at {report_path}")

In [None]:
# Examine Causal Influence Index (CII)
if 'cii' in metrics:
    print("Top 5 features by Causal Influence Index (CII):")
    
    for i, (feature, score) in enumerate(list(metrics['cii'].items())[:5]):
        print(f"{i+1}. {feature}: {score:.4f}")
        
    # Visualize CII
    plt.figure(figsize=(10, 6))
    
    features = list(metrics['cii'].keys())[:10]  # Top 10 features
    scores = [metrics['cii'][f] for f in features]
    
    sns.barplot(x=scores, y=features)
    plt.title('Top Features by Causal Influence Index')
    plt.xlabel('CII')
    plt.tight_layout()
    plt.show()

In [None]:
# Compare model performance
model_names = list(ctf.model_performance.keys())
accuracy = [ctf.model_performance[m]['accuracy'] for m in model_names]
auc = [ctf.model_performance[m]['auc'] for m in model_names]
f1 = [ctf.model_performance[m]['f1'] for m in model_names]
n_features = [ctf.model_performance[m]['n_features'] for m in model_names]

In [None]:
# Create DataFrame
performance_df = pd.DataFrame({
    'Model': model_names,
    'Accuracy': accuracy,
    'AUC': auc,
    'F1': f1,
    'Features': n_features,
    'Type': ['Causal' if m.startswith('causal_') else 'Full' for m in model_names]
})

In [None]:
# Display as table
performance_df.sort_values('AUC', ascending=False)

In [None]:
# Compare transparency metrics across models
if 'te' in metrics and 'cs' in metrics:
    transparency_df = pd.DataFrame({
        'Model': model_names,
        'TE': [metrics['te'][m].get('te', 0) for m in model_names],
        'CS': [metrics['cs'][m].get('overall', 0) for m in model_names],
        'Type': ['Causal' if m.startswith('causal_') else 'Full' for m in model_names]
    })
    
    # Display as table
    transparency_df.sort_values('TE', ascending=False)

In [None]:
# Create a scatterplot of AUC vs. TE
plt.figure(figsize=(10, 6))

In [None]:
# Create a combined DataFrame
combined_df = pd.merge(performance_df, transparency_df, on='Model')

In [None]:
# Plot
sns.scatterplot(data=combined_df, x='AUC', y='TE', hue='Type_x', size='Features', 
                sizes=(100, 400), alpha=0.7)

In [None]:
# Add labels
for i, row in combined_df.iterrows():
    plt.text(row['AUC'] + 0.005, row['TE'] + 0.005, row['Model'])

plt.title('Performance vs. Transparency Tradeoff')
plt.xlabel('AUC (performance)')
plt.ylabel('Transparency Entropy (interpretability)')
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Extract key findings
top_cii_features = list(metrics['cii'].items())[:3]
best_causal_model = max([m for m in model_names if m.startswith('causal_')], 
                        key=lambda m: ctf.model_performance[m]['auc'])
best_full_model = max([m for m in model_names if not m.startswith('causal_')], 
                      key=lambda m: ctf.model_performance[m]['auc'])

causal_auc = ctf.model_performance[best_causal_model]['auc']
full_auc = ctf.model_performance[best_full_model]['auc']

causal_te = metrics['te'][best_causal_model]['te']
full_te = metrics['te'][best_full_model]['te']

causal_cs = metrics['cs'][best_causal_model]['overall']
full_cs = metrics['cs'][best_full_model]['overall']

In [None]:
# Print key findings
print("### Key Clinical Findings ###\n")

print("1. Causal Drivers of Mortality:")
for feature, cii in top_cii_features:
    print(f"   - {feature} (CII: {cii:.4f})")

print(f"\n2. Model Performance Comparison:")
print(f"   - Best causal model ({best_causal_model}): AUC = {causal_auc:.4f}")
print(f"   - Best full model ({best_full_model}): AUC = {full_auc:.4f}")
print(f"   - Performance gap: {(full_auc - causal_auc) * 100:.2f}%")

print(f"\n3. Transparency Metrics:")
print(f"   - Transparency Entropy (TE): Causal = {causal_te:.4f}, Full = {full_te:.4f}")
print(f"   - Counterfactual Stability (CS): Causal = {causal_cs:.4f}, Full = {full_cs:.4f}")

print("\n4. Clinical Implications:")
if causal_auc >= 0.95 * full_auc:
    print("   - The causal model with fewer features performs nearly as well as the full model")
    print("   - This suggests that focusing on key causal factors may be sufficient for clinical use")
else:
    print("   - The full model substantially outperforms the causal model")
    print("   - This suggests that non-causal correlations provide important predictive value")
    
if causal_te > full_te:
    print("   - The causal model offers greater transparency, making it more interpretable for clinicians")
else:
    print("   - Despite using more features, the full model offers better interpretability")
    
if causal_cs > full_cs:
    print("   - The causal model provides more stable predictions under perturbations")
    print("   - This suggests greater reliability when input data has small variations")
else:
    print("   - The full model provides more stable predictions, possibly due to redundant features")

## 5. Conclusion

In this notebook, we applied the Causal Transparency Framework to the MIMIC-III dataset for mortality prediction. The framework provided valuable insights into:

1. The causal structure underlying mortality prediction
2. The key causal drivers of mortality risk
3. The performance-transparency tradeoff between causal and full models
4. The interpretability and stability of different modeling approaches

These insights can guide clinicians and data scientists in developing more transparent and reliable clinical prediction models.