In [37]:
import os
from datetime import datetime
from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv
from supabase import create_client
import pandas as pd

# Load environment variables from .env if present
load_dotenv()

url = os.environ.get("SUPABASE_URL")
key = os.environ.get("SUPABASE_KEY") or os.environ.get("SUPABASE_SERVICE_ROLE_KEY")

if not url or not key:
    print("Warning: SUPABASE_URL and/or SUPABASE_KEY/SUPABASE_SERVICE_ROLE_KEY are not set in environment.")
else:
    try:
        supabase = create_client(url, key)

        # Calculate date 24 months ago
        # SQL equivalent: date >= (current_date - interval '24 months')
        start_date = (datetime.now() - relativedelta(months=24)).strftime('%Y-%m-%d')
        print(f"Querying data since: {start_date}")

        # Execute Query
        response = supabase.table('macro_regime') \
            .select('date, regime_id, regime_label, stress_flag') \
            .gte('date', start_date) \
            .order('date', desc=False) \
            .execute()

        data = response.data

        if data:
            df = pd.DataFrame(data)
            # Reorder columns to match request if needed, though they come as dict
            cols = ['date', 'regime_id', 'regime_label', 'stress_flag']
            # Filter columns that actually exist in the response
            existing_cols = [c for c in cols if c in df.columns]
            df = df[existing_cols]
            
            display(df)
            print(f"\nTotal rows: {len(df)}")
        else:
            print("No data found.")

    except Exception as e:
        print(f"An error occurred: {e}")

Querying data since: 2024-01-20


Unnamed: 0,date,regime_id,regime_label,stress_flag
0,2024-01-31,3,Stagflation,True
1,2024-02-29,3,Stagflation,True
2,2024-03-31,3,Stagflation,True
3,2024-04-30,2,Reflation,True
4,2024-05-31,2,Reflation,True
5,2024-06-30,3,Stagflation,True
6,2024-07-31,3,Stagflation,True
7,2024-08-31,3,Stagflation,False
8,2024-09-30,3,Stagflation,False
9,2024-10-31,3,Stagflation,False



Total rows: 25


In [38]:
# Aggregation Query
# SQL Equivalent:
# select regime_id, regime_label, count(*) as n
# from public.macro_regime
# group by regime_id, regime_label
# order by regime_id;

try:
    if 'supabase' not in locals():
        print("Supabase client not initialized. Please run the previous cell first.")
    else:
        # Fetch all necessary columns for aggregation
        # We fetch all rows since no date filter was specified in the request
        response_agg = supabase.table('macro_regime') \
            .select('regime_id, regime_label') \
            .execute()
        
        data_agg = response_agg.data
        
        if data_agg:
            df_agg = pd.DataFrame(data_agg)
            # Perform aggregation in Pandas (PostgREST doesn't support direct GROUP BY without Views/RPC)
            result = df_agg.groupby(['regime_id', 'regime_label']).size().reset_index(name='n')
            result = result.sort_values(by='regime_id')
            
            display(result)
        else:
            print("No data found for aggregation.")

except Exception as e:
    print(f"An error occurred during aggregation: {e}")

Unnamed: 0,regime_id,regime_label,n
0,1,Goldilocks,46
1,2,Reflation,42
2,3,Stagflation,113
3,4,Recession,112


In [39]:
# Stress Flag Query
# SQL Equivalent:
# select date, regime_id, regime_label
# from public.macro_regime
# where stress_flag = true
# order by date desc;

try:
    if 'supabase' not in locals():
        print("Supabase client not initialized. Please run the previous cell first.")
    else:
        # Query for stress_flag = True
        response_stress = supabase.table('macro_regime') \
            .select('date, regime_id, regime_label') \
            .eq('stress_flag', True) \
            .order('date', desc=True) \
            .execute()
        
        data_stress = response_stress.data
        
        if data_stress:
            df_stress = pd.DataFrame(data_stress)
             # Reorder columns to match request
            cols = ['date', 'regime_id', 'regime_label']
            existing_cols = [c for c in cols if c in df_stress.columns]
            df_stress = df_stress[existing_cols]

            display(df_stress)
        else:
            print("No data found with stress_flag = true.")

except Exception as e:
    print(f"An error occurred during stress_flag query: {e}")

Unnamed: 0,date,regime_id,regime_label
0,2025-04-30,4,Recession
1,2024-07-31,3,Stagflation
2,2024-06-30,3,Stagflation
3,2024-05-31,2,Reflation
4,2024-04-30,2,Reflation
...,...,...,...
108,2000-06-30,4,Recession
109,2000-05-31,4,Recession
110,2000-04-30,4,Recession
111,2000-03-31,4,Recession


In [40]:
# Regime Change Query
# SQL Equivalent:
# with x as (
#   select
#     date,
#     regime_id,
#     lag(regime_id) over (order by date) as prev_regime_id,
#     stress_flag
#   from public.macro_regime
# )
# select *
# from x
# where prev_regime_id is not null
#   and regime_id <> prev_regime_id
# order by date desc;

