In [None]:
# @title Global Multi-Asset Correlation Lab (Bloomberg tickers, robust parsing, residuals, dedup, locked + unlocked animations)
# Complete one-cell notebook.

# =======================
# User Parameters
# =======================
UNIVERSE = "SP500"          # @param ["SP500","NASDAQ100","DOW30","CUSTOM_EQUITY_ONLY"]
TOP_N_EQUITIES = 80         # @param {type:"integer"}
MAX_PER_SECTOR = 15         # @param {type:"integer"}  # cap equities per sector (0 = no cap)
START_DATE = "2022-01-01"   # @param {type:"date"}
END_DATE   = "2025-09-01"   # @param {type:"date"}

TOP_K = 30                  # printed top-K (full window)
USE_ABSOLUTE = False        # rank by |rho| if True

# Residualization
USE_EQUITY_RESIDUALS = True     # regress EQ on market + sector ETF
MARKET_PROXY = "SPY"            # use ACWI/VEA/EEM for broader markets if you prefer
USE_PCA_RESIDUAL_K = 0          # remove first K PCs across all assets AFTER equity residuals

# Animation settings
TOP_RANKS_PER_MONTH = 10        # N bars for both videos
FPS = 30
INTERP_STEPS = 12
SAVE_GIF = True

# Visuals
PALETTE = "tab20"
LABEL_FONTSIZE = 16
VALUE_FONTSIZE = 14
TITLE_FONTSIZE = 20
VIDEO_DPI = 200

# Include groups
INCLUDE_EQUITIES = True
INCLUDE_FX       = True
INCLUDE_CRYPTO   = True
INCLUDE_COMMODS  = True
INCLUDE_ETF      = True
INCLUDE_CUSTOM   = True

# Tickers (Yahoo or Bloomberg-style like "DTE GY"; separate with commas/semicolons)
FX_TICKERS      = "EURUSD=X, GBPUSD=X, USDJPY=X, DXY"
CRYPTO_TICKERS  = "BTC-USD, ETH-USD, SOL-USD"
COMMOD_TICKERS  = "GC=F, SI=F, CL=F, NG=F"
ETF_TICKERS     = "SPY, QQQ, IWM, TLT, HYG, GLD"
CUSTOM_TICKERS  = "TSLA, NVDA, AAPL, MSFT, META, AVGO, LLY, ^VIX, DTE GY, VOD LN"
EQUITY_CUSTOM_TICKERS = "AAPL, MSFT, NVDA, AMZN, GOOGL, META, AVGO, LLY, TSLA, JPM"

# =======================
# Setup
# =======================
!pip -q install yfinance pandas lxml pillow ffmpeg

import pandas as pd, numpy as np, yfinance as yf, re, os, warnings, concurrent.futures, requests, zlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
from matplotlib.patches import Patch
from itertools import combinations
warnings.filterwarnings("ignore")

# ---------- HTTP helpers ----------
def _fetch_wiki_tables(url: str):
    headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120 Safari/537.36"}
    r = requests.get(url, headers=headers, timeout=20); r.raise_for_status()
    return pd.read_html(r.text)

def wiki_equity_universe(universe: str) -> pd.DataFrame:
    uni = universe.upper()
    if uni == "SP500":
        df = _fetch_wiki_tables("https://en.wikipedia.org/wiki/List_of_S%26P_500_companies")[0]
        df = df.rename(columns={"Symbol":"ticker"})
    elif uni == "NASDAQ100":
        tables = _fetch_wiki_tables("https://en.wikipedia.org/wiki/NASDAQ-100")
        tdf = None
        for t in tables:
            cols = [c.lower() for c in t.columns]
            if any(x in cols for x in ["ticker","symbol"]): tdf = t; break
        if tdf is None: raise RuntimeError("No Nasdaq-100 ticker table found.")
        df = tdf.rename(columns={c:"ticker" for c in tdf.columns if c.lower() in ["ticker","symbol"]})
    elif uni == "DOW30":
        tables = _fetch_wiki_tables("https://en.wikipedia.org/wiki/Dow_Jones_Industrial_Average")
        tdf = None
        for t in tables:
            cols = [c.lower() for c in t.columns]
            if any(x in cols for x in ["ticker","symbol"]): tdf = t; break
        if tdf is None: raise RuntimeError("No Dow components table found.")
        df = tdf.rename(columns={c:"ticker" for c in tdf.columns if c.lower() in ["ticker","symbol"]})
    else:
        return pd.DataFrame({"ticker":[]})
    df["ticker"] = df["ticker"].astype(str)
    return df[["ticker"]].drop_duplicates()

# ---------- Bloomberg → Yahoo ----------
BBG_TO_YF = {
    "GY":"DE","GR":"DE","DE":"DE", "LN":"L","L":"L", "FP":"PA","PA":"PA", "NA":"AS","AS":"AS",
    "IM":"MI","MI":"MI", "SM":"MC","MC":"MC", "SW":"SW","VX":"SW", "HK":"HK", "JP":"T","T":"T",
    "AU":"AX","AX":"AX", "CA":"TO","TO":"TO", "BR":"SA","BZ":"SA", "KS":"KS","KQ":"KQ",
    "CH":"SS","CN":"SS", "TW":"TW",
}

