#  Crash-Regime LightGBM Training
## Renaissance Trading Bot -- Regime-Specific Models

**What this does:**
- Downloads BTC + ETH candles, macro data (SPX, VIX, DXY), Binance derivatives
- Filters to crash periods only (2018, 2021-22, 2025-26)
- Engineers 40 crash-specific features across 4 groups
- Trains LightGBM on crash-only data with macro features
- Saves everything to Google Drive (survives runtime disconnects)

**Key insight:** BTC has 0.77-0.90 correlation with S&P 500 during crashes.
Our current models miss this signal entirely, achieving only 51% accuracy.
Training on crash-only data with macro features should improve this.

**Runtime:** Select GPU (T4) for faster training, though LightGBM trains in ~30 seconds either way.

## Cell 1: Setup & Google Drive Mount
Mount Drive first so all outputs survive runtime disconnects.

In [None]:
# ============================================================
# CELL 1: Setup & Google Drive Mount
# ============================================================
# Mount Drive FIRST -- all outputs save here to survive disconnects

from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/content/data', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)

DRIVE_SAVE = '/content/drive/MyDrive/renaissance-bot-training/crash_models/'
os.makedirs(DRIVE_SAVE, exist_ok=True)

!pip install -q yfinance lightgbm pandas numpy scikit-learn

import numpy as np
import pandas as pd
import requests
import time
from datetime import datetime, timedelta

print("\u2705 Setup complete")
print(f"Drive save path: {DRIVE_SAVE}")

## Cell 2: Download BTC + ETH 5-Minute Candles from Binance
Full history from Sep 2017. Covers all three crash periods. Takes 5-10 minutes.

In [None]:
# ============================================================
# CELL 2: Download BTC + ETH 5m candles from Binance
# ============================================================
# Need crash periods:
#   Crash 1: Jan 2018 -- Dec 2018
#   Crash 2: Nov 2021 -- Nov 2022
#   Crash 3: Oct 2025 -- Feb 2026 (current)

def fetch_binance_klines(symbol, interval, start_ms, end_ms, limit=1000):
    """Fetch klines from Binance public API. No auth needed."""
    url = "https://api.binance.com/api/v3/klines"
    all_data = []
    current = start_ms
    retries = 0

    while current < end_ms:
        params = {
            'symbol': symbol,
            'interval': interval,
            'startTime': current,
            'endTime': end_ms,
            'limit': limit,
        }
        try:
            resp = requests.get(url, params=params, timeout=30)
            if resp.status_code == 429:
                print("  Rate limited, waiting 60s...")
                time.sleep(60)
                continue
            if resp.status_code != 200:
                print(f"  HTTP {resp.status_code}, retrying...")
                retries += 1
                if retries > 5:
                    break
                time.sleep(5)
                continue

            data = resp.json()
            if not data or not isinstance(data, list):
                break

            all_data.extend(data)
            current = data[-1][0] + 1
            retries = 0

            if len(all_data) % 50000 == 0:
                print(f"  ... {len(all_data):,} candles so far")

            if len(data) < limit:
                break

            time.sleep(0.1)

        except Exception as e:
            print(f"  Error: {e}, retrying...")
            retries += 1
            if retries > 5:
                break
            time.sleep(5)

    return all_data

def klines_to_df(raw):
    """Convert Binance klines to clean DataFrame."""
    df = pd.DataFrame(raw, columns=[
        'open_time', 'open', 'high', 'low', 'close', 'volume',
        'close_time', 'quote_volume', 'trades', 'taker_buy_base',
        'taker_buy_quote', 'ignore'
    ])
    for col in ['open', 'high', 'low', 'close', 'volume', 'quote_volume',
                'taker_buy_base', 'taker_buy_quote']:
        df[col] = df[col].astype(float)
    df['trades'] = df['trades'].astype(int)
    df['timestamp'] = pd.to_datetime(df['open_time'], unit='ms')
    df = df.drop_duplicates(subset=['open_time']).sort_values('timestamp').reset_index(drop=True)
    return df

start_ms = int(datetime(2017, 9, 1).timestamp() * 1000)
end_ms = int(datetime.utcnow().timestamp() * 1000)

# Download BTC
print("Downloading BTC 5m candles from Binance (this takes 5-10 minutes)...")
raw = fetch_binance_klines('BTCUSDT', '5m', start_ms, end_ms)
btc_5m = klines_to_df(raw)
btc_5m.to_csv('/content/data/btc_5m_full.csv', index=False)
print(f"\u2705 BTC: {len(btc_5m):,} candles ({btc_5m['timestamp'].min().date()} \u2192 {btc_5m['timestamp'].max().date()})")

# Download ETH (for cross-asset features)
print("\nDownloading ETH 5m candles...")
raw = fetch_binance_klines('ETHUSDT', '5m', start_ms, end_ms)
eth_5m = klines_to_df(raw)
eth_5m.to_csv('/content/data/eth_5m_full.csv', index=False)
print(f"\u2705 ETH: {len(eth_5m):,} candles ({eth_5m['timestamp'].min().date()} \u2192 {eth_5m['timestamp'].max().date()})")

# Save to Drive as backup
btc_5m.to_csv(f'{DRIVE_SAVE}/btc_5m_full.csv', index=False)
eth_5m.to_csv(f'{DRIVE_SAVE}/eth_5m_full.csv', index=False)
print("\n\u2705 Backed up to Drive")

## Cell 3: Download Macro Data (Daily + 5-Minute Intraday)

In [None]:
# ============================================================
# CELL 3: Download macro data -- daily + 5-minute intraday
# ============================================================
# BTC has 0.77-0.90 correlation with S&P 500 in crash periods.
# Daily macro covers full history; intraday only last ~60 days (yfinance limit).

import yfinance as yf

# --- Part A: Daily macro (full history, same as v1) ---
macro_tickers = {
    'spx': '^GSPC',       # S&P 500
    'ndx': '^IXIC',       # Nasdaq Composite
    'vix': '^VIX',        # CBOE Volatility Index
    'dxy': 'DX-Y.NYB',    # US Dollar Index
    'us10y': '^TNX',      # 10-Year Treasury Yield
    'gold': 'GC=F',       # Gold Futures
}

macro_daily = {}
for name, ticker in macro_tickers.items():
    print(f"Downloading daily {name} ({ticker})...")
    try:
        data = yf.download(ticker, start='2017-01-01', progress=False)
        if len(data) > 0:
            if isinstance(data.columns, pd.MultiIndex):
                data.columns = data.columns.get_level_values(0)
            macro_daily[name] = data[['Close']].rename(columns={'Close': name})
            print(f"  [OK] {name}: {len(data):,} days ({data.index.min().date()} to {data.index.max().date()})")
        else:
            print(f"  [WARN] {name}: no data returned")
    except Exception as e:
        print(f"  [ERR] {name}: {e}")

macro_df = pd.concat(macro_daily.values(), axis=1)
macro_df.index = pd.to_datetime(macro_df.index)
macro_df = macro_df.ffill().bfill()
macro_df.to_csv('/content/data/macro_daily.csv')
print(f"\n[OK] Daily macro: {len(macro_df):,} days, columns: {list(macro_df.columns)}")

# --- Part B: 5-minute intraday macro (last ~60 days only) ---
# yfinance allows 5m data for the last 60 days.
# This gives us real-time SPX/VIX/NDX at BTC candle resolution
# for the most recent crash period (crash 3).

