In [None]:
#!/usr/bin/env python3
# full_ce_pe_pipeline_use_ltp_enriched_final_restored.py
"""
Consolidated CE+PE pipeline using LTP where available.
Produces a merged enriched CSV with all requested columns,
plus auxiliary outputs: top calls/puts, alerts, reversals, IV-crush,
heatmap, auto-trade export, CE-only / PE-only lists, hold-breakouts.

Drop this file next to your flattened_snapshots.csv and run:
    python full_ce_pe_pipeline_use_ltp_enriched_final_restored.py
"""

import os
import sys
import json
from collections import defaultdict, Counter
from datetime import timedelta
import math

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import json
import pandas as pd
from pandas import json_normalize

INPUT_FILE = "19112025_BANK_PNL.txt"     # your pasted file
OUTPUT_CSV = "flattened_snapshots.csv"


def parse_snapshot(obj):
    """
    Flatten one snapshot including nested Current, Previous, Next JSON blocks.
    """
    flat = {}

    # 1. Copy all top-level fields except 'Current'
    for k, v in obj.items():
        if k != "Current":
            flat[k] = v

    # 2. Parse nested Current JSON string safely
    if "Current" in obj:
        try:
            curr = json.loads(obj["Current"])
        except Exception:
            curr = {}

        # Process Previous, Current, Next inside
        for section_name in ["Previous", "Current", "Next"]:
            if section_name in curr and isinstance(curr[section_name], dict):
                for key, val in curr[section_name].items():
                    flat[f"{section_name}_{key}"] = val
            else:
                # Add empty if missing
                flat[f"{section_name}"] = None

    return flat


def load_snapshots(path):
    """
    Load file that contains:
    - either JSON objects separated by newlines
    - or multiple JSON blobs one after another
    """
    snapshots = []
    with open(path, "r") as f:
        raw = f.read().strip()

    # Try line-by-line JSON parsing
    for line in raw.splitlines():
        line = line.strip()
        if not line:
            continue
        try:
            snapshots.append(json.loads(line))
        except:
            pass

    # If nothing loaded, try parsing entire file as list
    if len(snapshots) == 0:
        try:
            snapshots = json.loads(raw)
        except:
            raise ValueError("Cannot parse file: Format looks invalid")

    return snapshots


def main():
    snapshots = load_snapshots(INPUT_FILE)
    print(f"Loaded snapshots: {len(snapshots)}")

    # Flatten all snapshots
    flattened = [parse_snapshot(snap) for snap in snapshots]

    # Convert to DataFrame
    df = pd.DataFrame(flattened)

    # Normalize types
    for col in df.columns:
        # Clean array strings "[]"
        df[col] = df[col].apply(lambda x: None if x == "[]" else x)

    # Convert timestamp
    if "LTT" in df.columns:
        df["LTT"] = pd.to_datetime(df["LTT"], errors="coerce")

    # Save to CSV
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"Saved: {OUTPUT_CSV}")


if __name__ == "__main__":
    main()



# ---------------------- CONFIG ---------------------- #
INPUT_CSV = "flattened_snapshots.csv"

OUT_MERGED = "MERGED_CE_PE_FORECAST.csv"
OUT_TOP_PUTS = "TOP_BUY_PUTS.csv"
OUT_TOP_CALLS = "TOP_BUY_CALLS.csv"
OUT_ALL_ACTIONS = "ALL_ACTIONS.csv"

OUT_CE_ONLY = "TOP_CE_ONLY.csv"
OUT_PE_ONLY = "TOP_PE_ONLY.csv"
OUT_HOLD_BREAKOUTS = "TOP_HOLD_BREAKOUTS.csv"

OUT_REVERSALS = "REVERSALS.csv"
OUT_IV_CRUSH = "IV_CRUSH.csv"

OUT_PLOT = "SPOT_VS_MOMENTUM.png"
OUT_HEATMAP = "PREMIUM_HEATMAP.png"
OUT_AUTOTRADE = "AUTO_TRADE_SIGNALS.json"