try:
    if 'supabase' not in locals():
        print("Supabase client not initialized. Please run the previous cell first.")
    else:
        # Fetch data sorted by date for window calculations
        # We fetch all necessary columns
        response_change = supabase.table('macro_regime') \
            .select('date, regime_id, stress_flag') \
            .order('date', desc=False) \
            .execute()
        
        data_change = response_change.data
        
        if data_change:
            df_change = pd.DataFrame(data_change)
            
            # Create prev_regime_id using shift()
            df_change['prev_regime_id'] = df_change['regime_id'].shift(1)
            
            # Filter logic:
            # prev_regime_id is not null AND regime_id <> prev_regime_id
            cond = (df_change['prev_regime_id'].notna()) & (df_change['regime_id'] != df_change['prev_regime_id'])
            
            result_change = df_change[cond].copy()
            
            # Cast prev_regime_id to int since it becomes float after shift/NaN introduction
            result_change['prev_regime_id'] = result_change['prev_regime_id'].astype(int)
            
            # Order by date desc
            result_change = result_change.sort_values(by='date', ascending=False)
            
            display(result_change)
            print(f"\nTotal regime changes identified: {len(result_change)}")
        else:
            print("No data found for regime change analysis.")

except Exception as e:
    print(f"An error occurred during regime change query: {e}")

Unnamed: 0,date,regime_id,stress_flag,prev_regime_id
309,2025-10-31,4,False,3
307,2025-08-31,3,False,2
306,2025-07-31,2,False,4
305,2025-06-30,4,False,1
304,2025-05-31,1,False,4
...,...,...,...,...
43,2003-08-31,1,False,4
39,2003-04-30,4,False,1
36,2003-01-31,1,True,4
33,2002-10-31,4,True,1



Total regime changes identified: 99


In [41]:
# Visualization: Regime vs SPX (with Stress Flag)
try:
    import yfinance as yf
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    if 'df' not in locals() or 'date' not in df.columns:
        # Fallback if cell 1 wasn't run
        print("Please run the first cell to load 'df' with regime data.")
    else:
        # 1. Prepare Regime Data
        viz_df = df.copy()
        viz_df['date'] = pd.to_datetime(viz_df['date'])
        viz_df = viz_df.sort_values('date')
        
        start_date_viz = viz_df['date'].min()
        end_date_viz = viz_df['date'].max()

        # 2. Fetch SPX Data
        print(f"Fetching SPX data from {start_date_viz.date()} to {end_date_viz.date()}...")
        spx = yf.download("^GSPC", start=start_date_viz, end=end_date_viz + pd.Timedelta(days=1), progress=False)
        
        if spx.empty:
            print("Failed to fetch SPX data.")
        else:
            # Handle MultiIndex if present (yfinance update)
            if isinstance(spx.columns, pd.MultiIndex):
                spx_close = spx["Close"].iloc[:, 0] # Take first column if multi-level
            else:
                spx_close = spx["Close"]
            
            # 3. Create Plot
            fig = make_subplots(specs=[[{"secondary_y": True}]])

            # Add SPX trace
            fig.add_trace(
                go.Scatter(x=spx_close.index, y=spx_close, name="S&P 500", line=dict(color='black', width=1.5)),
                secondary_y=False,
            )

            # Add Background Colors for Regulation
            regime_colors = {
                1: "rgba(76, 175, 80, 0.3)",   # Green (Goldilocks)
                2: "rgba(33, 150, 243, 0.3)",  # Blue (Reflation)
                3: "rgba(255, 152, 0, 0.3)",   # Orange (Stagflation)
                4: "rgba(244, 67, 54, 0.3)"    # Red (Recession)
            }
            regime_labels = {
                1: "Goldilocks",
                2: "Reflation",
                3: "Stagflation",
                4: "Recession"
            }

             # Add invisible traces for Legend (Regimes)
            for r_id, color in regime_colors.items():
                fig.add_trace(
                    go.Scatter(
                        x=[None], y=[None],
                        mode='markers',
                        marker=dict(size=10, color=color, symbol='square'),
                        name=regime_labels.get(r_id, f"Regime {r_id}"),
                        legendgroup=f"group{r_id}",
                        showlegend=True
                    ),
                    secondary_y=False
                )

            # Draw Background Rectangles
            viz_df['group'] = (viz_df['regime_id'] != viz_df['regime_id'].shift()).cumsum()
            for _, group in viz_df.groupby('group'):
                r_id = group['regime_id'].iloc[0]
                g_start = group['date'].min()
                g_end = group['date'].max()
                if len(group) == 1:
                     g_end = g_start + pd.Timedelta(days=28) 
                
                if r_id in regime_colors:
                    fig.add_vrect(
                        x0=g_start,
                        x1=g_end,
                        fillcolor=regime_colors[r_id],
                        layer="below",
                        line_width=0,
                    )
            
            # Add Stress Flag Markers
            # Explicitly cast to boolean to handle potential string 'true'/'false' or objects
            # Convert non-boolean to boolean if necessary (e.g., string 'true' -> True)
            viz_df['stress_flag_bool'] = viz_df['stress_flag'].astype(str).str.lower() == 'true'
            
            stress_data = viz_df[viz_df['stress_flag_bool'] == True]
            print(f"Found {len(stress_data)} occurrences of stress_flag=True")
            
            if not stress_data.empty:
                # Align dates
                stress_dates = stress_data['date']
                # Reindex SPX to these dates, tolerating nearest match within say 3 days
                spx_locs = spx_close.index.get_indexer(stress_dates, method='nearest', tolerance=pd.Timedelta('3 days'))
                
                # Filter out -1s (no match found within tolerance)
                valid_locs = spx_locs != -1
                if valid_locs.any():
                    stress_y = spx_close.iloc[spx_locs[valid_locs]]
                    stress_x = spx_close.index[spx_locs[valid_locs]]

                    fig.add_trace(
                        go.Scatter(
                            x=stress_x,
                            y=stress_y,
                            mode='markers',
                            marker=dict(size=12, color='purple', symbol='x-thin', line=dict(width=2, color='purple')),
                            name="Stress Flag",
                            showlegend=True
                        ),
                        secondary_y=False
                    )
                else:
                    print("No matching SPX dates found for stress flags.")
            else:
                print("No stress flags found in the data.")

            fig.update_layout(
                title_text="Macro Regime vs S&P 500 (with Stress)",
                xaxis_title="Date",
                yaxis_title="S&P 500 Price",
                height=600,
                legend_title="Legend"
            )

            fig.show()