intraday_tickers = {
    'spx_5m': '^GSPC',
    'vix_5m': '^VIX',
    'ndx_5m': '^IXIC',
}

intraday_frames = {}
for name, ticker in intraday_tickers.items():
    print(f"Downloading 5m intraday {name} ({ticker})...")
    try:
        data = yf.download(ticker, period='60d', interval='5m', progress=False)
        if len(data) > 0:
            if isinstance(data.columns, pd.MultiIndex):
                data.columns = data.columns.get_level_values(0)
            col_name = name.replace('_5m', '_intraday')
            intraday_frames[col_name] = data[['Close']].rename(columns={'Close': col_name})
            print(f"  [OK] {name}: {len(data):,} bars ({data.index.min()} to {data.index.max()})")
        else:
            print(f"  [WARN] {name}: no data returned")
    except Exception as e:
        print(f"  [ERR] {name}: {e}")

if intraday_frames:
    intraday_df = pd.concat(intraday_frames.values(), axis=1)
    intraday_df.index = pd.to_datetime(intraday_df.index).tz_localize(None)
    intraday_df = intraday_df.ffill()
    intraday_df.to_csv('/content/data/macro_intraday_5m.csv')
    print(f"\n[OK] Intraday macro: {len(intraday_df):,} bars, columns: {list(intraday_df.columns)}")
    print(f"     Date range: {intraday_df.index.min()} to {intraday_df.index.max()}")
else:
    intraday_df = None
    print("\n[WARN] No intraday macro data available")

# --- Part C: Fear & Greed Index (daily) ---
print("\nDownloading Fear & Greed Index...")
try:
    fng_resp = requests.get("https://api.alternative.me/fng/?limit=0", timeout=30)
    fng_data = fng_resp.json()['data']
    fng_df = pd.DataFrame(fng_data)
    fng_df['timestamp'] = pd.to_datetime(fng_df['timestamp'].astype(int), unit='s')
    fng_df['fng'] = fng_df['value'].astype(int)
    fng_df = fng_df[['timestamp', 'fng']].set_index('timestamp').sort_index()
    fng_df.to_csv('/content/data/fear_greed.csv')
    print(f"  [OK] Fear & Greed: {len(fng_df):,} days")
except Exception as e:
    print(f"  [ERR] Fear & Greed: {e}")
    fng_df = None

# Save to Drive
macro_df.to_csv(f'{DRIVE_SAVE}/macro_daily.csv')
if intraday_df is not None:
    intraday_df.to_csv(f'{DRIVE_SAVE}/macro_intraday_5m.csv')
print("[OK] Backed up to Drive")


## Cell 4: Download Binance Derivatives Data (1h Period)

In [None]:
# ============================================================
# CELL 4: Download Binance Futures derivatives data (1h period)
# ============================================================
# Funding rate, Open Interest, Long/Short ratio, Taker buy/sell.
# IMPORTANT: OI/LS/Taker endpoints only retain ~30 days for 5m period.
# Using 1h period gives ~180 days of history -- much better coverage.
# Funding rate is always 8-hourly (no period param needed).

def fetch_binance_futures(endpoint, symbol, period=None, limit=500,
                          start_ms=None, end_ms=None, max_records=50000):
    """Fetch data from Binance Futures API with robust pagination."""
    base = "https://fapi.binance.com"
    all_data = []
    current = start_ms
    retries = 0

    while True:
        params = {'symbol': symbol, 'limit': limit}
        if period and 'fundingRate' not in endpoint:
            params['period'] = period
        if current:
            params['startTime'] = current
        if end_ms:
            params['endTime'] = end_ms

        try:
            resp = requests.get(f"{base}{endpoint}", params=params, timeout=30)
            if resp.status_code == 429:
                print("    Rate limited, waiting 60s...")
                time.sleep(60)
                continue
            if resp.status_code != 200:
                retries += 1
                if retries > 5:
                    print(f"    Giving up after {retries} retries (status {resp.status_code})")
                    break
                time.sleep(5)
                continue

            data = resp.json()
            if not data or not isinstance(data, list):
                break

            all_data.extend(data)
            retries = 0

            # Find timestamp key for pagination
            last_row = data[-1]
            if 'fundingTime' in last_row:
                current = last_row['fundingTime'] + 1
            elif 'timestamp' in last_row:
                current = last_row['timestamp'] + 1
            else:
                break

            if len(data) < limit:
                break

            if len(all_data) >= max_records:
                print(f"    Hit max_records limit ({max_records})")
                break

            if len(all_data) % 5000 == 0:
                print(f"    ... {len(all_data):,} records so far")

            time.sleep(0.3)  # Slightly slower to avoid 429

        except Exception as e:
            retries += 1
            if retries > 5:
                print(f"    Giving up after {retries} retries: {e}")
                break
            time.sleep(5)

    return all_data


# Focus on last 180 days for derivatives (1h period has good coverage)
end_ms = int(datetime.utcnow().timestamp() * 1000)
start_180d = int((datetime.utcnow() - timedelta(days=180)).timestamp() * 1000)
# Funding rate available from 2019, use full range
start_funding = int(datetime(2019, 9, 1).timestamp() * 1000)

derivatives = {}

# 1. Funding Rate (8-hourly, no period param)
print("Downloading BTC funding rate (8h intervals, full history)...")
raw = fetch_binance_futures('/fapi/v1/fundingRate', 'BTCUSDT',
                            start_ms=start_funding, end_ms=end_ms)
if raw:
    fr_df = pd.DataFrame(raw)
    fr_df['timestamp'] = pd.to_datetime(fr_df['fundingTime'], unit='ms')
    fr_df['funding_rate'] = fr_df['fundingRate'].astype(float)
    derivatives['funding_rate'] = fr_df[['timestamp', 'funding_rate']]
    fr_df[['timestamp', 'funding_rate']].to_csv('/content/data/btc_funding_rate.csv', index=False)
    print(f"  [OK] Funding rate: {len(fr_df):,} records ({fr_df['timestamp'].min().date()} to {fr_df['timestamp'].max().date()})")
else:
    print("  [WARN] No funding rate data")

# 2. Open Interest (1h period -- ~180 days available)
print("Downloading BTC open interest (1h, last 180 days)...")
raw = fetch_binance_futures('/futures/data/openInterestHist', 'BTCUSDT',
                            period='1h', start_ms=start_180d, end_ms=end_ms)
if raw:
    oi_df = pd.DataFrame(raw)
    oi_df['timestamp'] = pd.to_datetime(oi_df['timestamp'], unit='ms')
    oi_df['open_interest'] = oi_df['sumOpenInterest'].astype(float)
    oi_df['oi_value'] = oi_df['sumOpenInterestValue'].astype(float)
    derivatives['open_interest'] = oi_df[['timestamp', 'open_interest', 'oi_value']]
    oi_df[['timestamp', 'open_interest', 'oi_value']].to_csv('/content/data/btc_open_interest.csv', index=False)
    print(f"  [OK] Open interest: {len(oi_df):,} records ({oi_df['timestamp'].min().date()} to {oi_df['timestamp'].max().date()})")
else:
    print("  [WARN] No open interest data")

# 3. Long/Short Ratio (1h period)
print("Downloading BTC long/short ratio (1h, last 180 days)...")
raw = fetch_binance_futures('/futures/data/globalLongShortAccountRatio',
                            'BTCUSDT', period='1h',
                            start_ms=start_180d, end_ms=end_ms)