REVERSAL_WINDOW_MIN = 5
REVERSAL_DROP_PCT = 12        # % drop from peak to last to mark reversal
IV_CRUSH_DROP = 15            # % drop in IV to mark IV crush

# ---------------------- LOAD ---------------------- #
if not os.path.exists(INPUT_CSV):
    print(f"ERROR: input file not found: {INPUT_CSV}")
    sys.exit(1)

# Parse LTT as datetime if present; if not, we'll proceed without times.
df = pd.read_csv(INPUT_CSV, low_memory=False)
if "LTT" in df.columns:
    try:
        df["LTT"] = pd.to_datetime(df["LTT"])
    except Exception:
        # keep raw if parse fails
        pass

# ---------------------- COLUMN SELECTION (LTP preferred) ---------------------- #
# Preferred LTP column names
PREV_CALL_LTP = "Previous_Call_ltp"
CURR_CALL_LTP = "Current_Call_ltp"
NEXT_CALL_LTP = "Next_Call_ltp"

PREV_PUT_LTP = "Previous_Put_ltp"
CURR_PUT_LTP = "Current_Put_ltp"
NEXT_PUT_LTP = "Next_Put_ltp"

# Strike columns
PREV_STR = "Previous_Strikeprice"
CURR_STR = "Current_Strikeprice"
NEXT_STR = "Next_Strikeprice"

# fallback prefer _Premium if ltp missing
def pick(col_ltp, col_prem):
    if col_ltp in df.columns:
        return col_ltp
    if col_prem in df.columns:
        return col_prem
    return None

prev_call_col = pick(PREV_CALL_LTP, "Previous_Call_Premium")
curr_call_col = pick(CURR_CALL_LTP, "Current_Call_Premium")
next_call_col = pick(NEXT_CALL_LTP, "Next_Call_Premium")

prev_put_col = pick(PREV_PUT_LTP, "Previous_Put_Premium")
curr_put_col = pick(CURR_PUT_LTP, "Current_Put_Premium")
next_put_col = pick(NEXT_PUT_LTP, "Next_Put_Premium")

print("Using Call columns:", prev_call_col, curr_call_col, next_call_col)
print("Using Put  columns:", prev_put_col, curr_put_col, next_put_col)
print("Using Strike columns:", PREV_STR, CURR_STR, NEXT_STR)

# ---------------------- BUILD TIMESERIES PER STRIKE ---------------------- #
def build_series(str_col_prev, str_col_curr, str_col_next, val_prev, val_curr, val_next):
    series = defaultdict(list)
    for idx, row in df.iterrows():
        t = row.get("LTT", None)
        triples = [
            (str_col_prev, val_prev),
            (str_col_curr, val_curr),
            (str_col_next, val_next)
        ]
        for sc, vc in triples:
            if sc is None or vc is None:
                continue
            if sc not in df.columns or vc not in df.columns:
                continue
            scv = row.get(sc)
            vcv = row.get(vc)
            if pd.isna(scv) or pd.isna(vcv):
                continue
            try:
                s = int(scv)
                p = float(vcv)
            except Exception:
                continue
            # keep timestamp if present, else use index-based monotonic increasing fallback
            if t is None or pd.isna(t):
                # use index as proxy timestamp
                t_use = pd.Timestamp(idx)
            else:
                t_use = t
            series[s].append((t_use, p))
    return series

call_series = build_series(PREV_STR, CURR_STR, NEXT_STR, prev_call_col, curr_call_col, next_call_col)
put_series  = build_series(PREV_STR, CURR_STR, NEXT_STR, prev_put_col, curr_put_col, next_put_col)

print("Detected call strikes:", len(call_series))
print("Detected put strikes :", len(put_series))