def parse_list(s: str):
    """
    Split ONLY on commas/semicolons; keep internal spaces so 'DTE GY' stays intact.
    """
    if not s: return []
    parts = re.split(r"[;,]+", s)
    parts = [re.sub(r"\s+", " ", p).strip() for p in parts]
    return [p for p in parts if p]

TICKER_ALIASES = {"DXY":"DX=F","^DXY":"DX=F","USDOLLAR":"DX=F"}
def apply_aliases(symbols): return [TICKER_ALIASES.get(s, s) for s in symbols]

def normalize_symbol_to_yahoo(sym: str):
    """
    Convert Bloomberg-style 'NAME EX' to Yahoo 'NAME.SUF'.
    Returns None for a bare exchange token (e.g., 'GY', 'LN').
    Passes through strings that already look like Yahoo tickers.
    """
    s = (sym or "").strip().replace("·","-").upper()
    if not s: return None
    if s in BBG_TO_YF:  # bare exchange code → drop
        return None
    if "." in s or s.endswith(("=X","-USD","^VIX","=F")):
        return s
    m = re.match(r"^([A-Z0-9\-]+)[\s\.]([A-Z]{1,3})$", s)
    if m:
        base, ex = m.group(1), m.group(2).upper()
        suf = BBG_TO_YF.get(ex)
        if suf:
            out = f"{base}.{suf}".replace(".B","-B").replace(".A","-A")
            return out
        return base  # unknown exchange code → best effort
    return s

def normalize_list_to_yahoo(symbols):
    out=[]
    for raw in symbols:
        y = normalize_symbol_to_yahoo(raw)
        if not y: continue
        if y in BBG_TO_YF: continue
        out.append(y)
    return out

# ---------- Data helpers ----------
def get_market_cap(ticker: str) -> float:
    try:
        t = yf.Ticker(ticker)
        mc = getattr(t.fast_info, "market_cap", None)
        if mc and mc>0: return float(mc)
        info = t.info or {}; mc = info.get("marketCap", np.nan)
        if mc and mc>0: return float(mc)
    except Exception: pass
    return np.nan

def get_sector_industry(ticker: str):
    try:
        info = yf.Ticker(ticker).info or {}
        return info.get("sector","Unknown") or "Unknown", info.get("industry","Unknown") or "Unknown"
    except Exception: return "Unknown","Unknown"

def rank_by_market_cap(tickers: list, max_workers: int=20) -> pd.DataFrame:
    out=[]
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs={ex.submit(get_market_cap, tk): tk for tk in tickers}
        for fut in concurrent.futures.as_completed(futs):
            tk=futs[fut]
            try: mc=fut.result()
            except Exception: mc=np.nan
            out.append({"ticker":tk,"market_cap":mc})
    return pd.DataFrame(out).sort_values("market_cap", ascending=False, na_position="last").reset_index(drop=True)

def filter_symbols_with_prices(symbols, sample_period="5d"):
    ok=[]
    for s in symbols:
        if s in BBG_TO_YF:  # skip bare exchange codes defensively
            continue
        try:
            df = yf.Ticker(s).history(period=sample_period, auto_adjust=True)
            if not df.empty: ok.append(s)
        except Exception: pass
    return ok or symbols

def download_prices(tickers: list, start: str, end: str) -> pd.DataFrame:
    data = yf.download(tickers=tickers, start=start, end=end, auto_adjust=True,
                       progress=False, group_by="ticker", threads=True, interval="1d")
    if isinstance(data.columns, pd.MultiIndex):
        closes=[]
        for tk in tickers:
            try: closes.append(data[tk]["Close"].rename(tk))
            except Exception: pass
        prices = pd.concat(closes, axis=1) if closes else pd.DataFrame()
    else:
        prices = data
    return prices.dropna(how="all").sort_index()

# ---------- Pair helpers ----------
def base_symbol(label: str): return label.split(":", 1)[1] if ":" in label else label
def drop_same_underlying_pairs(df_pairs: pd.DataFrame, col1="col_1", col2="col_2"):
    if df_pairs.empty: return df_pairs
    mask = df_pairs.apply(lambda r: base_symbol(r[col1]) != base_symbol(r[col2]), axis=1)
    return df_pairs[mask].reset_index(drop=True)

# ---------- Correlations ----------
def top_k_correlations(corr: pd.DataFrame, k: int=20, use_abs: bool=False) -> pd.DataFrame:
    m=corr.copy(); np.fill_diagonal(m.values, np.nan)
    upper=m.where(np.triu(np.ones(m.shape),1).astype(bool))
    df=upper.stack().reset_index(); df.columns=["col_1","col_2","correlation"]
    df = drop_same_underlying_pairs(df, "col_1", "col_2")
    key=df["correlation"].abs() if use_abs else df["correlation"]
    return df.assign(_rk=key).sort_values("_rk", ascending=False).drop(columns="_rk").head(k)

def monthly_all_corrs(returns: pd.DataFrame) -> dict:
    out={}
    for p in returns.index.to_period("M").unique().sort_values():
        r=returns.loc[returns.index.to_period("M")==p]
        if r.shape[0] >= 5: out[p.to_timestamp()] = r.corr()
    return out