if raw:
    ls_df = pd.DataFrame(raw)
    ls_df['timestamp'] = pd.to_datetime(ls_df['timestamp'], unit='ms')
    ls_df['long_short_ratio'] = ls_df['longShortRatio'].astype(float)
    ls_df['long_account'] = ls_df['longAccount'].astype(float)
    ls_df['short_account'] = ls_df['shortAccount'].astype(float)
    derivatives['long_short'] = ls_df[['timestamp', 'long_short_ratio',
                                        'long_account', 'short_account']]
    ls_df[['timestamp', 'long_short_ratio', 'long_account', 'short_account']].to_csv(
        '/content/data/btc_long_short.csv', index=False)
    print(f"  [OK] Long/short ratio: {len(ls_df):,} records ({ls_df['timestamp'].min().date()} to {ls_df['timestamp'].max().date()})")
else:
    print("  [WARN] No long/short data")

# 4. Taker Buy/Sell Volume (1h period)
print("Downloading BTC taker buy/sell volume (1h, last 180 days)...")
raw = fetch_binance_futures('/futures/data/takeBuySellVol', 'BTCUSDT',
                            period='1h', start_ms=start_180d, end_ms=end_ms)
if raw:
    tv_df = pd.DataFrame(raw)
    tv_df['timestamp'] = pd.to_datetime(tv_df['timestamp'], unit='ms')
    tv_df['taker_buy_vol'] = tv_df['buyVol'].astype(float)
    tv_df['taker_sell_vol'] = tv_df['sellVol'].astype(float)
    tv_df['taker_ratio'] = tv_df['taker_buy_vol'] / (tv_df['taker_sell_vol'] + 1e-10)
    derivatives['taker_vol'] = tv_df[['timestamp', 'taker_buy_vol',
                                       'taker_sell_vol', 'taker_ratio']]
    tv_df[['timestamp', 'taker_buy_vol', 'taker_sell_vol', 'taker_ratio']].to_csv(
        '/content/data/btc_taker_vol.csv', index=False)
    print(f"  [OK] Taker volume: {len(tv_df):,} records ({tv_df['timestamp'].min().date()} to {tv_df['timestamp'].max().date()})")
else:
    print("  [WARN] No taker volume data")

print(f"\n[OK] Derivatives data complete. Got {len(derivatives)} datasets.")
for name, df in derivatives.items():
    print(f"  {name}: {len(df):,} records")


## Cell 5: Label Crash Periods & Merge All Data
Merge BTC 5m candles with macro (daily), derivatives (variable freq via merge_asof), ETH cross-asset, and Fear & Greed. Filter to crash periods only.

In [None]:
# ============================================================
# CELL 5: Label crash periods and merge everything
# ============================================================

# Load BTC 5m
btc = pd.read_csv('/content/data/btc_5m_full.csv')
btc['timestamp'] = pd.to_datetime(btc['timestamp'])
btc = btc.sort_values('timestamp').reset_index(drop=True)

# -- Define crash periods --
# Crash 1: Post-2017 bubble, Jan 2018 peak -> Dec 2018 bottom
# Crash 2: Post-2021 bubble, Nov 2021 ATH $69K -> Nov 2022 bottom $15.5K
# Crash 3: Post-2025 bubble, Oct 2025 ATH $126K -> ongoing (~$65K)
CRASH_PERIODS = [
    ('2018-01-07', '2018-12-15'),
    ('2021-11-10', '2022-11-21'),
    ('2025-10-06', '2026-02-28'),
]

btc['is_crash'] = False
for start, end in CRASH_PERIODS:
    mask = (btc['timestamp'] >= start) & (btc['timestamp'] <= end)
    btc.loc[mask, 'is_crash'] = True

crash_data = btc[btc['is_crash']].copy()
print(f"Total BTC candles: {len(btc):,}")
print(f"Crash candles: {len(crash_data):,} ({100*len(crash_data)/len(btc):.1f}%)")
for start, end in CRASH_PERIODS:
    n = len(btc[(btc['timestamp'] >= start) & (btc['timestamp'] <= end)])
    print(f"  {start} to {end}: {n:,} candles")

# -- Merge daily macro (via date join) --
macro = pd.read_csv('/content/data/macro_daily.csv', index_col=0, parse_dates=True)
crash_data['date'] = crash_data['timestamp'].dt.strftime('%Y-%m-%d')
macro['date'] = macro.index.strftime('%Y-%m-%d')
crash_data = crash_data.merge(macro, on='date', how='left')
for col in ['spx', 'ndx', 'vix', 'dxy', 'us10y', 'gold']:
    if col in crash_data.columns:
        crash_data[col] = crash_data[col].ffill().bfill()
print(f"\n[OK] Merged daily macro data")

# -- Merge 5-minute intraday macro (NEW in v2) --
# Only available for last ~60 days (crash 3 period).
# Rows without intraday data get NaN -- features will be 0-filled.
has_intraday_macro = False
try:
    intraday = pd.read_csv('/content/data/macro_intraday_5m.csv', index_col=0, parse_dates=True)
    intraday.index = pd.to_datetime(intraday.index)
    intraday = intraday.sort_index()
    crash_data_sorted = crash_data.sort_values('timestamp').reset_index(drop=True)
    crash_data_sorted['timestamp'] = pd.to_datetime(crash_data_sorted['timestamp'])
    crash_data = pd.merge_asof(
        crash_data_sorted, intraday.reset_index().rename(columns={'index': 'timestamp'}),
        on='timestamp', direction='backward',
        tolerance=pd.Timedelta('10min')  # Allow 10min tolerance for market hours gaps
    )
    for col in ['spx_intraday', 'vix_intraday', 'ndx_intraday']:
        if col in crash_data.columns:
            filled = crash_data[col].notna().sum()
            print(f"  [OK] {col}: {filled:,}/{len(crash_data):,} rows filled")
    has_intraday_macro = True
    print(f"[OK] Merged intraday macro (5-min resolution)")
except FileNotFoundError:
    print("[WARN] No intraday macro file found -- features will be zero-filled")
    crash_data['spx_intraday'] = np.nan
    crash_data['vix_intraday'] = np.nan
    crash_data['ndx_intraday'] = np.nan
except Exception as e:
    print(f"[WARN] Intraday macro merge failed: {e}")
    crash_data['spx_intraday'] = np.nan
    crash_data['vix_intraday'] = np.nan
    crash_data['ndx_intraday'] = np.nan

# -- Merge Fear & Greed (daily -> 5m via date join) --
try:
    fng = pd.read_csv('/content/data/fear_greed.csv', parse_dates=['timestamp'])
    fng['date'] = fng['timestamp'].dt.strftime('%Y-%m-%d')
    crash_data = crash_data.merge(fng[['date', 'fng']], on='date', how='left')
    crash_data['fng'] = crash_data['fng'].ffill().bfill().fillna(50)
    print(f"[OK] Merged Fear & Greed")
except Exception as e:
    crash_data['fng'] = 50
    print(f"[WARN] Fear & Greed failed, using default 50: {e}")

# -- Merge derivatives (variable freq -> 5m via merge_asof) --
crash_data = crash_data.sort_values('timestamp').reset_index(drop=True)

deriv_files = {
    'funding_rate': ('/content/data/btc_funding_rate.csv', ['funding_rate']),
    'open_interest': ('/content/data/btc_open_interest.csv', ['open_interest', 'oi_value']),
    'long_short': ('/content/data/btc_long_short.csv', ['long_short_ratio', 'long_account', 'short_account']),
    'taker_vol': ('/content/data/btc_taker_vol.csv', ['taker_buy_vol', 'taker_sell_vol', 'taker_ratio']),
}