except ImportError as e:
    print(f"Library missing: {e}. Please install plotly and yfinance.")
except Exception as e:
    print(f"An error occurred during visualization: {e}")

Fetching SPX data from 2024-01-31 to 2026-01-31...
Found 8 occurrences of stress_flag=True


In [42]:
# Visualization (fixed): Regime vs SPX Monthly Returns (with Stress)
import pandas as pd
import numpy as np

import yfinance as yf
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---- 0) Validate input ----
if 'df' not in locals() or 'date' not in df.columns:
    raise RuntimeError("Please load 'df' with regime data first (must include 'date').")

viz_df = df.copy()
viz_df['date'] = pd.to_datetime(viz_df['date'])
viz_df = viz_df.sort_values('date')

# Force month-end timestamp for alignment
viz_df['date_me'] = viz_df['date'].dt.to_period('M').dt.to_timestamp('M')
viz_df = viz_df.drop_duplicates(subset=['date_me']).set_index('date_me').sort_index()

start_date_viz = viz_df.index.min()
end_date_viz = viz_df.index.max()

# ---- 1) Fetch SPX daily, convert to month-end ----
print(f"Fetching SPX (^GSPC) daily from {start_date_viz.date()} to {end_date_viz.date()}...")
spx = yf.download("^GSPC", start=start_date_viz, end=end_date_viz + pd.Timedelta(days=1), progress=False)

if spx.empty:
    raise RuntimeError("Failed to fetch SPX data from yfinance.")

# Handle MultiIndex columns if present
if isinstance(spx.columns, pd.MultiIndex):
    spx_close_daily = spx["Close"].iloc[:, 0]
else:
    spx_close_daily = spx["Close"]

spx_close_me = spx_close_daily.resample("M").last()
spx_ret_1m = spx_close_me.pct_change()  # decimal returns
spx_ret_1m.name = "spx_ret_1m"

# Align with viz_df index (month-end)
common_idx = viz_df.index.intersection(spx_ret_1m.index)
viz_df = viz_df.loc[common_idx]
spx_close_me = spx_close_me.loc[common_idx]
spx_ret_1m = spx_ret_1m.loc[common_idx]

# Stress flag robust boolean parsing
if 'stress_flag' in viz_df.columns:
    stress_bool = viz_df['stress_flag'].astype(str).str.lower().isin(['true', '1', 't', 'yes', 'y'])
else:
    stress_bool = pd.Series(False, index=viz_df.index)

# ---- 2) Plot setup (2 rows: return + optional stress_score) ----
has_stress_score = 'stress_score' in viz_df.columns

fig = make_subplots(
    rows=2 if has_stress_score else 1,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.06,
    row_heights=[0.72, 0.28] if has_stress_score else [1.0],
    specs=[[{"secondary_y": False}], [{"secondary_y": False}]] if has_stress_score else [[{"secondary_y": False}]],
)

# ---- 3) Add monthly return line (preferred for regime eval) ----
fig.add_trace(
    go.Scatter(
        x=spx_ret_1m.index,
        y=(spx_ret_1m * 100.0),  # %
        name="S&P 500 1M Return (%)",
        line=dict(color="black", width=1.5),
    ),
    row=1, col=1
)

# Optional: add zero line
fig.add_hline(y=0, line_width=1, line_dash="dot", line_color="gray", row=1, col=1)

# ---- 4) Regime background rectangles (month-accurate) ----
regime_colors = {
    1: "rgba(76, 175, 80, 0.22)",   # Goldilocks
    2: "rgba(33, 150, 243, 0.22)",  # Reflation
    3: "rgba(255, 152, 0, 0.22)",   # Stagflation
    4: "rgba(244, 67, 54, 0.22)",   # Recession
}
regime_labels = {1: "Goldilocks", 2: "Reflation", 3: "Stagflation", 4: "Recession"}

# Legend entries for regimes (invisible)
for r_id, color in regime_colors.items():
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(size=10, color=color, symbol="square"),
            name=regime_labels.get(r_id, f"Regime {r_id}"),
            legendgroup=f"regime{r_id}",
            showlegend=True
        ),
        row=1, col=1
    )

# Create contiguous blocks of same regime
viz_df['_grp'] = (viz_df['regime_id'] != viz_df['regime_id'].shift()).cumsum()

for _, g in viz_df.groupby('_grp'):
    r_id = int(g['regime_id'].iloc[0])
    x0 = g.index.min()
    # end at next month-end + 1 day to cover the month, else too thin
    x1 = (g.index.max() + pd.offsets.MonthEnd(1))  # next month-end
    fig.add_vrect(
        x0=x0, x1=x1,
        fillcolor=regime_colors.get(r_id, "rgba(0,0,0,0.05)"),
        layer="below",
        line_width=0,
        row=1, col=1
    )
    if has_stress_score:
        fig.add_vrect(
            x0=x0, x1=x1,
            fillcolor=regime_colors.get(r_id, "rgba(0,0,0,0.05)"),
            layer="below",
            line_width=0,
            row=2, col=1
        )

