# Getting Started with Stable-CART Methods

This notebook demonstrates how to use stable-cart methods to analyze prediction stability using bootstrap variance measurement. We compare different stable tree implementations against a standard CART baseline.

## Why Stability Matters
Standard decision trees are **notoriously unstable** - small changes in training data can lead to completely different trees and predictions. Stable-CART methods modify the tree-building process to create trees that are **more consistent** across different training runs, which is crucial for:
- **Model reliability** in production
- **Interpretability** and trust  
- **Reduced overfitting** to specific training samples

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler

from stable_cart import LessGreedyHybridTree, BootstrapVariancePenalizedTree, RobustPrefixHonestTree

plt.style.use('seaborn-v0_8')
np.random.seed(42)

print("üéØ Stable-CART Stability Analysis")
print("Comparing stable tree methods vs standard CART")

## Dataset: Digits Binary Classification

We use the digits dataset with binary classification (digit 0 vs all others) because:
- **High dimensionality** (64 features) where stability can be important
- **Real-world data** with natural complexity
- **Clear class separation** allowing measurement of stability effects

In [None]:
# Load and prepare data
data = load_digits()
X, y = data.data, (data.target == 0).astype(int)

# Standardize features for stable-cart methods
scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

print(f"Dataset characteristics:")
print(f"  Features: {X.shape[1]}")
print(f"  Training samples: {X_train.shape[0]}")
print(f"  Test samples: {X_test.shape[0]}")
print(f"  Class distribution: {dict(zip(*np.unique(y_test, return_counts=True)))}")
print(f"  Task: Binary classification (digit 0 vs others)")

## Stability Measurement Methodology

We measure **prediction stability** using bootstrap variance:
1. Train the same model on multiple bootstrap samples of the training data
2. Make predictions on the same test set with each trained model
3. Calculate the variance of predictions for each test point
4. Lower variance = more stable predictions across training runs

This simulates the real-world scenario where slight differences in training data (new samples, missing values, etc.) can lead to different model behavior.

In [None]:
def measure_prediction_stability(model_class, X_train, y_train, X_test, y_test, model_params, n_bootstrap=15):
    """
    Measure prediction stability via bootstrap variance.
    
    Returns:
        dict with variance statistics and performance metrics
    """
    n_test = X_test.shape[0]
    predictions = np.zeros((n_bootstrap, n_test))
    aucs = []
    
    for i in range(n_bootstrap):
        # Bootstrap sample
        n_samples = X_train.shape[0]
        bootstrap_idx = np.random.choice(n_samples, n_samples, replace=True)
        X_boot = X_train[bootstrap_idx]
        y_boot = y_train[bootstrap_idx]
        
        # Train model
        params_copy = model_params.copy()
        params_copy.pop('random_state', None)  # Avoid conflicts
        model = model_class(**params_copy, random_state=i)
        model.fit(X_boot, y_boot)
        
        # Predictions
        y_pred = model.predict(X_test)
        predictions[i] = y_pred
        
        # Performance
        if hasattr(model, 'predict_proba'):
            y_proba = model.predict_proba(X_test)[:, 1]
            aucs.append(roc_auc_score(y_test, y_proba))
    
    # Calculate stability metrics
    point_variances = np.var(predictions, axis=0)
    
    return {
        'mean_variance': np.mean(point_variances),
        'median_variance': np.median(point_variances),
        'max_variance': np.max(point_variances),
        'point_variances': point_variances,
        'predictions': predictions,
        'aucs': aucs,
        'auc_mean': np.mean(aucs),
        'auc_std': np.std(aucs)
    }

print("‚úÖ Stability measurement function ready")
print("üìä Will use 15 bootstrap samples per model for analysis")

## Model Configurations

We test different stable-cart methods against a standard CART baseline:
- **CART Baseline**: Standard decision tree for comparison
- **LessGreedyHybrid**: Uses data partitioning for more cautious splitting
- **RobustPrefixHonest**: Uses consensus and honest estimation for stability