for name, (path, cols) in deriv_files.items():
    try:
        deriv = pd.read_csv(path, parse_dates=['timestamp'])
        deriv = deriv.sort_values('timestamp').reset_index(drop=True)
        crash_data = pd.merge_asof(
            crash_data, deriv[['timestamp'] + cols],
            on='timestamp', direction='backward',
            tolerance=pd.Timedelta('8h')  # Funding rate is 8-hourly
        )
        filled = crash_data[cols[0]].notna().sum()
        print(f"[OK] Merged {name}: {filled:,}/{len(crash_data):,} rows filled")
    except Exception as e:
        for col in cols:
            crash_data[col] = np.nan
        print(f"[WARN] {name} merge failed: {e}")

# -- Merge ETH cross-asset data --
try:
    eth = pd.read_csv('/content/data/eth_5m_full.csv')
    eth['timestamp'] = pd.to_datetime(eth['timestamp'])
    eth = eth.sort_values('timestamp').reset_index(drop=True)
    eth_merge = eth[['timestamp', 'close', 'volume']].rename(
        columns={'close': 'eth_close', 'volume': 'eth_volume'}
    )
    crash_data = pd.merge_asof(
        crash_data, eth_merge,
        on='timestamp', direction='backward',
        tolerance=pd.Timedelta('5min')
    )
    filled = crash_data['eth_close'].notna().sum()
    print(f"[OK] Merged ETH: {filled:,}/{len(crash_data):,} rows filled")
except Exception as e:
    crash_data['eth_close'] = np.nan
    crash_data['eth_volume'] = np.nan
    print(f"[WARN] ETH merge failed: {e}")

# -- Clean and save --
crash_data = crash_data.dropna(subset=['close']).reset_index(drop=True)
crash_data.to_csv('/content/data/crash_dataset_raw.csv', index=False)
crash_data.to_csv(f'{DRIVE_SAVE}/crash_dataset_raw.csv', index=False)

print(f"\n[OK] Crash dataset: {len(crash_data):,} rows, {len(crash_data.columns)} columns")
print(f"has_intraday_macro = {has_intraday_macro}")
print(f"\nColumn overview:")
for col in sorted(crash_data.columns):
    non_null = crash_data[col].notna().sum()
    pct = 100 * non_null / len(crash_data)
    print(f"  {col:30s} {non_null:>8,} non-null ({pct:.0f}%)")


## Cell 6: Engineer Crash-Specific Features (51 Features)

In [None]:
# ============================================================
# CELL 6: Build crash-specific features (v2 -- 51 features)
# ============================================================
# 51 features across 5 groups, all scale-invariant
# v2 adds: 11 intraday macro features + 2-bar (10min) primary label

df = crash_data.copy()

# ============================================
# GROUP 1: BTC PRICE & VOLUME (15 features)
# ============================================

# Returns at multiple horizons (relative, not absolute)
df['return_1bar'] = df['close'].pct_change(1)          # 5 min
df['return_6bar'] = df['close'].pct_change(6)          # 30 min
df['return_12bar'] = df['close'].pct_change(12)        # 1 hour
df['return_48bar'] = df['close'].pct_change(48)        # 4 hours
df['return_288bar'] = df['close'].pct_change(288)      # 24 hours

# Volatility (rolling std of returns -- already scale-invariant)
df['vol_12bar'] = df['return_1bar'].rolling(12).std()
df['vol_48bar'] = df['return_1bar'].rolling(48).std()
df['vol_ratio'] = df['vol_12bar'] / (df['vol_48bar'] + 1e-10)

# Volume (relative to own history -- scale-invariant)
df['vol_sma_20'] = df['volume'].rolling(20).mean()
df['volume_surge'] = df['volume'] / (df['vol_sma_20'] + 1e-10)
df['volume_trend'] = df['volume'].rolling(12).mean() / (df['volume'].rolling(48).mean() + 1e-10)

# Consecutive red candles (count)
df['candle_dir'] = (df['close'] > df['open']).astype(int)
groups = (df['candle_dir'] != df['candle_dir'].shift()).cumsum()
df['consecutive_red'] = df.groupby(groups)['candle_dir'].cumcount()
df.loc[df['candle_dir'] == 1, 'consecutive_red'] = 0

# Drawdown from rolling 24h high (relative)
df['rolling_high_24h'] = df['high'].rolling(288).max()
df['drawdown_24h'] = df['close'] / df['rolling_high_24h'] - 1

# RSI (normalized to [-1, 1])
delta = df['close'].diff()
gain = delta.where(delta > 0, 0).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / (loss + 1e-10)
df['rsi_14_norm'] = (100 - (100 / (1 + rs)) - 50) / 50

# Bollinger Band position (already normalized)
bb_mid = df['close'].rolling(20).mean()
bb_std = df['close'].rolling(20).std()
df['bb_pct_b'] = (df['close'] - bb_mid) / (2 * bb_std + 1e-10)

# VWAP distance (relative)
df['session'] = np.arange(len(df)) // 288
vwap = df.groupby('session').apply(
    lambda g: (g['quote_volume'].cumsum() / (g['volume'].cumsum() + 1e-10))
).reset_index(level=0, drop=True)
df['vwap_distance'] = df['close'] / (vwap + 1e-10) - 1

# ============================================
# GROUP 2: MACRO CORRELATION (10 features)
# THE KEY SIGNAL -- BTC follows SPX in crashes
# ============================================

# S&P 500
if 'spx' in df.columns and df['spx'].notna().sum() > 100:
    spx_daily = df.groupby('date')['spx'].last()
    spx_returns = spx_daily.pct_change()
    spx_ret_map = spx_returns.to_dict()
    df['spx_return_1d'] = df['date'].map(spx_ret_map).fillna(0)
    spx_sma = df['spx'].rolling(288 * 5, min_periods=288).mean()
    df['spx_vs_sma'] = df['spx'] / (spx_sma + 1e-10) - 1
else:
    df['spx_return_1d'] = 0.0
    df['spx_vs_sma'] = 0.0

# VIX
if 'vix' in df.columns and df['vix'].notna().sum() > 100:
    df['vix_norm'] = (df['vix'] - 20) / 20
    vix_daily = df.groupby('date')['vix'].last()
    vix_change = vix_daily.pct_change()
    vix_chg_map = vix_change.to_dict()
    df['vix_change'] = df['date'].map(vix_chg_map).fillna(0)
    df['vix_extreme'] = (df['vix'] > 30).astype(float)
else:
    df['vix_norm'] = 0.0
    df['vix_change'] = 0.0
    df['vix_extreme'] = 0.0

# Dollar Index
if 'dxy' in df.columns and df['dxy'].notna().sum() > 100:
    dxy_daily = df.groupby('date')['dxy'].last()
    dxy_returns = dxy_daily.pct_change()
    dxy_ret_map = dxy_returns.to_dict()
    df['dxy_return_1d'] = df['date'].map(dxy_ret_map).fillna(0)
    dxy_sma = df['dxy'].rolling(288 * 20, min_periods=288).mean()
    df['dxy_trend'] = df['dxy'] / (dxy_sma + 1e-10) - 1
else:
    df['dxy_return_1d'] = 0.0
    df['dxy_trend'] = 0.0