# ---- 5) Stress markers on month-end (no “nearest day” needed) ----
stress_points = viz_df[stress_bool]

print(f"Found {len(stress_points)} months with stress_flag=True (out of {len(viz_df)})")

if not stress_points.empty:
    # Place marker on return line at that month
    stress_y = (spx_ret_1m.loc[stress_points.index] * 100.0)
    fig.add_trace(
        go.Scatter(
            x=stress_points.index,
            y=stress_y,
            mode="markers",
            marker=dict(size=10, color="purple", symbol="x"),
            name="Stress Flag (month-end)",
            showlegend=True
        ),
        row=1, col=1
    )

# ---- 6) Optional: stress_score panel ----
if has_stress_score:
    fig.add_trace(
        go.Scatter(
            x=viz_df.index,
            y=viz_df['stress_score'],
            name="stress_score",
            line=dict(width=1.5),
        ),
        row=2, col=1
    )
    # Threshold guide (example: 1.0)
    fig.add_hline(y=1.0, line_width=1, line_dash="dot", line_color="gray", row=2, col=1)

# ---- 7) Layout ----
fig.update_layout(
    title="Macro Regime vs S&P 500 (Monthly Returns) + Stress",
    height=750 if has_stress_score else 600,
    legend_title="Legend",
    hovermode="x unified",
)

fig.update_yaxes(title_text="S&P 500 1M Return (%)", row=1, col=1)
if has_stress_score:
    fig.update_yaxes(title_text="stress_score", row=2, col=1)

fig.update_xaxes(title_text="Date (Month-End)")

fig.show()


Fetching SPX (^GSPC) daily from 2024-01-31 to 2026-01-31...
Found 8 months with stress_flag=True (out of 25)



'M' is deprecated and will be removed in a future version, please use 'ME' instead.



In [43]:
# Macro Regime vs SPX (Monthly Returns) + Stress shading (FULL CODE)

import pandas as pd
import numpy as np

import yfinance as yf
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# ============================
# 0) INPUT CHECK
# ============================
# Assumes you already have a DataFrame named `df` that includes:
#   - date (YYYY-MM-DD or datetime)
#   - regime_id (1~4)
#   - stress_flag (bool or 'true'/'false' etc)
# Optional:
#   - regime_label
#   - stress_score
#
# Example: df = pd.read_csv("macro_regime.csv")

if "df" not in locals() or "date" not in df.columns:
    raise RuntimeError("먼저 df를 준비해주세요. df에는 최소 date, regime_id, stress_flag 컬럼이 필요합니다.")

required_cols = {"date", "regime_id", "stress_flag"}
missing = required_cols - set(df.columns)
if missing:
    raise RuntimeError(f"df에 필요한 컬럼이 없습니다: {sorted(list(missing))}")


# ============================
# 1) PREP REGIME DATA (MONTH-END ALIGN)
# ============================
viz_df = df.copy()
viz_df["date"] = pd.to_datetime(viz_df["date"])
viz_df = viz_df.sort_values("date")

# month-end timestamp index
viz_df["date_me"] = viz_df["date"].dt.to_period("M").dt.to_timestamp("M")
viz_df = viz_df.drop_duplicates(subset=["date_me"]).set_index("date_me").sort_index()

start_date_viz = viz_df.index.min()
end_date_viz = viz_df.index.max()

# robust boolean parsing (handles True/False, "true"/"false", 1/0, etc.)
stress_bool = viz_df["stress_flag"].astype(str).str.lower().isin(["true", "1", "t", "yes", "y"])


# ============================
# 2) FETCH SPX DAILY -> MONTH-END -> MONTHLY RETURNS
# ============================
print(f"Fetching SPX (^GSPC) daily from {start_date_viz.date()} to {end_date_viz.date()}...")
spx = yf.download("^GSPC", start=start_date_viz, end=end_date_viz + pd.Timedelta(days=1), progress=False)

if spx.empty:
    raise RuntimeError("yfinance에서 ^GSPC 데이터를 가져오지 못했습니다.")

# handle multiindex columns (yfinance changes)
if isinstance(spx.columns, pd.MultiIndex):
    spx_close_daily = spx["Close"].iloc[:, 0]
else:
    spx_close_daily = spx["Close"]

spx_close_me = spx_close_daily.resample("M").last()
spx_ret_1m = spx_close_me.pct_change()  # decimal
spx_ret_1m.name = "spx_ret_1m"

# align indexes
common_idx = viz_df.index.intersection(spx_ret_1m.index)
viz_df = viz_df.loc[common_idx]
stress_bool = stress_bool.loc[common_idx]
spx_close_me = spx_close_me.loc[common_idx]
spx_ret_1m = spx_ret_1m.loc[common_idx]

print(f"Aligned months: {len(common_idx)} | Stress months: {int(stress_bool.sum())}")


# ============================
# 3) PLOT SETTINGS
# ============================
has_stress_score = "stress_score" in viz_df.columns

fig = make_subplots(
    rows=2 if has_stress_score else 1,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.06,
    row_heights=[0.70, 0.30] if has_stress_score else [1.0],
)

