# Causal Conditional Flow Matching: A Deep Dive Tutorial

This notebook provides a comprehensive walkthrough of the Causal Conditional Flow Matching (C-CFM) framework. We'll explore each component in detail, examining the mathematical foundations and visualizing intermediate results.

## Contents

1. **Introduction to Flow Matching**
2. **CTree-Lite: Statistical Regime Detection**
3. **Data Processing and Causal Discovery**
4. **The Masked Velocity Network**
5. **Training with Optimal Transport Paths**
6. **Scenario Generation and Stress Testing**
7. **Full Pipeline Demonstration**

In [None]:
# Setup and imports
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import chi2, rankdata

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("Setup complete!")

---
## 1. Introduction to Flow Matching

Flow Matching is a framework for training continuous normalizing flows (CNFs) without solving ODEs during training. The key insight is that we can learn a velocity field $v_\theta(x, t)$ that transports samples from a simple prior (e.g., Gaussian noise) to the data distribution.

### The Optimal Transport Path

Given:
- $x_0 \sim \mathcal{N}(0, I)$ - noise
- $x_1 \sim p_{data}$ - real data

The OT interpolation path is:
$$x_t = (1-t) \cdot x_0 + t \cdot x_1$$

And the target velocity along this path is:
$$u_t = \frac{dx_t}{dt} = x_1 - x_0$$