# Treasury Yields
if 'us10y' in df.columns and df['us10y'].notna().sum() > 100:
    df['yield_level'] = (df['us10y'] - 3.0) / 2.0
    yield_daily = df.groupby('date')['us10y'].last()
    yield_diff = yield_daily.diff()
    yield_diff_map = yield_diff.to_dict()
    df['yield_change'] = df['date'].map(yield_diff_map).fillna(0)
else:
    df['yield_level'] = 0.0
    df['yield_change'] = 0.0

# Fear & Greed
if 'fng' in df.columns and df['fng'].notna().sum() > 100:
    df['fng_norm'] = (df['fng'] - 50) / 50
else:
    df['fng_norm'] = 0.0

# ============================================
# GROUP 2B: INTRADAY MACRO (11 features, NEW in v2)
# Only populated for last ~60 days (crash 3)
# Older crash periods get 0-filled
# ============================================

# SPX intraday features
if 'spx_intraday' in df.columns and df['spx_intraday'].notna().sum() > 100:
    df['spx_return_5m'] = df['spx_intraday'].pct_change(1).fillna(0)
    df['spx_return_30m'] = df['spx_intraday'].pct_change(6).fillna(0)
    df['spx_return_1h'] = df['spx_intraday'].pct_change(12).fillna(0)
else:
    df['spx_return_5m'] = 0.0
    df['spx_return_30m'] = 0.0
    df['spx_return_1h'] = 0.0

# VIX intraday features
if 'vix_intraday' in df.columns and df['vix_intraday'].notna().sum() > 100:
    df['vix_return_5m'] = df['vix_intraday'].pct_change(1).fillna(0)
    df['vix_return_30m'] = df['vix_intraday'].pct_change(6).fillna(0)
    df['vix_level_5m'] = ((df['vix_intraday'] - 20) / 20).fillna(0)
else:
    df['vix_return_5m'] = 0.0
    df['vix_return_30m'] = 0.0
    df['vix_level_5m'] = 0.0

# NDX (Nasdaq) intraday features
if 'ndx_intraday' in df.columns and df['ndx_intraday'].notna().sum() > 100:
    df['ndx_return_5m'] = df['ndx_intraday'].pct_change(1).fillna(0)
    df['ndx_return_30m'] = df['ndx_intraday'].pct_change(6).fillna(0)
else:
    df['ndx_return_5m'] = 0.0
    df['ndx_return_30m'] = 0.0

# Cross-market signals from intraday data
# SPX-VIX divergence: SPX up + VIX up = warning signal
df['spx_vix_diverge'] = (df['spx_return_5m'] * df['vix_return_5m']).fillna(0)

# Macro momentum composites
df['macro_momentum_5m'] = ((df['spx_return_5m'] + df['ndx_return_5m']) / 2).fillna(0)
df['macro_momentum_30m'] = ((df['spx_return_30m'] + df['ndx_return_30m']) / 2).fillna(0)

# ============================================
# GROUP 3: DERIVATIVES (9 features)
# ============================================

# Funding rate
if 'funding_rate' in df.columns and df['funding_rate'].notna().sum() > 100:
    df['funding_rate'] = df['funding_rate'].ffill().fillna(0)
    fr_mean = df['funding_rate'].rolling(288 * 7, min_periods=288).mean()
    fr_std = df['funding_rate'].rolling(288 * 7, min_periods=288).std()
    df['funding_z'] = (df['funding_rate'] - fr_mean) / (fr_std + 1e-10)
    df['funding_extreme_long'] = (df['funding_rate'] > 0.01).astype(float)
    df['funding_extreme_short'] = (df['funding_rate'] < -0.01).astype(float)
else:
    df['funding_z'] = 0.0
    df['funding_extreme_long'] = 0.0
    df['funding_extreme_short'] = 0.0

# Open Interest
if 'oi_value' in df.columns and df['oi_value'].notna().sum() > 100:
    df['oi_value'] = df['oi_value'].ffill().bfill()
    df['oi_change_1h'] = df['oi_value'].pct_change(12)
    df['oi_change_4h'] = df['oi_value'].pct_change(48)
    df['oi_spike'] = (df['oi_change_1h'].abs() > 0.05).astype(float)
else:
    df['oi_change_1h'] = 0.0
    df['oi_change_4h'] = 0.0
    df['oi_spike'] = 0.0

# Long/Short Ratio
if 'long_short_ratio' in df.columns and df['long_short_ratio'].notna().sum() > 100:
    df['ls_ratio_norm'] = df['long_short_ratio'].ffill().fillna(1.0) - 1.0
    df['ls_extreme_long'] = (df['long_short_ratio'] > 2.0).astype(float)
else:
    df['ls_ratio_norm'] = 0.0
    df['ls_extreme_long'] = 0.0

# Taker Buy/Sell
if 'taker_ratio' in df.columns and df['taker_ratio'].notna().sum() > 100:
    df['taker_imbalance'] = df['taker_ratio'].ffill().fillna(1.0) - 1.0
else:
    df['taker_imbalance'] = 0.0

# ============================================
# GROUP 4: CROSS-ASSET (6 features)
# ============================================

if 'eth_close' in df.columns and df['eth_close'].notna().sum() > 100:
    df['eth_return_1bar'] = df['eth_close'].pct_change(1)
    df['eth_return_6bar'] = df['eth_close'].pct_change(6)
    df['eth_btc_ratio'] = df['eth_close'] / (df['close'] + 1e-10)
    df['eth_btc_ratio_change'] = df['eth_btc_ratio'].pct_change(12)
    df['btc_lead_1'] = df['return_1bar'].shift(1)
    df['btc_lead_2'] = df['return_1bar'].shift(2)
    df['btc_lead_3'] = df['return_1bar'].shift(3)
else:
    df['eth_return_1bar'] = 0.0
    df['eth_return_6bar'] = 0.0
    df['eth_btc_ratio_change'] = 0.0
    df['btc_lead_1'] = 0.0
    df['btc_lead_2'] = 0.0
    df['btc_lead_3'] = 0.0

# ============================================
# LABELS: Multiple horizons (v2 change)
# Primary: 2-bar (10 min) -- matches Polymarket 15-min windows
# Also compute 1-bar (5 min) and 6-bar (30 min) for comparison
# ============================================

# Primary label: 2-bar forward return (10 min ahead)
df['forward_return_2'] = df['close'].shift(-2) / df['close'] - 1
df['label_binary'] = (df['forward_return_2'] > 0).astype(int)  # 1=up, 0=down
df['label_soft'] = np.tanh(df['forward_return_2'] * 100)

# Comparison labels (trained separately in Cell 7)
df['forward_return_1'] = df['close'].shift(-1) / df['close'] - 1
df['forward_return_6'] = df['close'].shift(-6) / df['close'] - 1
df['label_binary_1bar'] = (df['forward_return_1'] > 0).astype(int)
df['label_binary_6bar'] = (df['forward_return_6'] > 0).astype(int)

# ============================================
# FEATURE LIST (51 features)
# ============================================