# ---------------------- SUMMARY & FORECAST HELPERS ---------------------- #
def summarize(pairs):
    pairs_sorted = sorted(pairs, key=lambda x: x[0])
    prices = [p for _, p in pairs_sorted]
    if not prices:
        return None
    first = prices[0]
    last = prices[-1]
    peak = max(prices)
    trough = min(prices)
    abs_change = last - first
    pct_change = (abs_change / first * 100) if first != 0 else np.nan
    return {
        "first": first,
        "last": last,
        "peak": peak,
        "trough": trough,
        "abs_change": abs_change,
        "pct_change": pct_change,
        "n_obs": len(prices),
        "series_sorted": pairs_sorted
    }

def forecast_from_pct(last, pct):
    """Return ((5min_low,5min_high),(10min_low,10min_high))"""
    if pd.isna(pct):
        pct = 0.0
    if pct >= 25:
        return (last+15, last+35), (last+25, last+60)
    if pct >= 8:
        return (last+6, last+18), (last+12, last+30)
    if pct > 0:
        return (last+2, last+8), (last+5, last+15)
    # neutral/slight negative
    return (last-5, last+2), (last-8, last+5)

# Build per-side summaries
call_summary = {}
for s, pairs in call_series.items():
    st = summarize(pairs)
    if st:
        call_summary[s] = st

put_summary = {}
for s, pairs in put_series.items():
    st = summarize(pairs)
    if st:
        put_summary[s] = st

# All unique strikes from both sides
all_strikes = sorted(set(list(call_summary.keys()) + list(put_summary.keys())))

# ---------------------- BUILD MERGED ENRICHED ROWS ---------------------- #
rows = []
for s in all_strikes:
    cs = call_summary.get(s)
    ps = put_summary.get(s)

    # n_obs unified
    n_obs_call = int(cs["n_obs"]) if cs else 0
    n_obs_put  = int(ps["n_obs"]) if ps else 0
    n_obs = n_obs_call + n_obs_put

    # choose first/last/peak/trough: prefer side with more observations (call over put if tie)
    def prefer(key_call, key_put):
        if cs and key_call in cs:
            return cs[key_call]
        if ps and key_put in ps:
            return ps[key_put]
        return np.nan

    # However to be explicit use logic:
    if n_obs_call >= n_obs_put and cs:
        first_premium = cs["first"]
        last_premium  = cs["last"]
        peak_premium  = cs["peak"]
        trough_premium= cs["trough"]
        abs_change    = cs["abs_change"]
        pct_change    = cs["pct_change"]
        f5, f10 = forecast_from_pct(last_premium, pct_change)
    elif ps:
        first_premium = ps["first"]
        last_premium  = ps["last"]
        peak_premium  = ps["peak"]
        trough_premium= ps["trough"]
        abs_change    = ps["abs_change"]
        pct_change    = ps["pct_change"]
        f5, f10 = forecast_from_pct(last_premium, pct_change)
    else:
        first_premium = last_premium = peak_premium = trough_premium = abs_change = pct_change = np.nan
        f5 = (np.nan, np.nan); f10 = (np.nan, np.nan)

    # p5/p10 expected deltas
    p5_expected_lo = (f5[0] - last_premium) if (not pd.isna(f5[0]) and not pd.isna(last_premium)) else np.nan
    p5_expected_hi = (f5[1] - last_premium) if (not pd.isna(f5[1]) and not pd.isna(last_premium)) else np.nan
    p10_expected_lo = (f10[0] - last_premium) if (not pd.isna(f10[0]) and not pd.isna(last_premium)) else np.nan
    p10_expected_hi = (f10[1] - last_premium) if (not pd.isna(f10[1]) and not pd.isna(last_premium)) else np.nan

    rows.append({
        "strike": s,
        "Unnamed: 0": (df["Unnamed: 0"].iloc[0] if "Unnamed: 0" in df.columns else np.nan),
        "n_obs": n_obs,
        "first_premium": first_premium,
        "last_premium": last_premium,
        "peak_premium": peak_premium,
        "trough_premium": trough_premium,
        "abs_change": abs_change,
        "pct_change": pct_change,
        "5min_low": f5[0],
        "5min_high": f5[1],
        "10min_low": f10[0],
        "10min_high": f10[1],
        "p5_expected_lo": p5_expected_lo,
        "p5_expected_hi": p5_expected_hi,
        "p10_expected_lo": p10_expected_lo,
        "p10_expected_hi": p10_expected_hi,
        # placeholders for call/put details - attach below
        "n_obs_call": n_obs_call,
        "n_obs_put": n_obs_put,
        "first_call_ltp": cs["first"] if cs else np.nan,
        "last_call_ltp": cs["last"] if cs else np.nan,
        "peak_call_ltp": cs["peak"] if cs else np.nan,
        "trough_call_ltp": cs["trough"] if cs else np.nan,
        "abs_change_call": cs["abs_change"] if cs else np.nan,
        "pct_change_call": cs["pct_change"] if cs else np.nan,
        "first_put_ltp": ps["first"] if ps else np.nan,
        "last_put_ltp": ps["last"] if ps else np.nan,
        "peak_put_ltp": ps["peak"] if ps else np.nan,
        "trough_put_ltp": ps["trough"] if ps else np.nan,
        "abs_change_put": ps["abs_change"] if ps else np.nan,
        "pct_change_put": ps["pct_change"] if ps else np.nan,
    })