In [None]:
# Visualize the OT path
def visualize_ot_path(n_samples=5, n_steps=10):
    """Visualize how samples flow from noise to data."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Generate noise and "data" (for visualization, use 2D)
    x_0 = np.random.randn(n_samples, 2)  # Noise
    x_1 = np.random.randn(n_samples, 2) * 0.5 + np.array([3, 3])  # "Data"
    
    # Time steps
    t_values = np.linspace(0, 1, n_steps)
    
    # Plot paths
    ax = axes[0]
    colors = plt.cm.viridis(np.linspace(0, 1, n_samples))
    for i in range(n_samples):
        path = [(1-t) * x_0[i] + t * x_1[i] for t in t_values]
        path = np.array(path)
        ax.plot(path[:, 0], path[:, 1], 'o-', color=colors[i], alpha=0.7, markersize=4)
        ax.scatter([x_0[i, 0]], [x_0[i, 1]], s=100, c='blue', marker='o', zorder=5)
        ax.scatter([x_1[i, 0]], [x_1[i, 1]], s=100, c='red', marker='x', zorder=5)
    ax.set_title('OT Interpolation Paths\n(Blue=Noise, Red=Data)')
    ax.set_xlabel('Dimension 1')
    ax.set_ylabel('Dimension 2')
    
    # Plot velocity field
    ax = axes[1]
    for i in range(n_samples):
        velocity = x_1[i] - x_0[i]  # Constant along path
        for t in t_values[::2]:
            x_t = (1-t) * x_0[i] + t * x_1[i]
            ax.arrow(x_t[0], x_t[1], velocity[0]*0.1, velocity[1]*0.1,
                    head_width=0.1, head_length=0.05, fc=colors[i], ec=colors[i], alpha=0.5)
    ax.set_title('Target Velocity Field\n$u_t = x_1 - x_0$')
    ax.set_xlabel('Dimension 1')
    ax.set_ylabel('Dimension 2')
    
    # Plot marginals at different times
    ax = axes[2]
    x_0_many = np.random.randn(500, 2)
    x_1_many = np.random.randn(500, 2) * 0.5 + np.array([3, 3])
    
    for t, color in zip([0, 0.5, 1], ['blue', 'purple', 'red']):
        x_t = (1-t) * x_0_many + t * x_1_many
        ax.scatter(x_t[:, 0], x_t[:, 1], alpha=0.3, s=5, c=color, label=f't={t}')
    ax.legend()
    ax.set_title('Distribution Evolution\nt=0 (Blue) → t=1 (Red)')
    ax.set_xlabel('Dimension 1')
    ax.set_ylabel('Dimension 2')
    
    plt.tight_layout()
    plt.show()

visualize_ot_path()

---
## 2. CTree-Lite: Statistical Regime Detection

Unlike standard decision trees (CART) that split based on impurity reduction, CTree-Lite uses **statistical significance testing** to determine splits. This is crucial for financial data where noise can lead to overfitting.

### The Strasser-Weber Framework

For each candidate variable $X_j$, we test:
$$H_0: X_j \perp Y \text{ (independence)}$$

The test statistic is computed using rank transformations for robustness:

1. Compute ranks: $h(X_j) = \text{rank}(X_j)$
2. Linear statistic: $T = Y^\top \cdot h(X_j)$
3. Standardize: $S = (T - \mu_T)^\top \Sigma_T^{-1} (T - \mu_T)$
4. P-value: $p = 1 - F_{\chi^2}(S)$

We only split if $p < \alpha$ (typically 0.05).

In [None]:
from core.ctree import CTree

# Generate data with clear regime structure
np.random.seed(42)
n = 300

# Create features
X = np.random.randn(n, 3)
X[:, 0] = np.linspace(-2, 2, n)  # Sorted for visualization

# Create response that depends on X[:, 0]
# Regime 1: X[:, 0] < 0 → Y centered at -1
# Regime 2: X[:, 0] >= 0 → Y centered at 1
Y = np.where(X[:, 0] < 0, -1 + 0.3*np.random.randn(n), 1 + 0.3*np.random.randn(n))
Y = Y.reshape(-1, 1)

# Fit CTree
ctree = CTree(alpha=0.05, min_split=20)
ctree.fit(X, Y, feature_names=['X1', 'X2', 'X3'])

# Get predictions
regimes = ctree.predict(X)

print("CTree-Lite Results:")
print(f"Number of regimes: {ctree.n_regimes}")
print(f"\nTree structure:")
print(ctree.print_tree())

In [None]:
# Visualize regime detection
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Data colored by true regime
ax = axes[0]
true_regime = (X[:, 0] >= 0).astype(int)
scatter = ax.scatter(X[:, 0], Y.ravel(), c=true_regime, cmap='coolwarm', alpha=0.5)
ax.axvline(x=0, color='black', linestyle='--', label='True boundary')
ax.set_xlabel('X1')
ax.set_ylabel('Y')
ax.set_title('Data with True Regimes')
ax.legend()

# Plot 2: Data colored by detected regime
ax = axes[1]
scatter = ax.scatter(X[:, 0], Y.ravel(), c=regimes, cmap='coolwarm', alpha=0.5)
ax.set_xlabel('X1')
ax.set_ylabel('Y')
ax.set_title(f'CTree Detected Regimes ({ctree.n_regimes} found)')

# Plot 3: Show independence test p-values
ax = axes[2]
# Compute p-values for each variable
p_values = [ctree._test_independence(X[:, j], Y) for j in range(3)]
bars = ax.bar(['X1', 'X2', 'X3'], p_values, color=['red' if p < 0.05 else 'blue' for p in p_values])
ax.axhline(y=0.05, color='red', linestyle='--', label='α = 0.05')
ax.set_ylabel('P-value')
ax.set_title('Independence Test P-values\n(Red = Significant)')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\nP-values: X1={p_values[0]:.4f}, X2={p_values[1]:.4f}, X3={p_values[2]:.4f}")
print(f"Only X1 is significant (p < 0.05), so it's used for splitting!")

---
## 3. Data Processing and Causal Discovery

The ETL pipeline performs several critical steps:

1. **Stationarity Check** (ADF test) - Non-stationary series are differenced
2. **Imputation** (Cubic spline) - Missing values are interpolated
3. **Normalization** (Z-score) - Required for stable training
4. **Causal Discovery** (LiNGAM) - Determines variable ordering

### Why Causal Ordering Matters

In our framework:
- **Slow variables** (macro: GDP, CPI) are upstream (cause)
- **Fast variables** (market: VIX, returns) are downstream (effect)

The masked network ensures slow variables cannot be influenced by fast variables!

In [None]:
from core.etl import DataProcessor, validate_for_ode
from examples.data_fetcher import create_sample_dataset

# Load sample data
X_raw, fast_vars, slow_vars = create_sample_dataset()

print(f"Raw data shape: {X_raw.shape}")
print(f"\nFast variables (market): {fast_vars}")
print(f"\nSlow variables (macro): {slow_vars}")

In [None]:
# Process data
processor = DataProcessor(
    adf_threshold=0.05,
    ctree_alpha=0.10,
    ctree_min_split=50
)

topology = processor.fit_transform(
    X_raw,
    fast_vars=fast_vars,
    slow_vars=slow_vars
)

print("\n" + "="*50)
print("Data Topology Results")
print("="*50)
print(f"Processed samples: {topology.X_processed.shape[0]}")
print(f"Number of variables: {topology.X_processed.shape[1]}")
print(f"Detected regimes: {topology.n_regimes}")
print(f"\nCausal ordering (first 10):")
for i, name in enumerate(topology.variable_names[:10]):
    var_type = "SLOW" if i < len(topology.slow_indices) else "FAST"
    print(f"  {i}: {name} ({var_type})")

In [None]:
# Visualize the data processing effects
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Raw data distribution
ax = axes[0, 0]
ax.hist(X_raw[:, 0], bins=50, alpha=0.7, label='Var 0 (raw)')
ax.hist(X_raw[:, -1], bins=50, alpha=0.7, label='Var -1 (raw)')
ax.set_xlabel('Value')
ax.set_ylabel('Count')
ax.set_title('Raw Data Distribution')
ax.legend()

# Plot 2: Processed (normalized) data
ax = axes[0, 1]
ax.hist(topology.X_processed[:, 0], bins=50, alpha=0.7, label='Var 0 (normalized)')
ax.hist(topology.X_processed[:, -1], bins=50, alpha=0.7, label='Var -1 (normalized)')
ax.set_xlabel('Value (Z-score)')
ax.set_ylabel('Count')
ax.set_title('Normalized Data Distribution')
ax.legend()

# Plot 3: Regime distribution
ax = axes[1, 0]
unique, counts = np.unique(topology.regimes, return_counts=True)
ax.bar(unique, counts, color=plt.cm.viridis(np.linspace(0, 1, len(unique))))
ax.set_xlabel('Regime')
ax.set_ylabel('Count')
ax.set_title(f'Regime Distribution ({topology.n_regimes} regimes)')

# Plot 4: Correlation heatmap (showing causal structure)
ax = axes[1, 1]
corr = np.corrcoef(topology.X_processed.T)
im = ax.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_title('Correlation Matrix (Causal Order)')
ax.set_xlabel('Variable Index')
ax.set_ylabel('Variable Index')
# Mark slow/fast boundary
n_slow = len(topology.slow_indices)
ax.axhline(y=n_slow-0.5, color='yellow', linewidth=2, label='Slow/Fast boundary')
ax.axvline(x=n_slow-0.5, color='yellow', linewidth=2)
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

---
## 4. The Masked Velocity Network

The velocity network $v_\theta(x_t, t, \text{regime})$ has a special structure:

### Architecture
- **Input**: State $x_t$, time embedding, regime embedding
- **Backbone**: 4-layer Residual MLP with SiLU activation
- **Conditioning**: FiLM layers for regime/time modulation
- **Masking**: MADE-style masks enforce causal structure

### The Causal Mask

The mask $M$ ensures the Jacobian $\partial v / \partial x$ is lower-triangular:

$$M_{ij} = \begin{cases} 1 & \text{if } \text{order}(j) \leq \text{order}(i) \\ 0 & \text{otherwise} \end{cases}$$

This means:
- Slow variables can influence fast variables ✓
- Fast variables **cannot** influence slow variables ✗

In [None]:
from core.network import VelocityNetwork, create_causal_mask

# Create a small network for visualization
state_dim = 6
causal_order = np.array([3, 4, 5, 0, 1, 2])  # Slow (3,4,5) before Fast (0,1,2)

# Visualize the causal mask
mask = create_causal_mask(dim=state_dim, causal_order=causal_order)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot the mask
ax = axes[0]
im = ax.imshow(mask.numpy(), cmap='Blues')
ax.set_xticks(range(state_dim))
ax.set_yticks(range(state_dim))
ax.set_xticklabels([f'In {i}\n(order {causal_order[i]})' for i in range(state_dim)])
ax.set_yticklabels([f'Out {i}\n(order {causal_order[i]})' for i in range(state_dim)])
ax.set_title('Causal Mask\n(1 = connection allowed)')

for i in range(state_dim):
    for j in range(state_dim):
        color = 'white' if mask[i, j] > 0.5 else 'black'
        ax.text(j, i, int(mask[i, j].item()), ha='center', va='center', color=color)

# Show causal graph
ax = axes[1]
# Draw nodes
slow_x = [1, 2, 3]
fast_x = [1, 2, 3]
slow_y = [2, 2, 2]
fast_y = [0, 0, 0]

ax.scatter(slow_x, slow_y, s=500, c='blue', zorder=5)
ax.scatter(fast_x, fast_y, s=500, c='red', zorder=5)

# Labels
for i, x in enumerate(slow_x):
    ax.text(x, slow_y[0], f'Slow\n{i}', ha='center', va='center', color='white', fontweight='bold')
for i, x in enumerate(fast_x):
    ax.text(x, fast_y[0], f'Fast\n{i}', ha='center', va='center', color='white', fontweight='bold')

# Arrows (slow → fast only)
for sx in slow_x:
    for fx in fast_x:
        ax.annotate('', xy=(fx, 0.3), xytext=(sx, 1.7),
                   arrowprops=dict(arrowstyle='->', color='green', lw=2, alpha=0.5))

ax.set_xlim(0, 4)
ax.set_ylim(-0.5, 2.5)
ax.set_title('Causal Graph\n(Green arrows = allowed influence)')
ax.axis('off')

plt.tight_layout()
plt.show()

print("Key insight: Fast variables (bottom) receive information from Slow variables (top),")
print("but NOT the other way around!")

In [None]:
# Create and inspect the velocity network
net = VelocityNetwork(
    state_dim=topology.X_processed.shape[1],
    hidden_dim=128,
    n_regimes=topology.n_regimes,
    n_layers=4,
    causal_order=topology.causal_order
)

print("Velocity Network Architecture:")
print(net)
print(f"\nTotal parameters: {sum(p.numel() for p in net.parameters()):,}")

---
## 5. Training with Optimal Transport Paths

### The CFM Training Objective

$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t, r) - u_t \|^2 \right]$$

Where:
- $t \sim \text{Uniform}(0, 1)$
- $x_0 \sim \mathcal{N}(0, I)$
- $x_1 \sim p_{data}$
- $x_t = (1-t) x_0 + t x_1$
- $u_t = x_1 - x_0$

This is a simple **regression problem** - no ODE solving during training!

In [None]:
from core.trainer import FlowMatchingTrainer, TrainingConfig

# Configure training
config = TrainingConfig(
    lr=1e-3,
    batch_size=64,
    n_epochs=50,  # Reduced for demo
    warmup_epochs=5,
    validate_every=10
)

# Create trainer
trainer = FlowMatchingTrainer(net, topology, config)

print("Training Configuration:")
print(f"  Learning rate: {config.lr}")
print(f"  Batch size: {config.batch_size}")
print(f"  Epochs: {config.n_epochs}")
print(f"  Device: {trainer.device}")

In [None]:
# Demonstrate a single training step
print("Single Training Step Demonstration:")
print("=" * 50)

# Get a batch
x1_batch, regime_batch = next(iter(trainer.train_loader))
x1_batch = x1_batch.to(trainer.device)
regime_batch = regime_batch.to(trainer.device)

print(f"\n1. Sample data batch: x1 shape = {x1_batch.shape}")

# Sample OT path
x_t, t, u_t, x_0 = trainer.sample_ot_path(x1_batch, len(x1_batch))

print(f"2. Sample noise: x0 shape = {x_0.shape}")
print(f"3. Sample time: t shape = {t.shape}, range = [{t.min():.3f}, {t.max():.3f}]")
print(f"4. Compute interpolation: x_t = (1-t)*x0 + t*x1")
print(f"5. Target velocity: u_t = x1 - x0, shape = {u_t.shape}")

# Forward pass
v_pred = net(x_t, t, regime_batch)
print(f"6. Predict velocity: v_pred shape = {v_pred.shape}")

# Compute loss
loss = ((v_pred - u_t) ** 2).mean()
print(f"7. MSE loss: {loss.item():.4f}")

In [None]:
# Train the model
print("\nStarting training...")
results = trainer.train()

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
ax.plot(results['loss_history'], label='Training Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss Curve')
ax.legend()

ax = axes[1]
val_epochs = list(range(config.validate_every-1, config.n_epochs, config.validate_every))
ax.plot(val_epochs, results['val_loss_history'], 'o-', label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Validation Loss Curve')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\nFinal validation loss: {results['final_metrics']['val_loss']:.4f}")

---
## 6. Scenario Generation and Stress Testing

At inference time, we generate samples by solving the ODE:

$$\frac{dx}{dt} = v_\theta(x, t, r), \quad x(0) \sim \mathcal{N}(0, I)$$

### Guided Generation (Stress Testing)

To steer a specific variable toward a target value, we modify the velocity:

$$v_{guided}(x, t) = v_\theta(x, t) + \lambda \cdot \frac{\text{target} - x_i}{1 - t + \epsilon}$$

This "nudges" variable $i$ toward the target while the causal mask ensures consistent propagation to other variables.

In [None]:
# Check if torchdiffeq is available
try:
    from core.solver import ODESolver, ScenarioGenerator
    SOLVER_AVAILABLE = True
except ImportError:
    print("Note: torchdiffeq not installed. Solver demos will be skipped.")
    print("Install with: pip install torchdiffeq")
    SOLVER_AVAILABLE = False

In [None]:
if SOLVER_AVAILABLE:
    # Create solver and generator
    solver = ODESolver(net, topology)
    generator = ScenarioGenerator(solver, topology)
    
    # Generate baseline scenarios
    print("Generating baseline scenarios...")
    baseline = generator.baseline(n_scenarios=500)
    
    # Generate stressed scenarios
    print("Generating stressed scenarios (3 std shock to first variable)...")
    stressed = generator.shock(
        variable=0,  # First fast variable
        magnitude=3.0,
        n_scenarios=500
    )
    
    print(f"\nBaseline shape: {baseline.shape}")
    print(f"Stressed shape: {stressed.shape}")

In [None]:
if SOLVER_AVAILABLE:
    # Visualize the difference
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    var_names = topology.variable_names
    
    for i, ax in enumerate(axes.flat):
        if i >= len(var_names):
            ax.axis('off')
            continue
            
        ax.hist(baseline[:, i], bins=30, alpha=0.5, label='Baseline', density=True)
        ax.hist(stressed[:, i], bins=30, alpha=0.5, label='Stressed', density=True)
        ax.set_xlabel('Value')
        ax.set_ylabel('Density')
        ax.set_title(var_names[i])
        ax.legend()
    
    plt.suptitle('Baseline vs Stressed Scenario Distributions', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\nMean Shift (Stressed - Baseline):")
    for i, name in enumerate(var_names[:10]):
        shift = stressed[:, i].mean() - baseline[:, i].mean()
        print(f"  {name}: {shift:+.3f}")

---
## 7. Full Pipeline Demonstration

Now let's use the high-level `CausalFlowMatcher` API to demonstrate the complete workflow.

In [None]:
from core import CausalFlowMatcher

# Create fresh instance
cfm = CausalFlowMatcher(
    hidden_dim=128,
    n_layers=4,
    dropout=0.0
)

print("CausalFlowMatcher initialized")
print(cfm)

In [None]:
# Fit to data
cfm.fit(
    X_raw,
    fast_vars=fast_vars,
    slow_vars=slow_vars
)

print(f"\nFitted! Detected {cfm.n_regimes} regimes")
print(f"Variables: {cfm.variable_names}")

In [None]:
# Train
results = cfm.train(
    n_epochs=100,
    lr=1e-3,
    batch_size=64,
    validate_every=20
)

print(f"\nTraining complete!")
print(f"Final loss: {results['final_metrics']['val_loss']:.4f}")

In [None]:
if SOLVER_AVAILABLE:
    # Generate and compare scenarios
    print("Generating scenarios...")
    
    baseline = cfm.sample(n_samples=1000)
    
    # Create stress scenarios
    scenarios = {}
    scenarios['baseline'] = baseline
    scenarios['vix_shock'] = cfm.shock(variable=4, magnitude=3.0, n_samples=1000)  # VIX +3std
    scenarios['rate_shock'] = cfm.shock(variable=-2, magnitude=2.0, n_samples=1000)  # Rate +2std
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Pairplot for first two fast variables
    ax = axes[0, 0]
    for name, data in scenarios.items():
        ax.scatter(data[:, 0], data[:, 1], alpha=0.3, s=10, label=name)
    ax.set_xlabel(cfm.variable_names[0])
    ax.set_ylabel(cfm.variable_names[1])
    ax.legend()
    ax.set_title('Fast Variable Scatter')
    
    # Distribution comparison for first variable
    ax = axes[0, 1]
    for name, data in scenarios.items():
        ax.hist(data[:, 0], bins=50, alpha=0.5, label=name, density=True)
    ax.set_xlabel(cfm.variable_names[0])
    ax.legend()
    ax.set_title(f'{cfm.variable_names[0]} Distribution')
    
    # Box plot comparison
    ax = axes[1, 0]
    data_for_box = [scenarios['baseline'][:, 0], 
                    scenarios['vix_shock'][:, 0],
                    scenarios['rate_shock'][:, 0]]
    ax.boxplot(data_for_box, labels=['Baseline', 'VIX Shock', 'Rate Shock'])
    ax.set_ylabel(cfm.variable_names[0])
    ax.set_title('Scenario Comparison')
    
    # Correlation differences
    ax = axes[1, 1]
    corr_base = np.corrcoef(scenarios['baseline'].T)
    corr_vix = np.corrcoef(scenarios['vix_shock'].T)
    corr_diff = corr_vix - corr_base
    im = ax.imshow(corr_diff[:10, :10], cmap='RdBu_r', vmin=-0.5, vmax=0.5)
    ax.set_title('Correlation Change (VIX Shock - Baseline)')
    plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    plt.show()

---
## Summary

In this tutorial, we covered:

1. **Flow Matching**: Learning to transport noise to data via velocity fields
2. **CTree-Lite**: Statistical regime detection using significance testing
3. **Data Processing**: Stationarity, normalization, and causal discovery
4. **Masked Networks**: Enforcing causal structure in the neural network
5. **Training**: Simulation-free CFM with OT paths
6. **Generation**: ODE integration with guided stress testing

### Key Takeaways

- C-CFM generates economically consistent scenarios by respecting causal structure
- The mask ensures macro shocks propagate to markets (not vice versa)
- Guided generation allows targeted stress testing
- The framework is simulation-free during training (efficient!)

### Next Steps

- Try with real data from FRED/Yahoo Finance
- Experiment with different network architectures
- Implement custom regime detection based on domain knowledge
- Use for risk management and scenario analysis