FEATURE_COLS = [
    # BTC price/volume (15)
    'return_1bar', 'return_6bar', 'return_12bar', 'return_48bar', 'return_288bar',
    'vol_12bar', 'vol_48bar', 'vol_ratio',
    'volume_surge', 'volume_trend',
    'consecutive_red', 'drawdown_24h',
    'rsi_14_norm', 'bb_pct_b', 'vwap_distance',
    # Daily macro (10)
    'spx_return_1d', 'spx_vs_sma',
    'vix_norm', 'vix_change', 'vix_extreme',
    'dxy_return_1d', 'dxy_trend',
    'yield_level', 'yield_change',
    'fng_norm',
    # Intraday macro (11) -- NEW in v2
    'spx_return_5m', 'spx_return_30m', 'spx_return_1h',
    'vix_return_5m', 'vix_return_30m', 'vix_level_5m',
    'ndx_return_5m', 'ndx_return_30m',
    'spx_vix_diverge',
    'macro_momentum_5m', 'macro_momentum_30m',
    # Derivatives (9)
    'funding_z', 'funding_extreme_long', 'funding_extreme_short',
    'oi_change_1h', 'oi_change_4h', 'oi_spike',
    'ls_ratio_norm', 'ls_extreme_long',
    'taker_imbalance',
    # Cross-asset (6)
    'eth_return_1bar', 'eth_return_6bar', 'eth_btc_ratio_change',
    'btc_lead_1', 'btc_lead_2', 'btc_lead_3',
]

assert len(FEATURE_COLS) == 51, f"Expected 51 features, got {len(FEATURE_COLS)}"

# Replace inf with nan, then drop incomplete rows
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna(subset=FEATURE_COLS + ['label_binary']).reset_index(drop=True)

# Count how many rows have real intraday macro data
intraday_cols = ['spx_return_5m', 'spx_return_30m', 'spx_return_1h',
                 'vix_return_5m', 'vix_return_30m', 'vix_level_5m',
                 'ndx_return_5m', 'ndx_return_30m']
has_real_intraday = (df[intraday_cols].abs().sum(axis=1) > 0).sum()

print(f"\n[OK] Feature engineering complete (v2)")
print(f"Clean dataset: {len(df):,} rows with {len(FEATURE_COLS)} features")
print(f"Label balance (2-bar): {df['label_binary'].mean():.3f} (1=up)")
print(f"Rows with real intraday macro: {has_real_intraday:,} ({100*has_real_intraday/len(df):.1f}%)")
print(f"\nFeature groups:")
print(f"  BTC price/volume:   15 features")
print(f"  Daily macro:        10 features")
print(f"  Intraday macro:     11 features (NEW)")
print(f"  Derivatives:         9 features")
print(f"  Cross-asset:         6 features")
print(f"  Total:              51 features")

# Save feature dataset to Drive
save_cols = FEATURE_COLS + ['timestamp', 'close', 'label_binary', 'label_soft',
                             'forward_return_2', 'forward_return_1', 'forward_return_6',
                             'label_binary_1bar', 'label_binary_6bar']
df[save_cols].to_csv(f'{DRIVE_SAVE}/crash_features_v2.csv', index=False)
print(f"[OK] Feature dataset saved to Drive")


## Cell 7: Walk-Forward Split & Train LightGBM (Multi-Horizon)

In [None]:
# ============================================================
# CELL 7: Train crash-regime LightGBM (v2 -- multi-horizon)
# ============================================================
# Trains 3 models at different horizons, plus an intraday-only model.
# Primary: 2-bar (10 min), Comparison: 1-bar (5 min), 6-bar (30 min)

import lightgbm as lgb
from sklearn.metrics import accuracy_score, roc_auc_score
import pickle
import json
import shutil

# -- Split by crash period (walk-forward) --
crash1 = df[df['timestamp'] < '2019-01-01']
crash2 = df[(df['timestamp'] >= '2021-11-01') & (df['timestamp'] < '2023-01-01')]
crash3 = df[df['timestamp'] >= '2025-10-01']

print(f"Crash 1 (2018):    {len(crash1):>8,} rows")
print(f"Crash 2 (2021-22): {len(crash2):>8,} rows")
print(f"Crash 3 (2025-26): {len(crash3):>8,} rows")

# Train on crash 1 + 2, validate/test on crash 3
train_data = pd.concat([crash1, crash2])

if len(crash3) > 1000:
    split_idx = len(crash3) // 2
    val_data = crash3.iloc[:split_idx]
    test_data = crash3.iloc[split_idx:]
else:
    split_70 = int(len(crash2) * 0.7)
    split_85 = int(len(crash2) * 0.85)
    train_data = pd.concat([crash1, crash2.iloc[:split_70]])
    val_data = crash2.iloc[split_70:split_85]
    test_data = crash2.iloc[split_85:]

# Sample weights: crash 1 gets 0.5x weight (older data)
train_weights = np.ones(len(train_data))
crash1_mask = train_data['timestamp'] < '2019-01-01'
train_weights[crash1_mask.values] = 0.5

X_train = train_data[FEATURE_COLS].values
X_val = val_data[FEATURE_COLS].values
X_test = test_data[FEATURE_COLS].values

print(f"\nTrain: {len(X_train):,} | Val: {len(X_val):,} | Test: {len(X_test):,}")

# LightGBM params (shared across all horizons)
params = {
    'objective': 'binary',
    'metric': 'binary_logloss',
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'min_child_samples': 100,
    'lambda_l1': 0.1,
    'lambda_l2': 0.1,
    'verbose': -1,
}

# ============================================
# TRAIN MODELS AT 3 HORIZONS
# ============================================

horizon_configs = [
    ('2bar_10min', 'label_binary',      'PRIMARY (10 min)'),
    ('1bar_5min',  'label_binary_1bar',  'comparison (5 min)'),
    ('6bar_30min', 'label_binary_6bar',  'comparison (30 min)'),
]

results = {}

for horizon_name, label_col, description in horizon_configs:
    print(f"\n{'='*60}")
    print(f"Training: {horizon_name} -- {description}")
    print(f"{'='*60}")

    y_train = train_data[label_col].values
    y_val = val_data[label_col].values
    y_test = test_data[label_col].values

    print(f"Train up%: {y_train.mean():.3f} | Val up%: {y_val.mean():.3f} | Test up%: {y_test.mean():.3f}")

    train_set = lgb.Dataset(X_train, label=y_train, weight=train_weights,
                             feature_name=FEATURE_COLS)
    val_set = lgb.Dataset(X_val, label=y_val, feature_name=FEATURE_COLS)

    model = lgb.train(
        params,
        train_set,
        num_boost_round=500,
        valid_sets=[val_set],
        valid_names=['val'],
        callbacks=[lgb.early_stopping(30), lgb.log_evaluation(50)],
    )

    val_probs = model.predict(X_val)
    test_probs = model.predict(X_test)

    val_acc = accuracy_score(y_val, (val_probs > 0.5).astype(int))
    test_acc = accuracy_score(y_test, (test_probs > 0.5).astype(int))
    val_auc = roc_auc_score(y_val, val_probs)
    test_auc = roc_auc_score(y_test, test_probs)
    pred_std = np.std(test_probs)

    results[horizon_name] = {
        'model': model,
        'val_acc': val_acc,
        'test_acc': test_acc,
        'val_auc': val_auc,
        'test_auc': test_auc,
        'pred_std': pred_std,
        'test_probs': test_probs,
        'y_test': y_test,
        'best_round': model.best_iteration,
    }

    print(f"\n  Val acc:  {val_acc:.4f} ({val_acc*100:.1f}%)")
    print(f"  Test acc: {test_acc:.4f} ({test_acc*100:.1f}%)")
    print(f"  Val AUC:  {val_auc:.4f}")
    print(f"  Test AUC: {test_auc:.4f}")
    print(f"  Pred std: {pred_std:.4f}")
    print(f"  Best round: {model.best_iteration}")