merged_df = pd.DataFrame(rows)

# ---------------------- EXTRACT TAGS & MONEYFLOW FOR EACH STRIKE ---------------------- #
tag_cols = ["Previous_StrategyTag","Current_StrategyTag","Next_StrategyTag"]
call_money_cols = ["Previous_CallMoneyFlow","Current_CallMoneyFlow","Next_CallMoneyFlow",
                   "Previous_TotalcallMoneyFlow","Current_TotalcallMoneyFlow","Next_TotalcallMoneyFlow"]
put_money_cols  = ["Previous_PutMoneyFlow","Current_PutMoneyFlow","Next_PutMoneyFlow",
                   "Previous_TotalputMoneyFlow","Current_TotalputMoneyFlow","Next_TotalputMoneyFlow"]

# initialize
strike_info = {int(s): {"tags":Counter(), "call_moneyflow":0.0, "put_moneyflow":0.0} for s in merged_df["strike"].astype(int)}

# iterate source df once and aggregate
for idx, row in df.iterrows():
    for sc in (PREV_STR, CURR_STR, NEXT_STR):
        if sc not in df.columns:
            continue
        scval = row.get(sc)
        if pd.isna(scval):
            continue
        try:
            s = int(scval)
        except Exception:
            continue
        if s not in strike_info:
            continue
        # tags
        for col in tag_cols:
            if col in df.columns:
                v = row.get(col)
                if isinstance(v, str) and v.strip():
                    tokens = [t.strip() for t in v.replace("|",";").split(";") if t.strip()]
                    for t in tokens:
                        strike_info[s]["tags"][t] += 1
        # call moneyflow
        for col in call_money_cols:
            if col in df.columns:
                try:
                    v = row.get(col)
                    if not pd.isna(v):
                        strike_info[s]["call_moneyflow"] += float(v)
                except Exception:
                    pass
        # put moneyflow
        for col in put_money_cols:
            if col in df.columns:
                try:
                    v = row.get(col)
                    if not pd.isna(v):
                        strike_info[s]["put_moneyflow"] += float(v)
                except Exception:
                    pass

def tags_to_text(s):
    t = strike_info.get(int(s), {}).get("tags", {})
    if not t:
        return ""
    return ";".join([f"{k}:{v}" for k,v in t.items()])

