In [None]:
import sys, subprocess

try:
    import nbformat  # noqa
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nbformat>=4.2.0"])
    import nbformat  # noqa

import pandas as pd
import numpy as np
from pathlib import Path
import plotly.graph_objects as go
import plotly.io as pio
from statsmodels.stats.outliers_influence import variance_inflation_factor

pio.renderers.default = "notebook_connected"  

# ----- Load -----
PROC_DIR = Path("../data/processed")
MERGED_PATH = PROC_DIR / "merged_macro_factors_monthly.csv"
df = pd.read_csv(MERGED_PATH, parse_dates=["Date"]).set_index("Date").sort_index()

# ----- Columns -----
factor_cols_req = ["Mkt-RF", "SMB", "HML", "RMW", "CMA", "RF"]
macro_cols_req  = ["market", "yield_curve", "oil ($/bbl)", "copper ($/metric ton)",
                   "monetary_policy", "volatility", "stock_bond_corr"]

def existing_numeric(df, cols):
    keep, miss = [], []
    for c in cols:
        if c in df.columns and pd.api.types.is_numeric_dtype(df[c]):
            keep.append(c)
        else:
            miss.append(c)
    return keep, miss

factor_cols, miss_f = existing_numeric(df, factor_cols_req)
macro_cols,  miss_m = existing_numeric(df, macro_cols_req)
factor_core = [c for c in factor_cols if c.upper() != "RF"]

XF = df[factor_core].dropna().copy() if factor_core else pd.DataFrame(index=df.index)
XM = df[macro_cols].dropna().copy()   if macro_cols else pd.DataFrame(index=df.index)
XC = df[factor_core + macro_cols].dropna().copy() if (factor_core and macro_cols) else XF.copy()

print("[INFO] Factors:", factor_core)
print("[INFO] Macros :", macro_cols)
if miss_f: print("[WARN] Missing/non-numeric (factors):", miss_f)
if miss_m: print("[WARN] Missing/non-numeric (macros):",  miss_m)

# ----- Correlations -----
def safe_corr(df_, method="pearson"):
    return df_.corr(method=method) if df_.shape[1] > 0 else pd.DataFrame()

corrF = safe_corr(XF, "pearson")
corrM = safe_corr(XM, "pearson")

def cross_corr(A: pd.DataFrame, B: pd.DataFrame):
    if A.empty or B.empty: return pd.DataFrame()
    out = pd.DataFrame(index=A.columns, columns=B.columns, dtype=float)
    for a in A.columns:
        for b in B.columns:
            out.loc[a, b] = A[a].corr(B[b])
    return out
crossFM = cross_corr(XF, XM)

# ----- VIF & Condition number -----
def compute_vif_table(X: pd.DataFrame) -> pd.DataFrame:
    if X.empty or X.shape[1] < 2:
        return pd.DataFrame(columns=["variable","VIF","Tolerance"])
    Xs = (X - X.mean()) / X.std(ddof=0)
    Xs = Xs.loc[:, Xs.std() > 0].copy()
    if Xs.shape[1] == 0:
        return pd.DataFrame(columns=["variable","VIF","Tolerance"])
    vifs = [variance_inflation_factor(Xs.values, i) for i in range(Xs.shape[1])]
    tbl = pd.DataFrame({"variable": Xs.columns, "VIF": vifs})
    tbl["Tolerance"] = 1.0 / tbl["VIF"]
    return tbl.sort_values("VIF", ascending=False).reset_index(drop=True)

def condition_number(X: pd.DataFrame) -> float:
    if X.empty or X.shape[1] < 2: return np.nan
    Xs = (X - X.mean()) / X.std(ddof=0)
    Xs = Xs.loc[:, Xs.std() > 0].copy()
    if Xs.shape[1] < 2: return np.nan
    u, s, vh = np.linalg.svd(Xs.dropna().values, full_matrices=False)
    return np.inf if (s.min() == 0) else float(s.max()/s.min())

vif_F = compute_vif_table(XF)
vif_C = compute_vif_table(XC)
cn_F  = condition_number(XF)
cn_C  = condition_number(XC)

# ----- Plot helpers (ONLY show()) -----
def heatmap_annot(z_df: pd.DataFrame, title: str):
    if z_df.empty:
        print(f"[SKIP] Heatmap {title} (empty)")
        return
    z = z_df.values
    x = list(z_df.columns)
    y = list(z_df.index)
    text = [[f"{val:.2f}" if pd.notna(val) else "" for val in row] for row in z]
    fig = go.Figure(data=go.Heatmap(
        z=z, x=x, y=y,
        colorscale="RdBu", reversescale=True, zmid=0,
        colorbar=dict(title="ρ"),
        text=text, texttemplate="%{text}",
        hovertemplate="x=%{x}<br>y=%{y}<br>ρ=%{z:.3f}<extra></extra>"
    ))
    fig.update_layout(
        title=title,
        xaxis=dict(tickangle=45),
        width=max(700, 70*len(x)+220),
        height=max(520, 32*len(y)+220),
        margin=dict(l=90,r=20,t=60,b=90)
    )
    fig.show()