In [None]:
# Model configurations for comparison
model_configs = {
    'CART_Baseline': {
        'class': DecisionTreeClassifier,
        'params': {
            'max_depth': 10,
            'min_samples_leaf': 1,
            'random_state': 42
        },
        'description': 'Standard decision tree baseline'
    },
    
    'LessGreedy_Method': {
        'class': LessGreedyHybridTree,
        'params': {
            'task': 'classification',
            'max_depth': 12,
            'min_samples_leaf': 2,
            'split_frac': 0.95,     # 95% data for splitting
            'val_frac': 0.03,       # 3% for validation
            'est_frac': 0.02,       # 2% for estimation
            'enable_oblique_splits': False,
            'enable_lookahead': False,
            'random_state': 42
        },
        'description': 'Less greedy approach with data partitioning'
    },
    
    'RobustPrefix_Method': {
        'class': RobustPrefixHonestTree,
        'params': {
            'task': 'classification',
            'max_depth': 10,
            'min_samples_leaf': 2,
            'top_levels': 2,          # Robust prefix levels
            'consensus_samples': 3,   # Consensus mechanism
            'val_frac': 0.05,
            'est_frac': 0.03,
            'random_state': 42
        },
        'description': 'Robust prefix with consensus and honest estimation'
    }
}

print(f"üìã Configured {len(model_configs)} models for comparison:")
for name, config in model_configs.items():
    print(f"   {name}: {config['description']}")

## Running the Stability Analysis

We now run the stability analysis, measuring both:
- **Single-run performance** (accuracy, AUC)
- **Bootstrap stability** (variance across training runs)

‚è±Ô∏è *This may take a few minutes as we train multiple models*

In [None]:
print("üîÑ Running stability analysis...")
print("This trains 15 bootstrap models per configuration")

results = {}

for name, config in model_configs.items():
    print(f"\nüìä Analyzing {name}...")
    
    # Single run performance
    model = config['class'](**config['params'])
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    
    # Get performance metrics
    accuracy = accuracy_score(y_test, y_pred)
    balanced_acc = balanced_accuracy_score(y_test, y_pred)
    
    if hasattr(model, 'predict_proba'):
        y_proba = model.predict_proba(X_test)[:, 1]
        auc = roc_auc_score(y_test, y_proba)
    else:
        auc = balanced_acc
    
    unique_preds = len(np.unique(y_pred))
    
    print(f"   Single run: AUC={auc:.3f}, Accuracy={accuracy:.3f}, Unique predictions={unique_preds}")
    
    if unique_preds > 1 and balanced_acc > 0.8:  # Working model
        # Measure stability
        print(f"   Measuring stability (15 bootstrap samples)...")
        stability = measure_prediction_stability(
            config['class'], X_train, y_train, X_test, y_test, config['params']
        )
        
        results[name] = {
            'auc': auc,
            'accuracy': accuracy,
            'balanced_acc': balanced_acc,
            'stability': stability,
            'description': config['description'],
            'working': True
        }
        
        print(f"   ‚úÖ Variance: {stability['mean_variance']:.4f}, AUC std: ¬±{stability['auc_std']:.3f}")
    else:
        print(f"   ‚ùå Model not working properly")
        results[name] = {
            'working': False,
            'description': config['description']
        }

print(f"\n‚úÖ Analysis complete! Found {sum(1 for r in results.values() if r.get('working', False))} working models")

## Results: Stability vs Accuracy Analysis

Let's analyze the trade-offs between prediction stability and accuracy performance.

In [None]:
# Filter working models
working_results = {name: res for name, res in results.items() if res.get('working', False)}

if len(working_results) < 2:
    print("‚ùå Need at least 2 working models for comparison")