# ============================================
# HORIZON COMPARISON TABLE
# ============================================
print(f"\n\n{'='*60}")
print(f"HORIZON COMPARISON")
print(f"{'='*60}")
print(f"{'Horizon':<16} {'Val Acc':>8} {'Test Acc':>9} {'Val AUC':>8} {'Test AUC':>9} {'Pred Std':>9}")
print(f"{'-'*60}")
for name, r in results.items():
    marker = " <-- PRIMARY" if '2bar' in name else ""
    print(f"{name:<16} {r['val_acc']:>8.4f} {r['test_acc']:>9.4f} {r['val_auc']:>8.4f} {r['test_auc']:>9.4f} {r['pred_std']:>9.4f}{marker}")

# ============================================
# INTRADAY-ONLY MODEL (crash 3 only, all 51 features)
# Uses only rows where intraday macro is populated
# ============================================
print(f"\n\n{'='*60}")
print(f"Training: INTRADAY-ONLY model (crash 3 subset)")
print(f"{'='*60}")

intraday_cols = ['spx_return_5m', 'spx_return_30m', 'spx_return_1h',
                 'vix_return_5m', 'vix_return_30m', 'vix_level_5m',
                 'ndx_return_5m', 'ndx_return_30m']

# Filter to rows with real intraday data
df_intraday = crash3[crash3[intraday_cols].abs().sum(axis=1) > 0].copy()

if len(df_intraday) >= 500:
    split_70 = int(len(df_intraday) * 0.7)
    split_85 = int(len(df_intraday) * 0.85)
    id_train = df_intraday.iloc[:split_70]
    id_val = df_intraday.iloc[split_70:split_85]
    id_test = df_intraday.iloc[split_85:]

    id_X_train = id_train[FEATURE_COLS].values
    id_y_train = id_train['label_binary'].values
    id_X_val = id_val[FEATURE_COLS].values
    id_y_val = id_val['label_binary'].values
    id_X_test = id_test[FEATURE_COLS].values
    id_y_test = id_test['label_binary'].values

    print(f"Intraday rows: {len(df_intraday):,}")
    print(f"Train: {len(id_X_train):,} | Val: {len(id_X_val):,} | Test: {len(id_X_test):,}")

    id_train_set = lgb.Dataset(id_X_train, label=id_y_train, feature_name=FEATURE_COLS)
    id_val_set = lgb.Dataset(id_X_val, label=id_y_val, feature_name=FEATURE_COLS)

    id_model = lgb.train(
        params,
        id_train_set,
        num_boost_round=300,
        valid_sets=[id_val_set],
        valid_names=['val'],
        callbacks=[lgb.early_stopping(20), lgb.log_evaluation(50)],
    )

    id_test_probs = id_model.predict(id_X_test)
    id_test_acc = accuracy_score(id_y_test, (id_test_probs > 0.5).astype(int))
    id_test_auc = roc_auc_score(id_y_test, id_test_probs)

    print(f"\n  Intraday-only test acc: {id_test_acc:.4f} ({id_test_acc*100:.1f}%)")
    print(f"  Intraday-only test AUC: {id_test_auc:.4f}")

    # Save intraday model too
    id_model.save_model('/content/models/crash_lgbm_intraday.txt')
    with open('/content/models/crash_lgbm_intraday.pkl', 'wb') as f:
        pickle.dump(id_model, f)
    print(f"  Intraday model saved")
else:
    print(f"Not enough intraday rows ({len(df_intraday)}) -- need >= 500. Skipping.")
    id_model = None

# ============================================
# PICK BEST MODEL & SAVE
# ============================================

# Best = highest test AUC among the 3 horizon models
best_name = max(results, key=lambda k: results[k]['test_auc'])
best = results[best_name]
best_model = best['model']

print(f"\n\n{'='*60}")
print(f"BEST MODEL: {best_name}")
print(f"  Test acc: {best['test_acc']:.4f} | Test AUC: {best['test_auc']:.4f}")
print(f"{'='*60}")

# Feature importance for best model
importance = best_model.feature_importance(importance_type='gain')
feat_imp = sorted(zip(FEATURE_COLS, importance), key=lambda x: -x[1])

macro_feats = {'spx_return_1d','spx_vs_sma','vix_norm','vix_change','vix_extreme',
               'dxy_return_1d','dxy_trend','yield_level','yield_change','fng_norm'}
intraday_macro_feats = {'spx_return_5m','spx_return_30m','spx_return_1h',
                         'vix_return_5m','vix_return_30m','vix_level_5m',
                         'ndx_return_5m','ndx_return_30m',
                         'spx_vix_diverge','macro_momentum_5m','macro_momentum_30m'}
deriv_feats = {'funding_z','funding_extreme_long','funding_extreme_short',
               'oi_change_1h','oi_change_4h','oi_spike','ls_ratio_norm',
               'ls_extreme_long','taker_imbalance'}
cross_feats = {'eth_return_1bar','eth_return_6bar','eth_btc_ratio_change',
               'btc_lead_1','btc_lead_2','btc_lead_3'}

print(f"\nTop 20 features by gain:")
for i, (feat, imp) in enumerate(feat_imp[:20], 1):
    if feat in intraday_macro_feats:
        group = "INTRA"
    elif feat in macro_feats:
        group = "MACRO"
    elif feat in deriv_feats:
        group = "DERIV"
    elif feat in cross_feats:
        group = "CROSS"
    else:
        group = "PRICE"
    print(f"  {i:3d}. [{group:5s}] {feat:30s} {imp:>12,.0f}")

# Group importance totals
group_imp = {'PRICE': 0, 'MACRO': 0, 'INTRA': 0, 'DERIV': 0, 'CROSS': 0}
for feat, imp in feat_imp:
    if feat in intraday_macro_feats:
        group_imp['INTRA'] += imp
    elif feat in macro_feats:
        group_imp['MACRO'] += imp
    elif feat in deriv_feats:
        group_imp['DERIV'] += imp
    elif feat in cross_feats:
        group_imp['CROSS'] += imp
    else:
        group_imp['PRICE'] += imp

total_imp = sum(group_imp.values())
print(f"\nFeature group importance:")
for group, imp in sorted(group_imp.items(), key=lambda x: -x[1]):
    print(f"  {group:5s}: {imp:>12,.0f} ({100*imp/total_imp:.1f}%)")

# -- Save best model --
best_model.save_model('/content/models/crash_lightgbm_model.txt')
with open('/content/models/crash_lightgbm_model.pkl', 'wb') as f:
    pickle.dump(best_model, f)

# Also save all horizon models
for name, r in results.items():
    with open(f'/content/models/crash_lgbm_{name}.pkl', 'wb') as f:
        pickle.dump(r['model'], f)

