# 03 - Regime Classification

This notebook explores the multi-scale path state regime classification.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.data.synthetic_data import SyntheticDataGenerator
from src.regimes.regime_classifier import RegimeClassifier
from src.regimes.path_states import PathStateClassifier
from src.visualization.styles import set_publication_style, PlotStyles

set_publication_style()
np.random.seed(42)

In [None]:
# Load data
generator = SyntheticDataGenerator(seed=42)
data = generator.generate(n_months=732)

volatility = data['volatility']
factors = data['factors']
regimes = data['regimes']
if isinstance(regimes, pd.DataFrame):
    regimes = regimes['regime']

## 1. State Space Visualization

In [None]:
# Scatter plot of state space with regime coloring
fig, ax = plt.subplots(figsize=(10, 7))

for regime in PlotStyles.REGIME_ORDER:
    mask = regimes == regime
    if mask.sum() > 0:
        ax.scatter(
            volatility.loc[mask, 'sigma_1m'] * 100,
            volatility.loc[mask, 'rho_sigma'],
            c=PlotStyles.get_regime_color(regime),
            label=regime,
            alpha=0.6,
            s=25,
            edgecolors='none'
        )

# Add threshold lines
vol_33 = np.percentile(volatility['sigma_1m'].dropna(), 33) * 100
vol_67 = np.percentile(volatility['sigma_1m'].dropna(), 67) * 100

ax.axvline(x=vol_33, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)
ax.axvline(x=vol_67, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)
ax.axhline(y=0.8, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)
ax.axhline(y=1.5, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)

ax.set_xlabel('One-Month Realized Volatility (%)')
ax.set_ylabel('Volatility Ratio (σ_1w / σ_3m)')
ax.set_title('State Space with Regime Classification')
ax.legend(loc='upper right')

plt.tight_layout()
plt.show()

## 2. Regime Frequencies and Transitions

In [None]:
# Regime frequencies
freq = regimes.value_counts()
pct = freq / len(regimes) * 100

freq_df = pd.DataFrame({
    'Observations': freq,
    'Frequency (%)': pct
})

print("Regime Frequencies:")
freq_df.loc[PlotStyles.REGIME_ORDER]

In [None]:
# Transition matrix
classifier = RegimeClassifier()
trans_mat = classifier.compute_transition_matrix(volatility)

print("\nTransition Matrix (row = from, col = to):")
trans_mat.round(2)

In [None]:
# Plot transition matrix as heatmap
import seaborn as sns

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(trans_mat * 100, annot=True, fmt='.1f', cmap='Blues', ax=ax)
ax.set_title('Regime Transition Probabilities (%)')
ax.set_xlabel('To Regime')
ax.set_ylabel('From Regime')
plt.tight_layout()
plt.show()

## 3. Regime Time Series

In [None]:
# Plot regime time series
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

# Panel 1: Regime classification
ax1 = axes[0]
for i in range(len(regimes.index) - 1):
    regime = regimes.iloc[i]
    ax1.axvspan(regimes.index[i], regimes.index[i+1], 
               alpha=0.8, color=PlotStyles.get_regime_color(regime))
ax1.set_xlim([regimes.index[0], regimes.index[-1]])
ax1.set_yticks([])
ax1.set_title('Regime Classification')

# Panel 2: Volatility
ax2 = axes[1]
ax2.plot(volatility.index, volatility['sigma_1m'] * 100, 'b-', linewidth=1)
ax2.set_ylabel('Volatility (%)')
ax2.set_title('1-Month Realized Volatility')

# Panel 3: Momentum returns
ax3 = axes[2]
cum_mom = np.cumsum(factors['Momentum']) * 100
ax3.plot(cum_mom.index, cum_mom.values, 'r-', linewidth=1)
ax3.set_ylabel('Cumulative Return (%)')
ax3.set_xlabel('Date')
ax3.set_title('Momentum Cumulative Returns')

plt.tight_layout()
plt.show()

## 4. Regime Duration Analysis

In [None]:
# Compute regime durations
episodes = classifier.identify_regime_episodes(volatility)

# Average duration by regime
duration_stats = episodes.groupby('regime')['duration'].agg(['mean', 'std', 'max', 'count'])
duration_stats.columns = ['Avg Duration', 'Std Duration', 'Max Duration', 'Episodes']

print("Regime Duration Statistics (months):")
duration_stats.round(1)

## 5. Key Insights

1. **Calm Trend** is the most frequent state (~40% of observations)
2. **Crash-Spike** is relatively rare (~8%) but extremely impactful
3. **High persistence** on diagonal of transition matrix
4. **Recovery** typically follows **Crash-Spike** states
5. Regime classification aligns with major market events