# regime colors
regime_colors = {
    1: "rgba(76, 175, 80, 0.22)",   # Goldilocks
    2: "rgba(33, 150, 243, 0.22)",  # Reflation
    3: "rgba(255, 152, 0, 0.22)",   # Stagflation
    4: "rgba(244, 67, 54, 0.22)",   # Recession
}
regime_labels = {
    1: "Goldilocks",
    2: "Reflation",
    3: "Stagflation",
    4: "Recession",
}

# optional toggles
SHOW_STRESS_MARKERS = True   # X marker on return line
SHOW_STRESS_SHADING = True   # gray shading on stress months


# ============================
# 4) MAIN TRACE: SPX MONTHLY RETURN
# ============================
fig.add_trace(
    go.Scatter(
        x=spx_ret_1m.index,
        y=spx_ret_1m * 100.0,
        name="S&P 500 1M Return (%)",
        line=dict(color="black", width=1.5),
    ),
    row=1, col=1
)
fig.add_hline(y=0, line_width=1, line_dash="dot", line_color="gray", row=1, col=1)


# ============================
# 5) LEGENDS (invisible markers for regime & stress shading)
# ============================
for r_id, color in regime_colors.items():
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode="markers",
            marker=dict(size=10, color=color, symbol="square"),
            name=regime_labels.get(r_id, f"Regime {r_id}"),
            showlegend=True
        ),
        row=1, col=1
    )

# stress shading legend entry (invisible)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode="markers",
        marker=dict(size=10, color="rgba(120,120,120,0.25)", symbol="square"),
        name="Stress Month (shaded)",
        showlegend=True
    ),
    row=1, col=1
)

# stress marker legend entry (invisible)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode="markers",
        marker=dict(size=10, color="purple", symbol="x"),
        name="Stress Flag (month-end)",
        showlegend=True
    ),
    row=1, col=1
)


# ============================
# 6) BACKGROUND: REGIME VRECTS (month-accurate)
# ============================
viz_df["_grp_regime"] = (viz_df["regime_id"] != viz_df["regime_id"].shift()).cumsum()

for _, g in viz_df.groupby("_grp_regime"):
    r_id = int(g["regime_id"].iloc[0])
    x0 = g.index.min()
    # cover until next month-end (exclusive 느낌)
    x1 = g.index.max() + pd.offsets.MonthEnd(1)

    # row 1
    fig.add_vrect(
        x0=x0, x1=x1,
        fillcolor=regime_colors.get(r_id, "rgba(0,0,0,0.05)"),
        layer="below",
        line_width=0,
        row=1, col=1
    )
    # row 2 (if stress_score panel exists)
    if has_stress_score:
        fig.add_vrect(
            x0=x0, x1=x1,
            fillcolor=regime_colors.get(r_id, "rgba(0,0,0,0.05)"),
            layer="below",
            line_width=0,
            row=2, col=1
        )


# ============================
# 7) STRESS SHADING (month-by-month gray overlay)
# ============================
if SHOW_STRESS_SHADING:
    stress_months = viz_df.index[stress_bool]
    for d in stress_months:
        x0 = d
        x1 = d + pd.offsets.MonthEnd(1)  # next month-end
        # overlay on top of regime shading but still below traces
        fig.add_vrect(
            x0=x0, x1=x1,
            fillcolor="rgba(120,120,120,0.18)",
            layer="below",
            line_width=0,
            row=1, col=1
        )
        if has_stress_score:
            fig.add_vrect(
                x0=x0, x1=x1,
                fillcolor="rgba(120,120,120,0.18)",
                layer="below",
                line_width=0,
                row=2, col=1
            )


# ============================
# 8) STRESS MARKERS (X on monthly return)
# ============================
if SHOW_STRESS_MARKERS:
    stress_points = viz_df.loc[stress_bool]
    if not stress_points.empty:
        stress_y = spx_ret_1m.loc[stress_points.index] * 100.0
        fig.add_trace(
            go.Scatter(
                x=stress_points.index,
                y=stress_y,
                mode="markers",
                marker=dict(size=10, color="purple", symbol="x"),
                name="Stress Flag (month-end)",
                showlegend=False  # legend already added as invisible entry
            ),
            row=1, col=1
        )


# ============================
# 9) OPTIONAL: STRESS SCORE PANEL
# ============================
if has_stress_score:
    fig.add_trace(
        go.Scatter(
            x=viz_df.index,
            y=viz_df["stress_score"],
            name="stress_score",
            line=dict(width=1.5),
        ),
        row=2, col=1
    )
    # threshold guide (adjust if your rule differs)
    fig.add_hline(y=1.0, line_width=1, line_dash="dot", line_color="gray", row=2, col=1)
    fig.update_yaxes(title_text="stress_score", row=2, col=1)


# ============================
# 10) LAYOUT
# ============================
fig.update_layout(
    title="Macro Regime vs S&P 500 (Monthly Returns) + Stress",
    height=780 if has_stress_score else 620,
    hovermode="x unified",
    legend_title="Legend",
)

fig.update_xaxes(title_text="Date (Month-End)")
fig.update_yaxes(title_text="S&P 500 1M Return (%)", row=1, col=1)

fig.show()


Fetching SPX (^GSPC) daily from 2024-01-31 to 2026-01-31...



'M' is deprecated and will be removed in a future version, please use 'ME' instead.



Aligned months: 25 | Stress months: 8


In [44]:
# Backtest (3): Regime base weight + Stress risk-cut, vs Buy&Hold