def table_vif(tbl: pd.DataFrame, title: str):
    if tbl.empty:
        print(f"[SKIP] {title} (empty)")
        return
    fig = go.Figure(data=[go.Table(
        header=dict(values=["Variable","VIF","Tolerance"], align="left"),
        cells=dict(values=[tbl["variable"], tbl["VIF"].round(3), tbl["Tolerance"].round(3)],
                   align="left")
    )])
    fig.update_layout(title=title, width=720, height=420)
    fig.show()

def splom_factors(X: pd.DataFrame, title: str):
    if X.shape[1] < 2:
        print("[SKIP] SPLOM (need ≥2 factor columns)")
        return
    n = X.shape[1]
    size = max(700, 180*n)
    dimensions=[dict(label=c, values=X[c]) for c in X.columns]
    fig = go.Figure(data=go.Splom(
        dimensions=dimensions, showupperhalf=False, diagonal_visible=True,
        marker=dict(size=4, opacity=0.65)
    ))
    fig.update_layout(title=title, dragmode="select", width=size, height=size)
    fig.show()

def scatter_with_reg(x, y, xname, yname, title):
    xv = np.asarray(x, dtype=float)
    yv = np.asarray(y, dtype=float)
    mask = np.isfinite(xv) & np.isfinite(yv)
    if mask.sum() < 3:
        print(f"[SKIP] {title} (too few points)")
        return
    b, a = np.polyfit(xv[mask], yv[mask], 1)  # slope, intercept
    yhat = a + b*xv[mask]
    ss_res = ((yv[mask]-yhat)**2).sum()
    ss_tot = ((yv[mask]-yv[mask].mean())**2).sum()
    r2 = 1 - ss_res/ss_tot if ss_tot>0 else np.nan
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=xv[mask], y=yv[mask], mode="markers", name="Data",
        hovertemplate=f"{xname}=%{{x:.4f}}<br>{yname}=%{{y:.4f}}<extra></extra>",
    ))
    xs = np.linspace(xv[mask].min(), xv[mask].max(), 100)
    fig.add_trace(go.Scatter(
        x=xs, y=a + b*xs, mode="lines",
        name=f"OLS fit (β={b:.3f}, R²={r2:.3f})", hoverinfo="skip"
    ))
    fig.update_layout(
        title=title, xaxis_title=xname, yaxis_title=yname,
        width=680, height=500,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
    )
    fig.show()

# ----- Draw (inline only) -----
heatmap_annot(corrF, "Pearson Correlation — Factors")
if not corrM.empty:
    heatmap_annot(corrM, "Pearson Correlation — Macros")
if not crossFM.empty:
    heatmap_annot(crossFM, "Pearson Cross-Correlation — Factors × Macros")

splom_factors(XF, "Scatterplot Matrix — Factor Returns")

table_vif(vif_F, "VIF — Factors only")
table_vif(vif_C, "VIF — Combined (Factors + Macros)")

# Factor vs Macro (aligned dates)
max_macros_to_plot = min(6, len(macro_cols))  
for m in macro_cols[:max_macros_to_plot]:
    for f in XF.columns:
        pair = pd.concat([XF[f].rename(f), XM[m].rename(m)], axis=1, join="inner").dropna()
        if len(pair) < 3:
            print(f"[SKIP] {m} vs {f} (too few aligned points)")
            continue
        scatter_with_reg(
            x=pair[f], y=pair[m],
            xname=f, yname=m,
            title=f"{m} vs {f} — with OLS line"
        )


[INFO] Factors: ['Mkt-RF', 'SMB', 'HML', 'RMW', 'CMA']
[INFO] Macros : ['market', 'yield_curve', 'oil ($/bbl)', 'copper ($/metric ton)', 'monetary_policy', 'volatility', 'stock_bond_corr']


In [12]:
!pip install plotly

Collecting plotly
  Downloading plotly-6.3.1-py3-none-any.whl.metadata (8.5 kB)
Collecting narwhals>=1.15.1 (from plotly)
  Downloading narwhals-2.10.0-py3-none-any.whl.metadata (11 kB)
Downloading plotly-6.3.1-py3-none-any.whl (9.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.8/9.8 MB[0m [31m12.8 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hDownloading narwhals-2.10.0-py3-none-any.whl (418 kB)
Installing collected packages: narwhals, plotly
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [plotly]2m1/2[0m [plotly]
[1A[2KSuccessfully installed narwhals-2.10.0 plotly-6.3.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