def monthly_top_pairs(returns: pd.DataFrame, top_n: int=10, use_abs: bool=False) -> pd.DataFrame:
    frames=[]
    for p in returns.index.to_period("M").unique().sort_values():
        r=returns.loc[returns.index.to_period("M")==p]
        if r.shape[0]<5: continue
        pairs=top_k_correlations(r.corr(), k=top_n, use_abs=use_abs)
        pairs["month"]=p.to_timestamp(); frames.append(pairs)
    if not frames: return pd.DataFrame(columns=["month","col_1","col_2","correlation"])
    out = pd.concat(frames, ignore_index=True)
    out = drop_same_underlying_pairs(out, "col_1", "col_2")
    return out[["month","col_1","col_2","correlation"]]

# ---------- Groups & heatmap ----------
def get_label_group(label: str): return label.split(":",1)[0] if ":" in label else "OTHER"
def get_pair_group(pair_label: str):
    a,b = pair_label.split("–",1)
    return "×".join(sorted([get_label_group(a), get_label_group(b)]))
def by_group(label):
    order = {"EQ":0,"ETF":1,"FX":2,"CRYPTO":3,"CMDTY":4,"CUSTOM":5,"OTHER":6}
    return (order.get(get_label_group(label),9), label)

def plot_correlation_heatmap(corr: pd.DataFrame, col_order, groups, outfile="/content/correlation_heatmap.png"):
    plt.rcParams.update({"font.size": 12})
    M = corr.loc[col_order, col_order].values
    n=len(col_order); fig, ax = plt.subplots(figsize=(min(20, 2+0.35*n), min(20, 2+0.35*n)))
    im=ax.imshow(M, vmin=-1, vmax=1, cmap="coolwarm", interpolation="nearest")
    ax.set_xticks(range(n)); ax.set_xticklabels(col_order, rotation=90, fontsize=10)
    ax.set_yticks(range(n)); ax.set_yticklabels(col_order, fontsize=10)
    fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02, label="Correlation")
    pos=0
    for gname, members in groups.items():
        if not members: continue
        m=len(members)
        ax.add_patch(plt.Rectangle((pos-0.5,pos-0.5), m, m, fill=False, ec="k", lw=1.5, alpha=0.9))
        ax.text(pos+m/2-0.5, -1.3, gname, ha="center", va="top", fontsize=12, fontweight="bold")
        ax.text(-1.3, pos+m/2-0.5, gname, ha="right", va="center", fontsize=12, fontweight="bold", rotation=90)
        pos+=m
    fig.tight_layout(); fig.savefig(outfile, dpi=240, bbox_inches="tight"); plt.close(fig)

# ---------- Animations ----------
def ease_in_out_sine(t): return 0.5*(1-np.cos(np.pi*t))