import pandas as pd
import numpy as np
import yfinance as yf
import plotly.graph_objects as go

# ----------------------------
# 0) Input check
# ----------------------------
if "df" not in locals() or "date" not in df.columns:
    raise RuntimeError("먼저 df를 준비해주세요. df에는 최소 date, regime_id, stress_flag 컬럼이 필요합니다.")

need = {"date", "regime_id", "stress_flag"}
missing = need - set(df.columns)
if missing:
    raise RuntimeError(f"df에 필요한 컬럼이 없습니다: {sorted(list(missing))}")

viz_df = df.copy()
viz_df["date"] = pd.to_datetime(viz_df["date"])
viz_df = viz_df.sort_values("date")

# month-end index
viz_df["date_me"] = viz_df["date"].dt.to_period("M").dt.to_timestamp("M")
viz_df = viz_df.drop_duplicates(subset=["date_me"]).set_index("date_me").sort_index()

# parse stress_flag robustly
stress_bool = viz_df["stress_flag"].astype(str).str.lower().isin(["true", "1", "t", "yes", "y"])

start_me = viz_df.index.min()
end_me = viz_df.index.max()

# ----------------------------
# 1) Fetch SPX daily -> month-end
# ----------------------------
print(f"Fetching SPX (^GSPC) from {start_me.date()} to {end_me.date()}...")
spx = yf.download("^GSPC", start=start_me, end=end_me + pd.Timedelta(days=1), progress=False)
if spx.empty:
    raise RuntimeError("yfinance에서 ^GSPC 데이터를 가져오지 못했습니다.")

if isinstance(spx.columns, pd.MultiIndex):
    spx_close_daily = spx["Close"].iloc[:, 0]
else:
    spx_close_daily = spx["Close"]

spx_close_me = spx_close_daily.resample("M").last()
spx_ret_1m = spx_close_me.pct_change()  # decimal
spx_ret_1m.name = "spx_ret_1m"

# align
idx = viz_df.index.intersection(spx_ret_1m.index)
viz_df = viz_df.loc[idx]
stress_bool = stress_bool.loc[idx]
spx_close_me = spx_close_me.loc[idx]
spx_ret_1m = spx_ret_1m.loc[idx]

# drop first NaN return month
mask = spx_ret_1m.notna()
viz_df = viz_df.loc[mask]
stress_bool = stress_bool.loc[mask]
spx_ret_1m = spx_ret_1m.loc[mask]
spx_close_me = spx_close_me.loc[mask]

# ----------------------------
# 2) Strategy rules
# ----------------------------
# 레짐별 "기본 주식비중" (원하면 여기만 조절하면 됨)
base_weight = {
    1: 0.70,  # Goldilocks
    2: 0.80,  # Reflation
    3: 0.50,  # Stagflation
    4: 0.40,  # Recession
}

# stress일 때 주식비중 축소 배수
stress_mult = 0.50  # 예: 0.5면 절반으로 줄임 (나머지는 현금)

# 월말 신호 → 다음 달에 적용(룩어헤드 방지)
w_eq_raw = viz_df["regime_id"].map(base_weight).astype(float)
w_eq_raw = w_eq_raw.where(~stress_bool, w_eq_raw * stress_mult)
w_eq = w_eq_raw.shift(1)  # next month apply
w_eq = w_eq.fillna(0.0).clip(lower=0.0, upper=1.0)
w_cash = 1.0 - w_eq  # cash return assumed 0% (보수적)

# ----------------------------
# 3) Backtest
# ----------------------------
# Portfolio returns (cash 0%)
port_ret = w_eq * spx_ret_1m
bh_ret = spx_ret_1m.copy()

# equity curves
port_equity = (1.0 + port_ret).cumprod()
bh_equity = (1.0 + bh_ret).cumprod()

# turnover (optional): sum of abs weight changes
turnover = (w_eq.diff().abs()).fillna(0.0)
avg_turnover = turnover.mean()

# ----------------------------
# 4) Stats helpers
# ----------------------------
def max_drawdown(equity: pd.Series) -> float:
    peak = equity.cummax()
    dd = equity / peak - 1.0
    return float(dd.min())

def perf_stats(r: pd.Series, equity: pd.Series) -> dict:
    r = r.dropna()
    if len(r) == 0:
        return {}
    mean_1m = float(r.mean())
    med_1m = float(r.median())
    ann_ret = float((1.0 + mean_1m) ** 12 - 1.0)
    ann_vol = float(r.std(ddof=0) * np.sqrt(12))
    sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else np.nan
    stats = {
        "n_months": int(len(r)),
        "mean_1m": mean_1m,
        "median_1m": med_1m,
        "ann_ret": ann_ret,
        "ann_vol": ann_vol,
        "sharpe": sharpe,
        "max_dd": max_drawdown(equity),
        "worst_1m": float(r.min()),
        "best_1m": float(r.max()),
    }
    return stats

st_port = perf_stats(port_ret, port_equity)
st_bh = perf_stats(bh_ret, bh_equity)

print("\n[Strategy] Regime base + Stress cut")
print(st_port)
print(f"avg_turnover(weight change abs): {avg_turnover:.4f}")

print("\n[Buy&Hold] SPX")
print(st_bh)

# ----------------------------
# 5) Plot
# ----------------------------
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=port_equity.index, y=port_equity,
    name="Strategy (Regime base + Stress cut)",
    line=dict(width=2)
))
fig.add_trace(go.Scatter(
    x=bh_equity.index, y=bh_equity,
    name="Buy & Hold (SPX)",
    line=dict(width=2, dash="dot")
))

