# Counterfactual RUL Estimation Using Causal Transformers
## A Structural Intervention Framework for Milling Tool Degradation (PHM 2010 Benchmark)

**Author:** Muhammad Umar  
**Affiliation:** University of Ulsan, South Korea

---

This notebook implements:
1. **Baseline Transformer** - High-accuracy RUL prediction
2. **Causal-Structural Transformer** - Explainable RUL with counterfactual capabilities
3. **What-If Analysis** - Evaluate alternative machining strategies

## 1. Setup and Configuration

In [None]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader

# Import our modules
from phm2010_data_loader import PHM2010DataLoader
from causal_transformer_rul import (
    BaselineTransformer,
    CausalStructuralTransformer,
    PHM2010Dataset,
    train_model,
    evaluate_model,
    visualize_causal_decomposition,
    perform_counterfactual_analysis,
    visualize_counterfactual_results,
    plot_training_curves,
    plot_predictions_comparison,
    device
)

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Configuration
CONFIG = {
    # Data parameters
    'data_path': r'E:\Collaboration Work\With Farooq\phm dataset\PHM Challange 2010 Milling',
    'sequence_length': 20,
    'stride': 1,
    'train_ratio': 0.7,
    'val_ratio': 0.15,
    
    # Model parameters
    'd_model': 128,
    'nhead': 8,
    'num_layers': 4,
    'dropout': 0.1,
    
    # Training parameters
    'batch_size': 32,
    'num_epochs': 100,
    'learning_rate': 0.001,
    'patience': 15,
}