# (A) LOCKED: final-month Top-N set
def animate_monthly_race_locked_final_group_colors(
    monthly_corrs: dict, locked_labels: list, out_mp4: str, out_gif: str=None,
    fps: int=30, interp_steps: int=12, title_prefix="Top-10 Correlations (Locked Final Set)",
    palette="tab20", label_fs=16, value_fs=14, title_fs=20, dpi=200
):
    if not monthly_corrs: print("No monthly data for animation."); return
    months = sorted(monthly_corrs.keys())

    def value_for_pair(cdf, pair):
        a,b = pair.split("–",1)
        try: return float(cdf.loc[a,b])
        except Exception: return 0.0

    month_vals=[]; month_ranks=[]
    for m in months:
        cdf = monthly_corrs[m]
        vals = [value_for_pair(cdf, L) for L in locked_labels]
        month_vals.append(np.array(vals, dtype=float))
        key = np.abs(vals) if USE_ABSOLUTE else vals
        order = np.argsort(-np.array(key), kind="mergesort")
        ranks = np.empty_like(order); ranks[order]=np.arange(1, len(locked_labels)+1)
        month_ranks.append(ranks.astype(float))

    frames_vals, frames_ranks, frames_midx = [], [], []
    for i in range(len(months)-1):
        v0, v1 = month_vals[i], month_vals[i+1]
        r0, r1 = month_ranks[i], month_ranks[i+1]
        frames_vals.append(v0); frames_ranks.append(r0); frames_midx.append(i)
        for t in range(1, interp_steps+1):
            a = ease_in_out_sine(t/(interp_steps+0.00001))
            frames_vals.append((1-a)*v0 + a*v1)
            frames_ranks.append((1-a)*r0 + a*r1)
            frames_midx.append(i + (t/(interp_steps+1)))
    frames_vals.append(month_vals[-1]); frames_ranks.append(month_ranks[-1]); frames_midx.append(len(months)-1)

    # Colors by pair-group
    pair_groups = [get_pair_group(L) for L in locked_labels]
    uniq_groups = sorted(set(pair_groups))
    cmap = plt.get_cmap(palette, len(uniq_groups)); group_to_color={g:cmap(i) for i,g in enumerate(uniq_groups)}
    colors = [group_to_color[g] for g in pair_groups]
    legend_handles = [Patch(color=group_to_color[g], label=g) for g in uniq_groups]

    max_abs = float(np.nanmax(np.abs(np.vstack(frames_vals)))) or 1.0
    xlim = (-max_abs, max_abs)

    plt.rcParams.update({"font.size": label_fs})
    fig, ax = plt.subplots(figsize=(16, 9))
    ax.set_xlim(*xlim); ax.axvline(0, color="k", lw=1, alpha=0.8)
    ax.grid(True, axis="x", linestyle="--", linewidth=0.6, alpha=0.5)
    ax.legend(handles=legend_handles, loc="lower right", frameon=True, fontsize=12)
    ax.set_title(f"{title_prefix} — {months[0].strftime('%Y-%m')}", fontsize=title_fs, pad=14)

    order0 = np.argsort(frames_ranks[0]); y = np.arange(len(locked_labels))
    bars=[]; vals0 = frames_vals[0]
    for idx, pos in zip(order0, y):
        L=locked_labels[idx]; v=vals0[idx]
        bars.append(ax.barh(pos, v, color=colors[idx])[0])
        ax.text(xlim[0]-0.02*(xlim[1]-xlim[0]), pos, L, va="center", ha="left",
                fontsize=label_fs, fontweight="bold")
    value_texts=[ax.text(v + (0.01 if v>=0 else -0.01), p, f"{v:+.2f}", va="center",
                         ha=("left" if v>=0 else "right"), fontsize=VALUE_FONTSIZE)
                 for v,p in zip(vals0, y)]
    ax.set_yticks([])

    def draw(i):
        vals = frames_vals[i]; ranks = frames_ranks[i]
        order = np.argsort(ranks); target_y = np.empty_like(ranks); target_y[order]=np.arange(len(locked_labels))
        for j in range(len(locked_labels)):
            v=float(vals[j]); yy=float(target_y[j])
            bars[j].set_width(v); bars[j].set_y(yy - bars[j].get_height()/2)
            value_texts[j].set_position((v + (0.01 if v>=0 else -0.01), yy))
            value_texts[j].set_text(f"{v:+.2f}")
            value_texts[j].set_ha("left" if v>=0 else "right")
        m_idx = int(round(frames_midx[i]))
        ax.set_title(f"{title_prefix} — {months[m_idx].strftime('%Y-%m')}", fontsize=title_fs)
        return bars + value_texts

    anim = FuncAnimation(fig, draw, frames=len(frames_vals), blit=False, repeat=False)
    try:
        writer = FFMpegWriter(fps=fps, bitrate=6000); anim.save(out_mp4, writer=writer, dpi=dpi)
        print(f"Saved (LOCKED): {out_mp4}")
    except Exception as e:
        print("FFmpeg not available, skipping MP4:", e)
    finally: plt.close(fig)
    if out_gif:
        try:
            writer = PillowWriter(fps=max(10, fps//2)); anim.save(out_gif, writer=writer, dpi=160)
            print(f"Saved GIF (LOCKED): {out_gif}")
        except Exception as e:
            print("GIF encode failed:", e)

# (B) UNLOCKED — each month’s Top-N; stable colors per PAIR across all months
def animate_monthly_race_unlocked_dynamic(
    monthly_pairs_df: pd.DataFrame, out_mp4: str, out_gif: str=None,
    top_n: int=10, fps: int=30, interp_steps: int=12,
    title_prefix="Top-10 Correlations (Unlocked per Month)",
    palette_base=("tab20","tab20b","tab20c"),
    label_fs=16, value_fs=14, title_fs=20, dpi=200, show_legend=False
):
    import matplotlib.colors as mcolors

    if monthly_pairs_df.empty:
        print("No monthly pairs for dynamic animation."); return
    months = sorted(monthly_pairs_df["month"].unique())

    # Global union of pairs → stable color per pair
    def mk_label(r): return f"{r['col_1']}–{r['col_2']}"
    all_pairs = sorted(set(mk_label(r) for _, r in monthly_pairs_df.iterrows()))

    # Build large palette and deterministic color mapping
    palette_colors=[]
    for pname in palette_base:
        cmap = plt.get_cmap(pname, 20)
        palette_colors.extend([cmap(i) for i in range(cmap.N)])
    palette_len = len(palette_colors)
    def color_for_pair(L):
        h = zlib.adler32(L.encode("utf-8"))
        if palette_len>0:
            return palette_colors[h % palette_len]
        hue = (h % 360) / 360.0
        return mcolors.hsv_to_rgb((hue, 0.65, 0.9))
    pair_to_color = {L: color_for_pair(L) for L in all_pairs}

    # Per-month maps
    month_maps=[]
    for m in months:
        d = monthly_pairs_df[monthly_pairs_df["month"]==m]
        month_maps.append({mk_label(r): float(r["correlation"]) for _, r in d.iterrows()})

    # Global xlim
    all_vals=[v for mp in month_maps for v in mp.values()]
    if not all_vals: print("No values for animation."); return
    max_abs=float(np.nanmax(np.abs(all_vals))) or 1.0
    xlim=(-max_abs, max_abs)

    # Frames (tweened)
    frames=[]
    for i in range(len(months)-1):
        cur, nxt = month_maps[i], month_maps[i+1]
        union = set(cur.keys()) | set(nxt.keys())
        frames.append(("static", i, cur))
        for t in range(1, interp_STEPS:=interp_steps+0):  # local alias not to shadow
            pass  # just to avoid accidental name reuse
        for t in range(1, interp_steps+1):
            a = ease_in_out_sine(t/(interp_steps+0.00001))
            vals = {L: (1-a)*cur.get(L, 0.0) + a*nxt.get(L, 0.0) for L in union}
            frames.append(("tween", i, vals))
    frames.append(("static", len(months)-1, month_maps[-1]))

    # Animate
    plt.rcParams.update({"font.size": label_fs})
    fig, ax = plt.subplots(figsize=(16, 9))
    ax.set_xlim(*xlim); ax.axvline(0, color="k", lw=1, alpha=0.8)
    ax.grid(True, axis="x", linestyle="--", linewidth=0.6, alpha=0.5)

    def draw(frame_idx):
        mode, month_idx, vals_map = frames[frame_idx]
        items = list(vals_map.items())
        keyfn = (lambda kv: abs(kv[1])) if USE_ABSOLUTE else (lambda kv: kv[1])
        items.sort(key=keyfn, reverse=True)
        top = items[:top_n]
        labels = [kv[0] for kv in top]
        vals   = [kv[1] for kv in top]

        ax.clear()
        ax.set_xlim(*xlim); ax.axvline(0, color="k", lw=1, alpha=0.8)
        ax.grid(True, axis="x", linestyle="--", linewidth=0.6, alpha=0.5)
        ax.set_title(f"{title_prefix} — {months[month_idx].strftime('%Y-%m')}", fontsize=title_fs, pad=14)

        colors = [pair_to_color[L] for L in labels]
        y = np.arange(len(labels))
        bars = ax.barh(y, vals, color=colors)
        ax.set_yticks([])

        for b, L, v in zip(bars, labels, vals):
            y0 = b.get_y()+b.get_height()/2
            ax.text(xlim[0]-0.02*(xlim[1]-xlim[0]), y0, L, va="center", ha="left",
                    fontsize=label_fs, fontweight="bold")
            ax.text(v + (0.01 if v>=0 else -0.01), y0, f"{v:+.2f}", va="center",
                    ha=("left" if v>=0 else "right"), fontsize=VALUE_FONTSIZE)

        if show_legend and labels:
            leg_handles = [Patch(color=pair_to_color[L], label=L) for L in labels]
            ax.legend(handles=leg_handles, loc="lower right", frameon=True, fontsize=10)

        return bars

    anim = FuncAnimation(fig, draw, frames=len(frames), blit=False, repeat=False)
    try:
        writer = FFMpegWriter(fps=fps, bitrate=6000); anim.save(out_mp4, writer=writer, dpi=dpi)
        print(f"Saved (UNLOCKED, stable colors): {out_mp4}")
    except Exception as e:
        print("FFmpeg unavailable, skipping MP4:", e)
    finally: plt.close(fig)
    if out_gif:
        try:
            writer = PillowWriter(fps=max(10, fps//2)); anim.save(out_gif, writer=writer, dpi=160)
            print(f"Saved GIF (UNLOCKED, stable colors): {out_gif}")
        except Exception as e:
            print("GIF encode failed:", e)

# ---------- Residualization ----------
SECTOR_ETFS = {
    "Communication Services":"XLC", "Consumer Discretionary":"XLY", "Consumer Staples":"XLP",
    "Energy":"XLE", "Financial Services":"XLF","Financials":"XLF", "Health Care":"XLV",
    "Industrials":"XLI", "Information Technology":"XLK", "Materials":"XLB",
    "Real Estate":"XLRE", "Utilities":"XLU",
}
def regress_residual(y, X):
    X_ = np.column_stack([np.ones(len(X)), X])
    beta, *_ = np.linalg.lstsq(X_, y, rcond=None)
    return y - X_ @ beta
def residualize_equity_returns(R, label_to_sym):
    helper = set([MARKET_PROXY]); eq_labels=[c for c in R.columns if c.startswith("EQ:")]
    sector_by_label={}; needed=set()
    for lab in eq_labels:
        sec,_ = get_sector_industry(label_to_sym[lab]); sector_by_label[lab]=sec
        etf = SECTOR_ETFS.get(sec);
        if etf: needed.add(etf)
    syms = sorted(helper | needed)
    Hret = pd.DataFrame(index=R.index)
    if syms:
        H = download_prices(syms, START_DATE, END_DATE).dropna(axis=0, how="any")
        Hret = np.log(H).diff().dropna(how="any")
    R2 = R.copy()
    for lab in eq_labels:
        y = R2[lab].dropna(); Xcols=[]
        if MARKET_PROXY in Hret.columns:
            Xcols.append(Hret[MARKET_PROXY].reindex(y.index).fillna(0.0).values)
        etf = SECTOR_ETFS.get(sector_by_label.get(lab,""))
        if etf and etf in Hret.columns:
            Xcols.append(Hret[etf].reindex(y.index).fillna(0.0).values)
        if Xcols:
            resid = regress_residual(y.values, np.column_stack(Xcols))
            R2.loc[y.index, lab] = resid
    return R2
def pca_residualize(R, k=0):
    if k<=0: return R
    Z = R.fillna(0.0).values
    Z = (Z - Z.mean(axis=0)) / (Z.std(axis=0)+1e-12)
    U,S,Vt = np.linalg.svd(Z, full_matrices=False)
    F = U[:, :k] * S[:k]; V = Vt[:k,:]
    Z_res = Z - F @ V
    Z_res = Z_res * (R.std(axis=0).values + 1e-12) + R.mean(axis=0).values
    return pd.DataFrame(Z_res, index=R.index, columns=R.columns)

# ---------- Duplicate equity handling ----------
DUPLICATE_CORR_THRESHOLD = 0.995
PREFERRED_SHARECLASS = {"GOOGL":"GOOG", "BRK-B":"BRK-B", "BRK.B":"BRK-B"}
def drop_near_duplicate_assets(R, label_to_sym, market_caps):
    labels=[c for c in R.columns if c.startswith("EQ:")]
    if len(labels)<2: return R, label_to_sym
    Ceq = R[labels].corr().abs()
    to_drop=set(); cap_by = {lab: market_caps.get(label_to_sym[lab], np.nan) for lab in labels}
    for i,j in combinations(labels,2):
        v=Ceq.loc[i,j]
        if pd.notna(v) and v>=DUPLICATE_CORR_THRESHOLD:
            ci, cj = cap_by.get(i,np.nan), cap_by.get(j,np.nan)
            if pd.notna(ci) and pd.notna(cj):
                keep, drop = ((i,j) if ci>=cj else (j,i))
            else:
                bi, bj = base_symbol(i), base_symbol(j)
                if bi in PREFERRED_SHARECLASS and PREFERRED_SHARECLASS[bi] in {bi,bj}:
                    keep, drop = (i,j) if bi==PREFERRED_SHARECLASS[bi] else (j,i)
                elif bj in PREFERRED_SHARECLASS and PREFERRED_SHARECLASS[bj] in {bi,bj}:
                    keep, drop = (i,j) if bj==PREFERRED_SHARECLASS[bj] else (j,i)
                else:
                    keep, drop = ((i,j) if len(i)<=len(j) else (j,i))
            to_drop.add(drop)
    if to_drop:
        R = R.drop(columns=list(to_drop), errors="ignore")
        for d in to_drop: label_to_sym.pop(d, None)
    return R, label_to_sym

# =======================
# Pipeline
# =======================
print("Building instrument universe...")

# 1) Equities universe
equity_list = []
if INCLUDE_EQUITIES:
    if UNIVERSE == "CUSTOM_EQUITY_ONLY":
        equity_list = normalize_list_to_yahoo(apply_aliases(parse_list(EQUITY_CUSTOM_TICKERS)))
    else:
        eq_df = wiki_equity_universe(UNIVERSE)
        raw = [t.upper().replace(".","-") for t in eq_df["ticker"].astype(str).tolist()]
        ranked = rank_by_market_cap(raw).dropna(subset=["market_cap"])
        equity_list = ranked["ticker"].head(TOP_N_EQUITIES).tolist()

# 2) Other groups (accept Bloomberg-style)
fx_list     = normalize_list_to_yahoo(apply_aliases(parse_list(FX_TICKERS)))      if INCLUDE_FX else []
crypto_list = normalize_list_to_yahoo(apply_aliases(parse_list(CRYPTO_TICKERS)))  if INCLUDE_CRYPTO else []
cmd_list    = normalize_list_to_yahoo(apply_aliases(parse_list(COMMOD_TICKERS)))  if INCLUDE_COMMODS else []
etf_list    = normalize_list_to_yahoo(apply_aliases(parse_list(ETF_TICKERS)))     if INCLUDE_ETF else []
cus_list    = normalize_list_to_yahoo(apply_aliases(parse_list(CUSTOM_TICKERS)))  if INCLUDE_CUSTOM else []

# 3) Label mapping & de-dup by underlying symbol (priority: EQ→ETF→FX→CRYPTO→CMDTY→CUSTOM)
grouped = [("EQ",equity_list),("ETF",etf_list),("FX",fx_list),("CRYPTO",crypto_list),("CMDTY",cmd_list),("CUSTOM",cus_list)]
label_to_sym={}; seen=set()
for g, syms in grouped:
    for s in syms:
        if s in seen: continue
        label_to_sym[f"{g}:{s}"]=s; seen.add(s)

all_labels=list(label_to_sym.keys()); all_syms=list(label_to_sym.values())
if len(all_syms)<3: raise RuntimeError("Universe too small after selection.")

# 4) Pre-filter & download
all_syms = filter_symbols_with_prices(all_syms, sample_period="5d")
print(f"Downloading {len(all_syms)} instruments {START_DATE} → {END_DATE} ...")
P_raw = download_prices(all_syms, START_DATE, END_DATE)
present_syms=[s for s in all_syms if s in P_raw.columns]
sym2lab={v:k for k,v in label_to_sym.items()}
P = P_raw[present_syms].dropna(axis=0, how="any").rename(columns=sym2lab)
if P.shape[1] < 3: raise RuntimeError("Not enough instruments with data.")

# 5) Sector cap (proxy for GICS depth)
if MAX_PER_SECTOR and INCLUDE_EQUITIES:
    eq_labels=[c for c in P.columns if c.startswith("EQ:")]
    sec_map={lab:get_sector_industry(label_to_sym[lab])[0] for lab in eq_labels}
    caps = {s:get_market_cap(s) for s in present_syms}
    keep=[]
    for sec in sorted(set(sec_map.values())):
        labs=[l for l in eq_labels if sec_map[l]==sec]
        labs_sorted=sorted(labs, key=lambda l: caps.get(base_symbol(l), np.nan), reverse=True)
        keep += labs_sorted[:MAX_PER_SECTOR]
    drop=[l for l in eq_labels if l not in set(keep)]
    if drop:
        P = P.drop(columns=drop, errors="ignore")
        for d in drop: label_to_sym.pop(d, None)

# 6) Returns
R = np.log(P).diff().dropna(how="any")

# 7) Drop near-duplicate equity series
caps = {s:get_market_cap(s) for s in present_syms}
R, label_to_sym = drop_near_duplicate_assets(R, label_to_sym, caps)

# 8) Residuals
if USE_EQUITY_RESIDUALS: R = residualize_equity_returns(R, label_to_sym)
if USE_PCA_RESIDUAL_K>0: R = pca_residualize(R, k=USE_PCA_RESIDUAL_K)

# 9) Correlation matrix
C = R.corr()

print("\n=== Full-Window Summary ===")
print(f"Date window: {START_DATE} → {END_DATE}")
print(f"Instruments with data: {R.shape[1]}; Observations: {R.shape[0]} days")

display(top_k_correlations(C, k=TOP_K, use_abs=USE_ABSOLUTE))

# 10) Heatmap
cols_order = sorted(C.columns.tolist(), key=by_group)
groups = {g:[c for c in cols_order if c.startswith(g+":")] for g in ["EQ","ETF","FX","CRYPTO","CMDTY","CUSTOM","OTHER"]}
os.makedirs("/content", exist_ok=True)
plot_correlation_heatmap(C, cols_order, groups, outfile="/content/correlation_heatmap.png")
print("Saved heatmap → /content/correlation_heatmap.png")

# 11) Monthly modules
monthly_pairs = monthly_top_pairs(R, top_n=max(TOP_RANKS_PER_MONTH, TOP_K), use_abs=USE_ABSOLUTE)
monthly_corrs = monthly_all_corrs(R)

if monthly_corrs:
    # (A) LOCKED: final-month set
    last_m = max(monthly_corrs.keys())
    final_df = monthly_pairs[monthly_pairs["month"]==last_m].copy().sort_values(
        "correlation", key=(lambda s: s.abs() if USE_ABSOLUTE else s), ascending=False
    ).head(TOP_RANKS_PER_MONTH)
    locked_labels = (final_df["col_1"] + "–" + final_df["col_2"]).astype(str).tolist()

    # Start/End PNGs on locked set (true monthly values)
    first_m = min(monthly_corrs.keys())
    def corr_for_month_pair(m, L):
        a,b = L.split("–",1); cdf = monthly_corrs.get(m)
        try: return float(cdf.loc[a,b])
        except Exception: return 0.0
    start_vals=[corr_for_month_pair(first_m, L) for L in locked_labels]
    end_vals  =[corr_for_month_pair(last_m,  L) for L in locked_labels]
    lim = max(0.5, float(np.nanmax(np.abs(start_vals + end_vals)))) if (start_vals and end_vals) else 1.0
    xlim = (-lim, lim)
    def pair_group(L):
        a,b = L.split("–",1)
        return "×".join(sorted([get_label_group(a), get_label_group(b)]))
    pgrps=[pair_group(L) for L in locked_labels]; uniq=sorted(set(pgrps))
    cmap = plt.get_cmap(PALETTE, len(uniq)); group_to_color={g:cmap(i) for i,g in enumerate(uniq)}
    bar_colors=[group_to_color[g] for g in pgrps]
    legend_handles=[Patch(color=group_to_color[g], label=g) for g in uniq]

    def save_bar_chart(fname, title, labels, values, xlim, colors, legend_handles):
        plt.rcParams.update({"font.size": LABEL_FONTSIZE})
        fig, ax = plt.subplots(figsize=(12, 8))
        y=np.arange(len(labels)); bars=ax.barh(y, values, color=colors)
        ax.set_yticks([]); ax.set_title(title, fontsize=TITLE_FONTSIZE, pad=12)
        ax.set_xlim(*xlim); ax.axvline(0, color="k", lw=1)
        ax.grid(True, axis="x", linestyle="--", lw=0.6, alpha=0.5)
        ax.legend(handles=legend_handles, loc="lower right", frameon=True, fontsize=12)
        for b,L,v in zip(bars, labels, values):
            y0=b.get_y()+b.get_height()/2
            ax.text(xlim[0]-0.02*(xlim[1]-xlim[0]), y0, L, va="center", ha="left",
                    fontsize=LABEL_FONTSIZE, fontweight="bold")
            ax.text(v + (0.01 if v>=0 else -0.01), y0, f"{v:+.2f}", va="center",
                    ha=("left" if v>=0 else "right"), fontsize=VALUE_FONTSIZE)
        fig.tight_layout(); fig.savefig(fname, dpi=220); plt.close(fig)

    save_bar_chart("/content/top10_start_month.png",
                   f"Top-10 Correlations — Start ({first_m.strftime('%Y-%m')}) [final set]",
                   locked_labels, start_vals, xlim, bar_colors, legend_handles)
    save_bar_chart("/content/top10_end_month.png",
                   f"Top-10 Correlations — End ({last_m.strftime('%Y-%m')}) [final set]",
                   locked_labels, end_vals, xlim, bar_colors, legend_handles)

    animate_monthly_race_locked_final_group_colors(
        monthly_corrs, locked_labels,
        out_mp4="/content/top10_correlations_race_LOCKED.mp4",
        out_gif="/content/top10_correlations_race_LOCKED.gif" if SAVE_GIF else None,
        fps=FPS, interp_steps=INTERP_STEPS, title_prefix="Top-10 Correlations (Locked Final Set)",
        palette=PALETTE, label_fs=LABEL_FONTSIZE, value_fs=VALUE_FONTSIZE, title_fs=TITLE_FONTSIZE, dpi=VIDEO_DPI
    )

    # (B) UNLOCKED: each month’s own Top-N with stable pair colors
    animate_monthly_race_unlocked_dynamic(
        monthly_pairs,
        out_mp4="/content/top10_correlations_race_UNLOCKED.mp4",
        out_gif="/content/top10_correlations_race_UNLOCKED.gif" if SAVE_GIF else None,
        top_n=TOP_RANKS_PER_MONTH, fps=FPS, interp_steps=INTERP_STEPS,
        title_prefix="Top-10 Correlations (Unlocked per Month)",
        palette_base=("tab20","tab20b","tab20c"),
        label_fs=LABEL_FONTSIZE, value_fs=VALUE_FONTSIZE, title_fs=TITLE_FONTSIZE, dpi=VIDEO_DPI,
        show_legend=False
    )
else:
    print("Monthly module: not enough data to build animations.")

# 12) Save outputs
P.to_csv("/content/prices_adjusted_multi.csv")
R.to_csv("/content/returns_log_multi.csv")
C.to_csv("/content/correlation_matrix_multi.csv")
top_k_correlations(C, k=TOP_K, use_abs=USE_ABSOLUTE).to_csv("/content/top_k_correlations_multi.csv", index=False)

print("\nFiles in /content:")
print(" - prices_adjusted_multi.csv")
print(" - returns_log_multi.csv")
print(" - correlation_matrix_multi.csv")
print(" - top_k_correlations_multi.csv")
print(" - correlation_heatmap.png")
print(" - top10_start_month.png (LOCKED set)")
print(" - top10_end_month.png   (LOCKED set)")
print(" - top10_correlations_race_LOCKED.mp4 / .gif")
print(" - top10_correlations_race_UNLOCKED.mp4 / .gif")


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for ffmpeg (setup.py) ... [?25l[?25hdone
Building instrument universe...
Downloading 99 instruments 2022-01-01 → 2025-09-01 ...

=== Full-Window Summary ===
Date window: 2022-01-01 → 2025-09-01
Instruments with data: 92; Observations: 346 days


Unnamed: 0,col_1,col_2,correlation
3996,ETF:SPY,ETF:QQQ,0.96243
4088,ETF:GLD,CMDTY:GC=F,0.936018
1602,EQ:HD,EQ:LOW,0.852054
3997,ETF:SPY,ETF:IWM,0.836179
4141,CRYPTO:BTC-USD,CRYPTO:ETH-USD,0.800489
3999,ETF:SPY,ETF:HYG,0.790013
4095,FX:EURUSD=X,FX:GBPUSD=X,0.787169
4034,ETF:IWM,ETF:HYG,0.780123
4142,CRYPTO:BTC-USD,CRYPTO:SOL-USD,0.758929
4015,ETF:QQQ,ETF:IWM,0.75659


Saved heatmap → /content/correlation_heatmap.png
Saved (LOCKED): /content/top10_correlations_race_LOCKED.mp4
Saved GIF (LOCKED): /content/top10_correlations_race_LOCKED.gif
Saved (UNLOCKED, stable colors): /content/top10_correlations_race_UNLOCKED.mp4
Saved GIF (UNLOCKED, stable colors): /content/top10_correlations_race_UNLOCKED.gif

Files in /content:
 - prices_adjusted_multi.csv
 - returns_log_multi.csv
 - correlation_matrix_multi.csv
 - top_k_correlations_multi.csv
 - correlation_heatmap.png
 - top10_start_month.png (LOCKED set)
 - top10_end_month.png   (LOCKED set)
 - top10_correlations_race_LOCKED.mp4 / .gif
 - top10_correlations_race_UNLOCKED.mp4 / .gif