# Optional: show equity scale log (toggle)
USE_LOG_Y = False
fig.update_layout(
    title=f"Backtest: Regime base + Stress cut (stress_mult={stress_mult}) vs SPX Buy&Hold",
    xaxis_title="Month-End Date",
    yaxis_title="Equity (start=1.0)",
    hovermode="x unified",
    height=600,
)
if USE_LOG_Y:
    fig.update_yaxes(type="log")

fig.show()

# ----------------------------
# 6) (Optional) Weight plot
# ----------------------------
# Uncomment if you want to see weight dynamics
# fig2 = go.Figure()
# fig2.add_trace(go.Scatter(x=w_eq.index, y=w_eq, name="Equity weight (applied next month)"))
# fig2.update_layout(title="Equity Weight Over Time", xaxis_title="Month-End", yaxis_title="Weight", height=350)
# fig2.show()


Fetching SPX (^GSPC) from 2024-01-31 to 2026-01-31...

[Strategy] Regime base + Stress cut
{'n_months': 24, 'mean_1m': 0.004866687139000253, 'median_1m': 0.006731685856496666, 'ann_ret': 0.059989070468867034, 'ann_vol': 0.05432198749844062, 'sharpe': 1.104323925382684, 'max_dd': -0.06193864848757502, 'worst_1m': -0.0460357581451003, 'best_1m': 0.034724787627401076}
avg_turnover(weight change abs): 0.1375

[Buy&Hold] SPX
{'n_months': 24, 'mean_1m': 0.015526794370152914, 'median_1m': 0.020931723310620987, 'ann_ret': 0.20308590694906292, 'ann_vol': 0.1036599023185173, 'sharpe': 1.9591558780852203, 'max_dd': -0.07805105567413462, 'worst_1m': -0.05754469768137538, 'best_1m': 0.06152382614078289}



'M' is deprecated and will be removed in a future version, please use 'ME' instead.



In [None]:
# Backtest (2): Regime base weight + Continuous risk-cut by stress_score (vs Buy&Hold)
# - Loads df from Supabase (macro_regime) including stress_score
# - Runs backtest and plots

import os
import pandas as pd
import numpy as np
import yfinance as yf
import plotly.graph_objects as go

# ----------------------------
# 0) Load df from Supabase
# ----------------------------
def load_macro_regime_from_supabase(
    table: str = "macro_regime",
    schema: str = "public",
) -> pd.DataFrame:
    """
    Requires env:
      SUPABASE_URL
      SUPABASE_ANON_KEY  (read-only)  OR  SUPABASE_SERVICE_ROLE_KEY
    """
    try:
        from supabase import create_client
    except Exception as e:
        raise RuntimeError("supabase 패키지가 필요합니다: pip install supabase") from e

    url = os.environ.get("SUPABASE_URL")
    key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") or os.environ.get("SUPABASE_KEY")
    if not url or not key:
        raise RuntimeError("SUPABASE_URL + (SUPABASE_ANON_KEY 또는 SUPABASE_SERVICE_ROLE_KEY) 환경변수가 필요합니다.")

    supabase = create_client(url, key)

    # IMPORTANT: stress_score를 반드시 select에 포함
    sel = "date,regime_id,regime_label,stress_flag,stress_score,stress_driver,updated_at"

    # Pagination (Supabase returns max rows per call depending on settings)
    # We'll page in chunks of 1000.
    data_all = []
    start = 0
    step = 1000

    while True:
        resp = (
            supabase.table(table)
            .select(sel)
            .order("date")
            .range(start, start + step - 1)
            .execute()
        )
        batch = resp.data or []
        data_all.extend(batch)
        if len(batch) < step:
            break
        start += step

    df = pd.DataFrame(data_all)

    if df.empty:
        raise RuntimeError("Supabase에서 macro_regime 데이터를 못 가져왔습니다(빈 결과).")

    # Type cleanup / guarantee columns
    df["date"] = pd.to_datetime(df["date"], errors="coerce")
    df["regime_id"] = pd.to_numeric(df["regime_id"], errors="coerce")
    df["stress_flag"] = df["stress_flag"].astype(bool)

    # 핵심: stress_score 보장
    if "stress_score" not in df.columns:
        df["stress_score"] = np.nan
    df["stress_score"] = pd.to_numeric(df["stress_score"], errors="coerce").fillna(0.0)

    return df


df = load_macro_regime_from_supabase()

print("df columns:", df.columns.tolist())
print(df.tail(3)[["date", "regime_id", "stress_score", "stress_flag"]])

# ----------------------------
# 1) Input check
# ----------------------------
need = {"date", "regime_id", "stress_score"}
missing = need - set(df.columns)
if missing:
    raise RuntimeError(f"df에 필요한 컬럼이 없습니다: {sorted(list(missing))}")

viz_df = df.copy()
viz_df = viz_df.dropna(subset=["date", "regime_id"]).copy()
viz_df["date"] = pd.to_datetime(viz_df["date"])
viz_df = viz_df.sort_values("date")

# month-end index
viz_df["date_me"] = viz_df["date"].dt.to_period("M").dt.to_timestamp("M")
viz_df = viz_df.drop_duplicates(subset=["date_me"]).set_index("date_me").sort_index()

start_me = viz_df.index.min()
end_me = viz_df.index.max()