merged_df["tags"] = merged_df["strike"].apply(lambda s: tags_to_text(s))
merged_df["call_moneyflow"] = merged_df["strike"].apply(lambda s: strike_info.get(int(s), {}).get("call_moneyflow", 0.0))
merged_df["put_moneyflow"] = merged_df["strike"].apply(lambda s: strike_info.get(int(s), {}).get("put_moneyflow", 0.0))

# ---------------------- REASONS (derived) ---------------------- #
def build_reasons(row):
    reasons = []
    tags_text = str(row.get("tags",""))
    # tag based
    if any(k.lower() in tags_text.lower() for k in ["rsimacd","rsi","macd"]):
        reasons.append("RSI/MACD momentum")
    if "VWAP" in tags_text or "vwap" in tags_text:
        reasons.append("VWAP divergence")
    if "OI" in tags_text or "oi" in tags_text:
        reasons.append("OI support/resistance")
    # moneyflow
    if row.get("call_moneyflow",0) > 0:
        reasons.append("Call net buying")
    if row.get("put_moneyflow",0) > 0:
        reasons.append("Put net buying")
    # pct-based
    try:
        pct = float(row.get("pct_change", 0) or 0)
        if pct > 10:
            reasons.append("Strong premium move")
        elif pct > 3:
            reasons.append("Moderate premium move")
    except Exception:
        pass
    if not reasons:
        reasons = ["No strong signals"]
    # dedupe
    seen = set()
    out = []
    for r in reasons:
        if r not in seen:
            out.append(r)
            seen.add(r)
    return "; ".join(out)

merged_df["reasons"] = merged_df.apply(build_reasons, axis=1)

# ---------------------- RECOMMENDED ACTION (dual-side) ---------------------- #
def decide_action(row):
    call_pct = row.get("pct_change_call")
    put_pct = row.get("pct_change_put")
    call_pct = 0 if pd.isna(call_pct) else float(call_pct)
    put_pct = 0 if pd.isna(put_pct) else float(put_pct)
    tags = (row.get("tags") or "").lower()

    bull_boost = ("call buying" in tags) or ("oi_support_call" in tags) or ("bull" in tags)
    bear_boost = ("put buying" in tags) or ("call writing" in tags) or ("bear" in tags)

    # priority rules
    if (put_pct > 8) or (put_pct > 5 and bear_boost):
        return "BUY_PUT"
    if (call_pct > 8) or (call_pct > 5 and bull_boost):
        return "BUY_CALL"
    return "HOLD"

merged_df["recommended_action"] = merged_df.apply(decide_action, axis=1)

# ---------------------- HIGH-CONVICTION STATISTICS ---------------------- #
hc_col = "Current_IsHighConvictionSignal"

def compute_highconv_stats(strike, action):
    total = 0
    success = 0
    if hc_col not in df.columns:
        return 0, 0, None
    # select rows where current strike matches and hc flag true
    cond = (df.get("Current_Strikeprice") == strike) & (df.get(hc_col) == True)
    hc_rows = df[cond]
    total = int(hc_rows.shape[0])
    for idx, hr in hc_rows.iterrows():
        t0 = hr.get("LTT")
        if pd.isna(t0):
            continue
        # pick premium column based on action preference
        if action == "BUY_PUT" and curr_put_col in df.columns:
            base_col = curr_put_col
        elif action == "BUY_CALL" and curr_call_col in df.columns:
            base_col = curr_call_col
        else:
            base_col = curr_call_col if curr_call_col in df.columns else curr_put_col if curr_put_col in df.columns else None
        if base_col is None:
            continue
        p0 = hr.get(base_col)
        if pd.isna(p0):
            continue
        # window 3 minutes forward
        window = df[(df["LTT"] >= t0) & (df["LTT"] <= (t0 + timedelta(minutes=3)))]
        if window.empty:
            continue
        try:
            if window[base_col].max() > p0:
                success += 1
        except Exception:
            pass
    rate = (success/total) if total > 0 else None
    return total, success, rate