else:
    print(f"üèÜ STABILITY vs ACCURACY ANALYSIS")
    print("=" * 60)
    
    # Baseline metrics
    baseline = working_results['CART_Baseline']
    baseline_auc = baseline['auc']
    baseline_variance = baseline['stability']['mean_variance']
    
    print(f"\nüìä BASELINE (CART):")
    print(f"   AUC: {baseline_auc:.3f}")
    print(f"   Prediction Variance: {baseline_variance:.4f}")
    
    # Stable model comparison
    stable_models = {name: res for name, res in working_results.items() if name != 'CART_Baseline'}
    
    print(f"\nüîÑ STABLE-CART MODELS:")
    
    comparison_data = []
    
    for name, res in stable_models.items():
        variance = res['stability']['mean_variance']
        auc = res['auc']
        
        # Calculate improvements
        variance_improvement = (1 - variance / baseline_variance) * 100
        performance_ratio = (auc / baseline_auc) * 100
        
        comparison_data.append({
            'Model': name.replace('_', ' '),
            'AUC': auc,
            'Variance': variance,
            'Variance_Improvement': variance_improvement,
            'Performance_Ratio': performance_ratio,
            'Description': res['description']
        })
        
        print(f"\n   {name}:")
        print(f"      AUC: {auc:.3f} ({performance_ratio:.1f}% of baseline)")
        print(f"      Variance: {variance:.4f} ({variance_improvement:+.1f}% vs baseline)")
        print(f"      {res['description']}")
        
        if variance_improvement > 25:
            print(f"      üéâ OUTSTANDING variance reduction!")
        elif variance_improvement > 15:
            print(f"      üèÜ EXCELLENT variance reduction!")
        elif variance_improvement > 5:
            print(f"      ‚úÖ GOOD variance reduction")
        else:
            print(f"      ‚ö†Ô∏è Marginal improvement")
    
    # Create comparison DataFrame
    df_comparison = pd.DataFrame(comparison_data)
    df_comparison = df_comparison.sort_values('Variance_Improvement', ascending=False)
    
    print(f"\nüìä SUMMARY TABLE:")
    print(df_comparison[['Model', 'AUC', 'Variance_Improvement', 'Performance_Ratio']].to_string(index=False, float_format='%.1f'))

## Visualization: Stability vs Accuracy Trade-off

In [None]:
# Create visualizations
if len(working_results) >= 2:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#92AA83']
    
    # 1. Stability vs Accuracy scatter plot
    ax1 = axes[0]
    
    for i, (name, res) in enumerate(working_results.items()):
        stability_score = 1 / (1 + res['stability']['mean_variance'])
        auc = res['auc']
        
        ax1.scatter(stability_score, auc, s=300, color=colors[i % len(colors)], 
                   alpha=0.8, edgecolors='black', linewidth=2, label=name.replace('_', ' '))
    
    ax1.set_xlabel('Stability Score (Higher = More Stable)', fontsize=12)
    ax1.set_ylabel('AUC Score (Higher = Better)', fontsize=12)
    ax1.set_title('Stability vs Accuracy Trade-off', fontsize=14, fontweight='bold')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, alpha=0.3)
    
    # 2. Variance comparison bar chart
    ax2 = axes[1]
    names = [name.replace('_', '\n') for name in working_results.keys()]
    variances = [res['stability']['mean_variance'] for res in working_results.values()]
    
    bars = ax2.bar(range(len(names)), variances, color=colors[:len(names)], alpha=0.8, edgecolor='black')
    ax2.set_ylabel('Prediction Variance (Lower = Better)', fontsize=12)
    ax2.set_title('Model Prediction Variance', fontsize=14, fontweight='bold')
    ax2.set_xticks(range(len(names)))
    ax2.set_xticklabels(names, rotation=45, ha='right')
    
    # Add values on bars
    for bar, var in zip(bars, variances):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(variances)*0.01,
                f'{var:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # 3. Performance comparison
    ax3 = axes[2]
    aucs = [res['auc'] for res in working_results.values()]
    
    bars = ax3.bar(range(len(names)), aucs, color=colors[:len(names)], alpha=0.8, edgecolor='black')
    ax3.set_ylabel('AUC Score (Higher = Better)', fontsize=12)
    ax3.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
    ax3.set_xticks(range(len(names)))
    ax3.set_xticklabels(names, rotation=45, ha='right')
    ax3.set_ylim(min(aucs) * 0.98, max(aucs) * 1.02)
    
    # Add values on bars
    for bar, auc in zip(bars, aucs):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + (max(aucs) - min(aucs))*0.01,
                f'{auc:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    plt.suptitle('Stable-CART: Stability vs Accuracy Analysis', 
                 fontsize=16, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.show()