print("Configuration loaded:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 2. Load and Prepare Data

In [None]:
# Initialize data loader
loader = PHM2010DataLoader(CONFIG['data_path'])

# Prepare data
data_dict = loader.prepare_data(
    sequence_length=CONFIG['sequence_length'],
    stride=CONFIG['stride'],
    train_ratio=CONFIG['train_ratio'],
    val_ratio=CONFIG['val_ratio']
)

In [None]:
# Extract data
train_seq, train_labels, train_cond, train_hi = data_dict['train']
val_seq, val_labels, val_cond, val_hi = data_dict['val']
test_seq, test_labels, test_cond, test_hi = data_dict['test']

input_dim = train_seq.shape[2]
num_conditions = len(data_dict['condition_mapping'])

print(f"Input dimension: {input_dim}")
print(f"Number of conditions: {num_conditions}")

# Create datasets
train_dataset = PHM2010Dataset(train_seq, train_labels, train_cond, train_hi)
val_dataset = PHM2010Dataset(val_seq, val_labels, val_cond, val_hi)
test_dataset = PHM2010Dataset(test_seq, test_labels, test_cond, test_hi)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 3. Train Baseline Transformer

In [None]:
# Initialize baseline model
baseline_model = BaselineTransformer(
    input_dim=input_dim,
    d_model=CONFIG['d_model'],
    nhead=CONFIG['nhead'],
    num_layers=CONFIG['num_layers'],
    dropout=CONFIG['dropout']
).to(device)

print(f"Baseline model parameters: {sum(p.numel() for p in baseline_model.parameters()):,}")

In [None]:
# Train baseline model
baseline_train_losses, baseline_val_losses = train_model(
    baseline_model,
    train_loader,
    val_loader,
    num_epochs=CONFIG['num_epochs'],
    learning_rate=CONFIG['learning_rate'],
    model_type='baseline',
    patience=CONFIG['patience']
)

In [None]:
# Evaluate baseline model
baseline_preds, baseline_actuals, baseline_metrics = evaluate_model(
    baseline_model, test_loader, model_type='baseline'
)

## 4. Train Causal-Structural Transformer

In [None]:
# Initialize causal model
causal_model = CausalStructuralTransformer(
    input_dim=input_dim,
    num_conditions=num_conditions,
    d_model=CONFIG['d_model'],
    nhead=CONFIG['nhead'],
    num_layers=CONFIG['num_layers'],
    dropout=CONFIG['dropout'],
    enforce_physics=True
).to(device)

print(f"Causal model parameters: {sum(p.numel() for p in causal_model.parameters()):,}")

In [None]:
# Train causal model
causal_train_losses, causal_val_losses = train_model(
    causal_model,
    train_loader,
    val_loader,
    num_epochs=CONFIG['num_epochs'],
    learning_rate=CONFIG['learning_rate'],
    model_type='causal',
    patience=CONFIG['patience']
)

In [None]:
# Evaluate causal model
causal_preds, causal_actuals, causal_metrics = evaluate_model(
    causal_model, test_loader, model_type='causal'
)

## 5. Model Comparison

In [None]:
# Compare models
comparison_df = pd.DataFrame({
    'Model': ['Baseline Transformer', 'Causal-Structural Transformer'],
    'MAE (cuts)': [baseline_metrics['mae'], causal_metrics['mae']],
    'RMSE (cuts)': [baseline_metrics['rmse'], causal_metrics['rmse']],
    'MAPE (%)': [baseline_metrics['mape'], causal_metrics['mape']]
})

print("\nModel Performance Comparison:")
print(comparison_df.to_string(index=False))

# Visualize comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
comparison_df.set_index('Model')[['MAE (cuts)', 'RMSE (cuts)']].plot(kind='bar', ax=ax)
ax.set_ylabel('Error (cuts)')
ax.set_title('Model Performance Comparison', fontsize=14, weight='bold')
ax.legend()
ax.grid(alpha=0.3)
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
# Plot training curves
plot_training_curves(
    (baseline_train_losses, baseline_val_losses),
    (causal_train_losses, causal_val_losses)
)

In [None]:
# Plot predictions comparison
plot_predictions_comparison(baseline_preds, causal_preds, baseline_actuals)

## 6. Causal Decomposition Analysis

In [None]:
# Visualize causal decomposition
base_rul, cond_eff, hi_eff, total_rul, true_rul = visualize_causal_decomposition(
    causal_model,
    test_loader,
    num_samples=6
)

# Print statistics
print("\nAverage Causal Contributions:")
print(f"  Base RUL: {np.mean(base_rul):.2f} cuts")
print(f"  Condition Effect: {np.mean(cond_eff):.2f} cuts")
print(f"  HI Effect: {np.mean(hi_eff):.2f} cuts")
print(f"  Total RUL: {np.mean(total_rul):.2f} cuts")
print(f"  True RUL: {np.mean(true_rul):.2f} cuts")

## 7. Counterfactual Analysis - What-If Scenarios

In [None]:
# Perform counterfactual analysis
cf_results = perform_counterfactual_analysis(
    causal_model,
    test_loader,
    num_samples=20
)

# Display results
print("\nCounterfactual Analysis Results:")
print(cf_results.head(15))

In [None]:
# Visualize counterfactual results
visualize_counterfactual_results(cf_results)

In [None]:
# Analyze key insights
condition_changes = cf_results[cf_results['new_condition'].notna()]

if len(condition_changes) > 0:
    avg_delta_by_condition = condition_changes.groupby('new_condition')['delta_rul'].mean()
    
    print("\nAverage RUL Change by Condition:")
    for cond, delta in avg_delta_by_condition.items():
        orig_cond = data_dict['reverse_mapping'].get(int(cond), int(cond))
        print(f"  Condition {orig_cond}: {delta:+.2f} cuts")
    
    best_condition = avg_delta_by_condition.idxmax()
    best_gain = avg_delta_by_condition.max()
    orig_best = data_dict['reverse_mapping'].get(int(best_condition), int(best_condition))
    
    print(f"\n→ Best Operating Condition: {orig_best}")
    print(f"  Average RUL gain: {best_gain:.2f} cuts")

# HI reduction impact
hi_reduction = cf_results[cf_results.get('intervention_type') == 'HI_reduction_20%']
if len(hi_reduction) > 0:
    avg_hi_benefit = hi_reduction['delta_rul'].mean()
    print(f"\n→ 20% Wear Reduction Benefit:")
    print(f"  Average RUL gain: {avg_hi_benefit:.2f} cuts")
    print(f"  This demonstrates the value of improved maintenance")

## 8. Interactive Counterfactual Query

In [None]:
# Interactive example: Single sample counterfactual
def query_counterfactual(sample_idx=0, new_condition=None, hi_reduction_percent=0):
    """
    Interactive function to query counterfactual predictions
    
    Args:
        sample_idx: Index of test sample
        new_condition: New operating condition (1, 4, or 6), None to keep original
        hi_reduction_percent: Percentage reduction in wear (0-100)
    """
    # Get sample
    batch = next(iter(test_loader))
    seq = batch['sequence'][sample_idx:sample_idx+1].to(device)
    orig_cond = batch['condition'][sample_idx:sample_idx+1].to(device)
    hi = batch['health_indicator'][sample_idx:sample_idx+1].to(device)
    true_rul = batch['label'][sample_idx].item()
    
    # Map condition
    if new_condition is not None:
        new_cond_mapped = data_dict['condition_mapping'].get(new_condition, 0)
        new_cond_tensor = torch.tensor([new_cond_mapped], device=device)
    else:
        new_cond_tensor = None
    
    # Apply HI reduction
    if hi_reduction_percent > 0:
        new_hi = hi * (1 - hi_reduction_percent / 100)
    else:
        new_hi = None
    
    # Get counterfactual prediction
    cf_result = causal_model.counterfactual_predict(
        seq, orig_cond, hi,
        new_condition=new_cond_tensor,
        new_hi=new_hi
    )
    
    # Display results
    orig_cond_name = data_dict['reverse_mapping'][orig_cond.item()]
    
    print("="*60)
    print("COUNTERFACTUAL QUERY RESULTS")
    print("="*60)
    print(f"\nOriginal Scenario:")
    print(f"  Condition: {orig_cond_name}")
    print(f"  Health Indicator: {hi.item():.4f}")
    print(f"  Predicted RUL: {cf_result['factual_rul'].item():.2f} cuts")
    print(f"  True RUL: {true_rul:.2f} cuts")
    
    print(f"\nIntervention:")
    if new_condition is not None:
        print(f"  → Changed condition to: {new_condition}")
    if hi_reduction_percent > 0:
        print(f"  → Reduced wear by: {hi_reduction_percent}%")
    
    print(f"\nCounterfactual Scenario:")
    print(f"  Predicted RUL: {cf_result['counterfactual_rul'].item():.2f} cuts")
    print(f"  RUL Change: {cf_result['delta_rul'].item():+.2f} cuts")
    
    if cf_result['delta_rul'].item() > 0:
        print(f"\n✓ This intervention would EXTEND tool life!")
    else:
        print(f"\n✗ This intervention would REDUCE tool life.")
    
    print("="*60)
    
    return cf_result

# Example queries
print("Example 1: What if we switch to Condition 4?")
result1 = query_counterfactual(sample_idx=0, new_condition=4)

print("\n" + "="*60 + "\n")

print("Example 2: What if we reduce wear by 30%?")
result2 = query_counterfactual(sample_idx=1, hi_reduction_percent=30)

## 9. Summary and Conclusions

In [None]:
print("="*70)
print("SUMMARY OF RESULTS")
print("="*70)

print("\n1. MODEL PERFORMANCE")
print("-" * 70)
print(comparison_df.to_string(index=False))

print("\n2. CAUSAL INSIGHTS")
print("-" * 70)
print(f"Average Base RUL: {np.mean(base_rul):.2f} cuts")
print(f"Average Condition Effect: {np.mean(cond_eff):.2f} cuts")
print(f"Average HI Effect: {np.mean(hi_eff):.2f} cuts")

print("\n3. NOVEL CONTRIBUTIONS")
print("-" * 70)
print("✓ Structurally interpretable RUL decomposition")
print("✓ Counterfactual 'What-If' analysis capability")
print("✓ Physics-informed constraints (wear reduces RUL)")
print("✓ Decision support for process optimization")

print("\n4. PRACTICAL APPLICATIONS")
print("-" * 70)
print("→ Optimize operating conditions for tool life extension")
print("→ Quantify benefits of maintenance interventions")
print("→ Support data-driven manufacturing decisions")
print("→ Enable proactive rather than reactive PHM")

print("\n" + "="*70)

## 10. Save Models and Results

In [None]:
# Save models
torch.save(baseline_model.state_dict(), 'baseline_transformer_final.pth')
torch.save(causal_model.state_dict(), 'causal_transformer_final.pth')

# Save results
results_df = pd.DataFrame({
    'baseline_predictions': baseline_preds,
    'causal_predictions': causal_preds,
    'actual_rul': baseline_actuals
})
results_df.to_csv('rul_predictions.csv', index=False)

# Save counterfactual results
cf_results.to_csv('counterfactual_scenarios.csv', index=False)

print("Models and results saved successfully!")