hc_totals = []
hc_successes = []
hc_rates = []
for idx, r in merged_df.iterrows():
    s = int(r["strike"])
    action = r.get("recommended_action","HOLD")
    tot, succ, rate = compute_highconv_stats(s, action)
    hc_totals.append(tot)
    hc_successes.append(succ)
    hc_rates.append(rate)

merged_df["highconv_total"] = hc_totals
merged_df["highconv_success"] = hc_successes
merged_df["highconv_hit_rate"] = hc_rates

# add Current_Strikeprice column (explicit)
merged_df["Current_Strikeprice"] = merged_df["strike"]

# ---------------------- REVERSAL DETECTION ---------------------- #
def detect_reversals_in_map(series_map, side_label):
    out = []
    for s, pairs in series_map.items():
        sr = sorted(pairs, key=lambda x: x[0])
        if len(sr) < 3:
            continue
        t_last = sr[-1][0]
        # window lookback
        window = [(t,p) for t,p in sr if t >= (t_last - timedelta(minutes=REVERSAL_WINDOW_MIN))]
        if not window:
            continue
        prices = [p for _,p in window]
        peak = max(prices)
        last_price = prices[-1]
        drop_pct = (peak - last_price) / peak * 100 if peak > 0 else 0
        if drop_pct >= REVERSAL_DROP_PCT:
            out.append({
                "strike": s,
                "side": side_label,
                "peak": peak,
                "last": last_price,
                "drop_pct": drop_pct
            })
    return out

rev_list = detect_reversals_in_map(call_series, "CALL") + detect_reversals_in_map(put_series, "PUT")
rev_df = pd.DataFrame(rev_list)
if not rev_df.empty:
    rev_df.to_csv(OUT_REVERSALS, index=False)
else:
    # create empty file
    pd.DataFrame(columns=["strike","side","peak","last","drop_pct"]).to_csv(OUT_REVERSALS, index=False)

# ---------------------- IV CRUSH DETECTION ---------------------- #
iv_crush_events = []
# iterate rows and check prev vs curr IV fields (existence optional)
for idx, row in df.iterrows():
    for prev_iv_col, curr_iv_col, side in [
        ("Previous_Call_IV","Current_Call_IV","CALL"),
        ("Previous_Put_IV","Current_Put_IV","PUT")
    ]:
        if prev_iv_col in df.columns and curr_iv_col in df.columns:
            prev_iv = row.get(prev_iv_col)
            curr_iv = row.get(curr_iv_col)
            if pd.notna(prev_iv) and pd.notna(curr_iv) and prev_iv > 0:
                drop_pct = (prev_iv - curr_iv) / prev_iv * 100
                if drop_pct >= IV_CRUSH_DROP:
                    iv_crush_events.append({
                        "LTT": row.get("LTT"),
                        "strike": row.get("Current_Strikeprice"),
                        "side": side,
                        "prev_iv": prev_iv,
                        "curr_iv": curr_iv,
                        "drop_pct": drop_pct
                    })

iv_crush_df = pd.DataFrame(iv_crush_events)
if not iv_crush_df.empty:
    iv_crush_df.to_csv(OUT_IV_CRUSH, index=False)
else:
    pd.DataFrame(columns=["LTT","strike","side","prev_iv","curr_iv","drop_pct"]).to_csv(OUT_IV_CRUSH, index=False)

# ---------------------- HEATMAP (CALL/PUT pct change) ---------------------- #
# Prepare simple heatmap matrix: rows = strikes, cols = [pct_change_call, pct_change_put]
heatmap_df = merged_df[["pct_change_call","pct_change_put"]].fillna(0)
if not heatmap_df.empty:
    plt.figure(figsize=(8, max(3, len(heatmap_df)/4)))
    plt.imshow(heatmap_df.values, aspect='auto', interpolation='nearest')
    plt.colorbar(label="Premium % Change")
    plt.title("CALL/PUT Premium % Change Heatmap (rows=strikes)")
    plt.ylabel("strike index (not price)")
    plt.xlabel("0=CE_pct_change, 1=PE_pct_change")
    plt.tight_layout()
    plt.savefig(OUT_HEATMAP, dpi=150)
    plt.close()