meta = {
    'model_type': 'lightgbm_crash_regime_v2',
    'regime': 'CRASH',
    'best_horizon': best_name,
    'val_accuracy': float(best['val_acc']),
    'test_accuracy': float(best['test_acc']),
    'val_auc': float(best['val_auc']),
    'test_auc': float(best['test_auc']),
    'pred_std': float(best['pred_std']),
    'best_round': int(best['best_round']),
    'n_features': len(FEATURE_COLS),
    'feature_names': FEATURE_COLS,
    'feature_importance': {f: float(i) for f, i in feat_imp},
    'horizon_comparison': {
        name: {
            'val_acc': float(r['val_acc']),
            'test_acc': float(r['test_acc']),
            'val_auc': float(r['val_auc']),
            'test_auc': float(r['test_auc']),
        }
        for name, r in results.items()
    },
    'train_rows': int(len(X_train)),
    'val_rows': int(len(X_val)),
    'test_rows': int(len(X_test)),
    'crash_periods': CRASH_PERIODS,
    'crash1_weight': 0.5,
    'params': params,
    'trained_at': datetime.utcnow().isoformat(),
}
with open('/content/models/crash_lightgbm_meta.json', 'w') as f:
    json.dump(meta, f, indent=2)

# -- SAVE TO DRIVE IMMEDIATELY --
for fname in os.listdir('/content/models/'):
    shutil.copy(f'/content/models/{fname}', f'{DRIVE_SAVE}/{fname}')
    sz = os.path.getsize(f'/content/models/{fname}') / 1024
    print(f"[OK] Saved to Drive: {fname} ({sz:.1f} KB)")

print(f"\n[DONE] Crash-regime LightGBM v2 trained and saved!")


## Cell 8: Confidence Calibration Analysis (Best Model)

In [None]:
# ============================================================
# CELL 8: Confidence calibration analysis (v2)
# ============================================================
# Answers: "When the model says 85% confident, is it right 85% of the time?"
# Uses the BEST model from the horizon comparison.

# Use best model's test predictions
probs = results[best_name]['test_probs'].copy()
actuals = results[best_name]['y_test'].copy()

# Model outputs probability of UP (0.0 to 1.0)
# Confidence = how far from 50/50
confidence = np.abs(probs - 0.5) * 2  # Scale to 0-1
predicted_up = (probs > 0.5).astype(int)
correct = (predicted_up == actuals).astype(int)

# Bin by model probability (UP direction)
print(f"Best model: {best_name}")
print(f"\n{'Probability':>12} {'Direction':>10} {'Count':>8} {'Accuracy':>10} {'Bet?':>8}")
print("-" * 55)

prob_bins = [
    (0.50, 0.55, 'UP'),
    (0.55, 0.60, 'UP'),
    (0.60, 0.65, 'UP'),
    (0.65, 0.70, 'UP'),
    (0.70, 0.75, 'UP'),
    (0.75, 0.80, 'UP'),
    (0.80, 0.85, 'UP'),
    (0.85, 0.90, 'UP'),
    (0.90, 0.95, 'UP'),
    (0.95, 1.01, 'UP'),
]

for lo, hi, direction in prob_bins:
    mask = (probs >= lo) & (probs < hi)
    if mask.sum() > 0:
        acc = correct[mask].mean()
        bet = "[BET]" if lo >= 0.85 else "--"
        print(f"  {lo:.0%}-{hi:.0%}      {'UP':>10} {mask.sum():>8,} {acc:>10.1%} {bet:>8}")

print()
# Also check DOWN predictions (prob < 0.5)
for lo, hi in [(0.05, 0.10), (0.10, 0.15), (0.15, 0.20), (0.20, 0.25),
               (0.25, 0.30), (0.30, 0.35), (0.35, 0.40), (0.40, 0.45), (0.45, 0.50)]:
    mask = (probs >= lo) & (probs < hi)
    if mask.sum() > 0:
        down_correct = (actuals[mask] == 0).mean()
        bet = "[BET]" if hi <= 0.15 else "--"
        print(f"  {lo:.0%}-{hi:.0%}      {'DOWN':>10} {mask.sum():>8,} {down_correct:>10.1%} {bet:>8}")

# Summary for Polymarket thresholds
print(f"\n{'='*55}")
print(f"POLYMARKET DECISION THRESHOLDS ({best_name})")
print(f"{'='*55}")

# 85%+ confident UP (prob >= 0.85)
mask_85_up = probs >= 0.85
if mask_85_up.sum() > 0:
    acc = correct[mask_85_up].mean()
    print(f"  85%+ UP:   {mask_85_up.sum():>6,} predictions, {acc:.1%} accuracy")
else:
    print(f"  85%+ UP:   0 predictions")

# 85%+ confident DOWN (prob <= 0.15)
mask_85_down = probs <= 0.15
if mask_85_down.sum() > 0:
    acc = (actuals[mask_85_down] == 0).mean()
    print(f"  85%+ DOWN: {mask_85_down.sum():>6,} predictions, {acc:.1%} accuracy")
else:
    print(f"  85%+ DOWN: 0 predictions")

# Combined high confidence
mask_85_any = (probs >= 0.85) | (probs <= 0.15)
if mask_85_any.sum() > 0:
    hc_correct = np.where(probs[mask_85_any] >= 0.5,
                          actuals[mask_85_any] == 1,
                          actuals[mask_85_any] == 0)
    print(f"  85%+ ANY:  {mask_85_any.sum():>6,} predictions, {hc_correct.mean():.1%} accuracy")
    verdict = "BET" if hc_correct.mean() > 0.60 else "WAIT -- models not calibrated yet"
    print(f"\n  -> Polymarket should {verdict}")
else:
    print(f"  85%+ ANY:  0 predictions -- model never reaches 85% confidence")
    max_conf = confidence.max()
    print(f"  Max confidence seen: {50 + max_conf*50:.1f}%")
    for thresh in [0.60, 0.65, 0.70, 0.75, 0.80]:
        mask = confidence >= (thresh - 0.5) * 2
        if mask.sum() > 10:
            hc = np.where(probs[mask] >= 0.5, actuals[mask] == 1, actuals[mask] == 0)
            print(f"  {thresh:.0%}+ conf: {mask.sum():>6,} predictions, {hc.mean():.1%} accuracy")

# Save calibration data to Drive
cal_df = pd.DataFrame({
    'probability': probs,
    'confidence': confidence,
    'predicted_up': predicted_up,
    'actual_up': actuals,
    'correct': correct,
})
cal_df.to_csv(f'{DRIVE_SAVE}/calibration_analysis_v2.csv', index=False)
print(f"\n[OK] Calibration data saved to Drive")

# ============================================
# FINAL SUMMARY
# ============================================
print(f"\n\n{'='*60}")
print(f"TRAINING COMPLETE -- CRASH REGIME LIGHTGBM v2")
print(f"{'='*60}")
print(f"  Best horizon:   {best_name}")
print(f"  Test accuracy:  {best['test_acc']:.1%}")
print(f"  Test AUC:       {best['test_auc']:.4f}")
print(f"  Features:       {len(FEATURE_COLS)} (51)")
print(f"  Training data:  {len(X_train):,} crash-period rows")
print(f"  Model saved to: {DRIVE_SAVE}")
print(f"{'='*60}")

print(f"\nHorizon comparison:")
for name, r in results.items():
    marker = " <-- BEST" if name == best_name else ""
    print(f"  {name:16s}: acc={r['test_acc']:.4f}  auc={r['test_auc']:.4f}{marker}")

if best['test_acc'] > 0.53:
    print(f"\n[PASS] ACCURACY ABOVE 53% -- Ready to deploy!")
    print(f"   Download crash_lightgbm_model.pkl from Google Drive")
    print(f"   and place in models/trained/ on the VPS.")
elif best['test_acc'] > 0.51:
    print(f"\n[MARGINAL] Consider adding more features or tuning params")
else:
    print(f"\n[FAIL] NO IMPROVEMENT -- Check feature quality and data coverage")