# ----------------------------
# 2) Fetch SPX daily -> month-end
# ----------------------------
print(f"Fetching SPX (^GSPC) from {start_me.date()} to {end_me.date()}...")
spx = yf.download("^GSPC", start=start_me, end=end_me + pd.Timedelta(days=1), progress=False)
if spx.empty:
    raise RuntimeError("yfinance에서 ^GSPC 데이터를 가져오지 못했습니다.")

if isinstance(spx.columns, pd.MultiIndex):
    spx_close_daily = spx["Close"].iloc[:, 0]
else:
    spx_close_daily = spx["Close"]

spx_close_me = spx_close_daily.resample("M").last()
spx_ret_1m = spx_close_me.pct_change()
spx_ret_1m.name = "spx_ret_1m"

# align
idx = viz_df.index.intersection(spx_ret_1m.index)
viz_df = viz_df.loc[idx]
spx_close_me = spx_close_me.loc[idx]
spx_ret_1m = spx_ret_1m.loc[idx]

# drop first NaN return month
mask = spx_ret_1m.notna()
viz_df = viz_df.loc[mask]
spx_ret_1m = spx_ret_1m.loc[mask]
spx_close_me = spx_close_me.loc[mask]

# ----------------------------
# 3) Strategy rules
# ----------------------------
base_weight = {
    1: 0.70,  # Goldilocks
    2: 0.80,  # Reflation
    3: 0.50,  # Stagflation
    4: 0.40,  # Recession
}

threshold = 1.0
k = 0.50
floor = 0.20
cap = 1.00

base = viz_df["regime_id"].map(base_weight).astype(float)

stress_score = pd.to_numeric(viz_df["stress_score"], errors="coerce").fillna(0.0)
excess = (stress_score - threshold).clip(lower=0.0)
risk_cut_factor = (1.0 - k * excess).clip(lower=floor, upper=cap)

w_eq_raw = (base * risk_cut_factor).clip(lower=0.0, upper=1.0)

# month-end signal -> apply next month
w_eq = w_eq_raw.shift(1).fillna(0.0).clip(0.0, 1.0)

# ----------------------------
# 4) Backtest
# ----------------------------
port_ret = w_eq * spx_ret_1m
bh_ret = spx_ret_1m.copy()

port_equity = (1.0 + port_ret).cumprod()
bh_equity = (1.0 + bh_ret).cumprod()

turnover = (w_eq.diff().abs()).fillna(0.0)
avg_turnover = float(turnover.mean())

def max_drawdown(equity: pd.Series) -> float:
    peak = equity.cummax()
    dd = equity / peak - 1.0
    return float(dd.min())

def perf_stats(r: pd.Series, equity: pd.Series) -> dict:
    r = r.dropna()
    if len(r) == 0:
        return {}
    mean_1m = float(r.mean())
    med_1m = float(r.median())
    ann_ret = float((1.0 + mean_1m) ** 12 - 1.0)
    ann_vol = float(r.std(ddof=0) * np.sqrt(12))
    sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else np.nan
    return {
        "n_months": int(len(r)),
        "mean_1m": mean_1m,
        "median_1m": med_1m,
        "ann_ret": ann_ret,
        "ann_vol": ann_vol,
        "sharpe": sharpe,
        "max_dd": max_drawdown(equity),
        "worst_1m": float(r.min()),
        "best_1m": float(r.max()),
    }

st_port = perf_stats(port_ret, port_equity)
st_bh = perf_stats(bh_ret, bh_equity)

print("\n[Strategy] Regime base + Continuous risk-cut by stress_score")
print(f"params: threshold={threshold}, k={k}, floor={floor}, cap={cap}")
print(st_port)
print(f"avg_turnover(abs weight change): {avg_turnover:.4f}")

print("\n[Buy&Hold] SPX")
print(st_bh)

# ----------------------------
# 5) Plot equity curves
# ----------------------------
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=port_equity.index, y=port_equity,
    name="Strategy (Regime base + stress_score cut)",
    line=dict(width=2),
))
fig.add_trace(go.Scatter(
    x=bh_equity.index, y=bh_equity,
    name="Buy & Hold (SPX)",
    line=dict(width=2, dash="dot"),
))
fig.update_layout(
    title="Backtest: Regime base + Continuous risk-cut (stress_score) vs SPX Buy&Hold",
    xaxis_title="Month-End Date",
    yaxis_title="Equity (start=1.0)",
    hovermode="x unified",
    height=600,
)
fig.show()

# ----------------------------
# 6) Plot weight & stress (diagnostic)
# ----------------------------
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=w_eq.index, y=w_eq, name="Equity weight (applied)"))
fig2.add_trace(go.Scatter(x=risk_cut_factor.index, y=risk_cut_factor, name="risk_cut_factor", yaxis="y2"))
fig2.add_trace(go.Scatter(x=stress_score.index, y=stress_score, name="stress_score", yaxis="y2", line=dict(dash="dot")))

fig2.update_layout(
    title="Diagnostics: equity weight vs stress_score / risk_cut_factor",
    xaxis_title="Month-End",
    yaxis=dict(title="Equity weight", rangemode="tozero"),
    yaxis2=dict(title="stress_score / factor", overlaying="y", side="right"),
    hovermode="x unified",
    height=420,
)
fig2.show()


RuntimeError: SUPABASE_URL + (SUPABASE_ANON_KEY 또는 SUPABASE_SERVICE_ROLE_KEY) 환경변수가 필요합니다.