else:
    # produce an empty placeholder
    plt.figure(figsize=(6,2)); plt.text(0.5,0.5,"No data"); plt.axis('off'); plt.savefig(OUT_HEATMAP); plt.close()

# ---------------------- SPOT VS MOMENTUM PLOT ---------------------- #
last_spot = df["SpotPrice"].dropna().iloc[-1] if "SpotPrice" in df.columns and not df["SpotPrice"].dropna().empty else np.nan
avg_ce_pct = merged_df["pct_change_call"].replace([np.inf,-np.inf],np.nan).dropna().mean() if "pct_change_call" in merged_df.columns else np.nan
avg_pe_pct = merged_df["pct_change_put"].replace([np.inf,-np.inf],np.nan).dropna().mean() if "pct_change_put" in merged_df.columns else np.nan

vals = [last_spot if not pd.isna(last_spot) else 0, avg_ce_pct if not pd.isna(avg_ce_pct) else 0, avg_pe_pct if not pd.isna(avg_pe_pct) else 0]
labels = ["Spot Last","Avg CE %","Avg PE %"]
plt.figure(figsize=(7,4))
plt.title("Spot (last) vs Avg CE/PE %change")
plt.bar(labels, vals)
plt.tight_layout()
plt.savefig(OUT_PLOT, dpi=150)
plt.close()

# ---------------------- AUTO-TRADE EXPORT ---------------------- #
auto_signals = []
for idx, r in merged_df.iterrows():
    action = r.get("recommended_action","HOLD")
    if action in ("BUY_CALL","BUY_PUT"):
        strength = float(r.get("pct_change_call",0) or r.get("pct_change_put",0) or 0)
        auto_signals.append({
            "strike": int(r["strike"]),
            "action": action,
            "strength": strength,
            "reason": r.get("reasons",""),
            "tags": r.get("tags","")
        })
with open(OUT_AUTOTRADE, "w") as fh:
    json.dump(auto_signals, fh, indent=2, default=str)

# ---------------------- WRITE OUTPUTS: main merged + slices ---------------------- #
# Save merged enriched table with all requested columns
# Ensure column order to match user's expectation
cols_order = [
    "strike","Unnamed: 0","n_obs","n_obs_call","n_obs_put",
    "first_premium","last_premium","peak_premium","trough_premium",
    "abs_change","pct_change",
    "5min_low","5min_high","10min_low","10min_high",
    "p5_expected_lo","p5_expected_hi","p10_expected_lo","p10_expected_hi",
    "call_moneyflow","put_moneyflow",
    "tags","reasons","recommended_action",
    "highconv_total","highconv_success","highconv_hit_rate",
    "Current_Strikeprice",
    # add call/put detailed columns too
    "first_call_ltp","last_call_ltp","peak_call_ltp","trough_call_ltp","abs_change_call","pct_change_call",
    "first_put_ltp","last_put_ltp","peak_put_ltp","trough_put_ltp","abs_change_put","pct_change_put"
]

# keep only existing columns from the order (some may be missing)
cols_existing = [c for c in cols_order if c in merged_df.columns]
# append any other columns to preserve details
other_cols = [c for c in merged_df.columns if c not in cols_existing]
final_cols = cols_existing + other_cols

merged_df.to_csv(OUT_MERGED, index=False, columns=final_cols)

# Top BUY_PUTS (by pct_change_put)
top_puts = merged_df[merged_df["recommended_action"]=="BUY_PUT"].copy()
if not top_puts.empty and "pct_change_put" in top_puts.columns:
    top_puts = top_puts.sort_values("pct_change_put", ascending=False).head(200)
top_puts.to_csv(OUT_TOP_PUTS, index=False)

# Top BUY_CALLS
top_calls = merged_df[merged_df["recommended_action"]=="BUY_CALL"].copy()
if not top_calls.empty and "pct_change_call" in top_calls.columns:
    top_calls = top_calls.sort_values("pct_change_call", ascending=False).head(200)
top_calls.to_csv(OUT_TOP_CALLS, index=False)

# All actions sorted by strength (max of CE/PE pct)
def compute_strength(r):
    try:
        a = r.get("pct_change_call", 0) or 0
        b = r.get("pct_change_put", 0) or 0
        return max(a, b)
    except Exception:
        return 0

merged_df["strength"] = merged_df.apply(compute_strength, axis=1)
merged_df.sort_values("strength", ascending=False).to_csv(OUT_ALL_ACTIONS, index=False)

# CE-only / PE-only tables
if "pct_change_call" in merged_df.columns:
    merged_df[merged_df["pct_change_call"].notna()].to_csv(OUT_CE_ONLY, index=False)
else:
    pd.DataFrame().to_csv(OUT_CE_ONLY, index=False)
if "pct_change_put" in merged_df.columns:
    merged_df[merged_df["pct_change_put"].notna()].to_csv(OUT_PE_ONLY, index=False)
else:
    pd.DataFrame().to_csv(OUT_PE_ONLY, index=False)

# HOLD breakout candidates
hold_df = merged_df[merged_df["recommended_action"]=="HOLD"].copy()
# define breakout_score as sum of positive pct changes
hold_df["breakout_score"] = (hold_df.get("pct_change_call",0).clip(lower=0).fillna(0)
                             + hold_df.get("pct_change_put",0).clip(lower=0).fillna(0))
hold_df.sort_values("breakout_score", ascending=False).to_csv(OUT_HOLD_BREAKOUTS, index=False)

print("✔ Pipeline complete. Files generated:")
for f in [OUT_MERGED, OUT_TOP_PUTS, OUT_TOP_CALLS, OUT_ALL_ACTIONS,
          OUT_CE_ONLY, OUT_PE_ONLY, OUT_HOLD_BREAKOUTS,
          OUT_REVERSALS, OUT_IV_CRUSH, OUT_PLOT, OUT_HEATMAP, OUT_AUTOTRADE]:
    print("-", f)

# Show columns included in the merged output for immediate verification
print("\nMerged CSV columns (sample):")
print(list(merged_df.columns[:50]))


Loaded snapshots: 22350
Saved: flattened_snapshots.csv
Using Call columns: Previous_Call_ltp Current_Call_ltp Next_Call_ltp
Using Put  columns: Previous_Put_ltp Current_Put_ltp Next_Put_ltp
Using Strike columns: Previous_Strikeprice Current_Strikeprice Next_Strikeprice
Detected call strikes: 7
Detected put strikes : 7
✔ Pipeline complete. Files generated:
- MERGED_CE_PE_FORECAST.csv
- TOP_BUY_PUTS.csv
- TOP_BUY_CALLS.csv
- ALL_ACTIONS.csv
- TOP_CE_ONLY.csv
- TOP_PE_ONLY.csv
- TOP_HOLD_BREAKOUTS.csv
- REVERSALS.csv
- IV_CRUSH.csv
- SPOT_VS_MOMENTUM.png
- PREMIUM_HEATMAP.png
- AUTO_TRADE_SIGNALS.json

Merged CSV columns (sample):
['strike', 'Unnamed: 0', 'n_obs', 'first_premium', 'last_premium', 'peak_premium', 'trough_premium', 'abs_change', 'pct_change', '5min_low', '5min_high', '10min_low', '10min_high', 'p5_expected_lo', 'p5_expected_hi', 'p10_expected_lo', 'p10_expected_hi', 'n_obs_call', 'n_obs_put', 'first_call_ltp', 'last_call_ltp', 'peak_call_ltp', 'trough_call_ltp', 'abs_change