In [19]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
VectorBT – Bullish Candle + (Optional) RSI – Robust Loader with Synthetic Fallback
==================================================================================

- Bulk (vectorbt.YFData) + per-symbol yfinance with retries.
- Accepts OHLC even if Volume missing.
- Smart pruning & intersection to avoid index/column warnings.
- **Synthetic fallback** (GBM) if all downloads fail (toggleable).

Strategy:
- Entries: bullish candle mask (+ optional RSI>50; optional trend checks; optional 52w filter)
- Throttle: VOLAR top-K per date (optional)
- Sizing: MVO via PyPortfolioOpt (optional) else equal-weight, scaled by deploy_cash_frac
- Exits: SL/TP (+ optional EMA fast<slow)
- Costs: fees & slippage
"""

import os, math, json, time, logging, warnings, random
from dataclasses import dataclass , field
from typing import Optional, List, Dict, Tuple
from functools import reduce

import numpy as np
import pandas as pd
import vectorbt as vbt
import yfinance as yf

warnings.filterwarnings("ignore", category=FutureWarning)

# ========= Logging =========
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("vbt_candle_rsi_final")

# ========= Config =========
@dataclass
class Config:
    start_date: str = "2018-01-01"
    end_date: str   = "2025-01-01"

    # Provide either a file OR a list; list is easiest for first run
    static_symbols: Optional[List[str]] = field(
        default_factory=lambda: ["RELIANCE.NS","HDFCBANK.NS","INFY.NS","TCS.NS","ICICIBANK.NS"]
    )
    static_symbols_path: Optional[str] = None

    out_dir: str = "outputs_vbt"
    write_csv: bool = True
    plot: bool = False  # set True if you want an interactive plot

    # Data hygiene / loader behavior
    min_rows: int = 120         # required AFTER alignment
    min_rows_symbol: int = 60   # required per symbol BEFORE alignment
    max_universe: int = 200
    accept_missing_volume: bool = True
    bulk_first: bool = True
    bulk_chunk: int = 80
    yf_retries: int = 3
    yf_retry_base_sleep: float = 0.8

    # <<< New >>> Synthetic fallback (if zero symbols downloaded)
    fallback_to_synthetic: bool = True
    synthetic_days: int = 600   # ~2.5y business days
    synthetic_seed: int = 42

    # Strategy toggles (you’ll enable RSI + candle)
    use_rsi_confirm: bool   = True
    use_trend_fast_slow: bool = False
    use_htf_trend: bool       = False
    use_indicator_exit: bool  = True

    # Candle patterns (bullish)
    enable_patterns: List[str] = (
        "ENGULFING","PIERCING","MORNING_STAR","HARAMI","HARAMI_CROSS","HAMMER","INVERTED_HAMMER"
    )

    # Filters & ranking
    use_52w_filter: bool     = True
    filter_52w_window: int   = 252
    within_pct_of_52w_high: float = 0.50
    use_volar_ranking: bool  = True
    volar_lookback: int      = 252
    top_k_daily: int         = 3

    # Sizing
    use_mvo_sizing: bool     = False
    deploy_cash_frac: float  = 0.25

    # Trend params
    ema_fast: int = 10
    ema_slow: int = 20
    ema_htf:  int = 200

    # Stops
    stop_loss_pct: float = 0.05
    target_pct:    float = 0.10

    # Costs
    fees: float     = 0.0005
    slippage: float = 0.0005

CFG = Config()

# ========= Helpers =========
def ensure_dir(p): os.makedirs(p, exist_ok=True)

def load_universe(cfg: Config) -> List[str]:
    if cfg.static_symbols:
        syms = list(cfg.static_symbols)
    elif cfg.static_symbols_path and os.path.exists(cfg.static_symbols_path):
        with open(cfg.static_symbols_path) as f:
            syms = [ln.strip() for ln in f if ln.strip()]
    else:
        raise ValueError("Provide CFG.static_symbols or CFG.static_symbols_path.")

    cleaned = []
    for s in syms[:cfg.max_universe]:
        s = s.strip().upper()
        if s.startswith("^"):     # drop indices
            continue
        if s.endswith(".BO"):     # drop BSE for now
            continue
        if not s.endswith(".NS"):
            s = s + ".NS"
        cleaned.append(s)

    out, seen = [], set()
    for s in cleaned:
        if s not in seen:
            out.append(s); seen.add(s)
    if not out:
        raise ValueError("Universe empty after cleaning.")
    return out

def _normalize_df(df: pd.DataFrame, accept_missing_volume: bool) -> Optional[pd.DataFrame]:
    if df is None or df.empty:
        return None
    df = df.rename(columns=str.title)
    needed = ['Open','High','Low','Close']
    if not set(needed).issubset(df.columns):
        return None
    cols = needed + (['Volume'] if ('Volume' in df.columns and not df['Volume'].isna().all()) or not accept_missing_volume else [])
    if 'Volume' in cols and 'Volume' not in df.columns:
        cols = needed
    df = df[cols].dropna(subset=needed)
    df.index = pd.to_datetime(df.index).tz_localize(None)
    return df

def yf_download_symbol(sym: str, start: str, end: str, accept_missing_volume: bool, retries: int, base_sleep: float) -> Optional[pd.DataFrame]:
    for a in range(retries):
        try:
            df = yf.download(sym, start=start, end=end, interval='1d', auto_adjust=True, progress=False, threads=False)
            df = _normalize_df(df, accept_missing_volume)
            if df is not None and df.shape[0] >= CFG.min_rows_symbol:
                return df
        except Exception:
            pass
        time.sleep(base_sleep * (2 ** a) + random.uniform(0, 0.3))
    return None

def bulk_download_vbt(tickers: List[str], start: str, end: str, accept_missing_volume: bool, chunk: int) -> Dict[str, pd.DataFrame]:
    kept = {}
    for i in range(0, len(tickers), chunk):
        batch = tickers[i:i+chunk]
        try:
            data = vbt.YFData.download(
                batch, start=start, end=end, auto_adjust=True, interval='1d',
                missing_index='drop', missing_columns='drop'
            )
            for s in data.symbols:
                O = data.select_symbol(s).get('Open')
                H = data.select_symbol(s).get('High')
                L = data.select_symbol(s).get('Low')
                C = data.select_symbol(s).get('Close')
                V = data.select_symbol(s).get('Volume')
                if accept_missing_volume and (V is None or (hasattr(V,'isna') and V.isna().all())):
                    df = pd.DataFrame({'Open': O, 'High': H, 'Low': L, 'Close': C})
                else:
                    if V is None:
                        continue
                    df = pd.DataFrame({'Open': O, 'High': H, 'Low': L, 'Close': C, 'Volume': V})
                df = df.dropna(subset=['Open','High','Low','Close'])
                df.index = pd.to_datetime(df.index).tz_localize(None)
                if df.shape[0] >= CFG.min_rows_symbol:
                    kept[s] = df
        except Exception:
            continue
    return kept

def download_clean_ohlcv(cfg: Config, tickers: List[str]) -> Dict[str, pd.DataFrame]:
    kept: Dict[str, pd.DataFrame] = {}

    if cfg.bulk_first:
        log.info("Bulk downloading with vectorbt.YFData (drop mismatches) ...")
        kept = bulk_download_vbt(tickers, cfg.start_date, cfg.end_date, cfg.accept_missing_volume, cfg.bulk_chunk)

    missing = [s for s in tickers if s not in kept]
    if missing:
        log.info("Per-symbol fallback via yfinance for %d tickers ...", len(missing))
        for j, s in enumerate(missing, 1):
            df = yf_download_symbol(s, cfg.start_date, cfg.end_date, cfg.accept_missing_volume, cfg.yf_retries, cfg.yf_retry_base_sleep)
            if df is not None:
                kept[s] = df
            if j % 25 == 0:
                log.info("  processed %d/%d in fallback; kept=%d", j, len(missing), len(kept))

    dropped = [s for s in tickers if s not in kept]
    if dropped:
        log.warning("Dropped %d tickers (bad/short/missing): %s", len(dropped), ", ".join(dropped[:20]) + (" ..." if len(dropped)>20 else ""))

    return kept

def synth_gbm_df(days: int, seed: int, s0: float = 100.0, mu: float = 0.12, sigma: float = 0.25) -> pd.DataFrame:
    np.random.seed(seed)
    idx = pd.bdate_range(end=pd.Timestamp.today().normalize(), periods=days)
    dt = 1/252
    shocks = np.random.normal((mu - 0.5*sigma*sigma)*dt, sigma*np.sqrt(dt), size=len(idx))
    logp = np.log(s0) + np.cumsum(shocks)
    close = np.exp(logp)
    high = close * (1 + np.random.uniform(0.0, 0.01, size=len(idx)))
    low  = close * (1 - np.random.uniform(0.0, 0.01, size=len(idx)))
    open_ = close * (1 + np.random.uniform(-0.005, 0.005, size=len(idx)))
    vol  = np.random.randint(100000, 5000000, size=len(idx))
    df = pd.DataFrame({"Open": open_, "High": high, "Low": low, "Close": close, "Volume": vol}, index=idx)
    return df

def build_common_index_smart(data_map: Dict[str, pd.DataFrame], min_rows: int) -> Tuple[pd.DatetimeIndex, Dict[str, pd.DataFrame]]:
    lengths = sorted(((k, len(v)) for k, v in data_map.items()), key=lambda x: x[1], reverse=True)
    symbols = [k for k, _ in lengths]
    current = symbols.copy()

    def intersection_of(symbols_list):
        idxs = [data_map[s].index for s in symbols_list]
        if not idxs:
            return pd.DatetimeIndex([])
        inter = reduce(lambda a, b: a.intersection(b), idxs)
        inter = pd.DatetimeIndex(sorted(inter.unique()))
        return inter

    inter = intersection_of(current)
    while len(inter) < min_rows and len(current) > 1:
        shortest = min(current, key=lambda s: len(data_map[s]))
        current.remove(shortest)
        inter = intersection_of(current)

    kept_map = {s: data_map[s] for s in current}
    return inter, kept_map

def align_to_index(data_map: Dict[str, pd.DataFrame], idx: pd.DatetimeIndex) -> Dict[str, pd.DataFrame]:
    aligned = {}
    for s, df in data_map.items():
        d = df.reindex(idx).dropna()
        aligned[s] = d
    return aligned

# ========= Candle patterns =========
def _body(O, C): return (C - O).abs()
def _is_bull(O, C): return C > O
def _is_bear(O, C): return C < O

def patt_bull_engulfing(O,H,L,C):
    Op, Cp = O.shift(1), C.shift(1)
    return _is_bear(Op,Cp) & _is_bull(O,C) & (C >= Op) & (O <= Cp)

def patt_piercing(O,H,L,C):
    Op, Hp, Lp, Cp = O.shift(1), H.shift(1), L.shift(1), C.shift(1)
    mid_prev = (Op + Cp) / 2.0
    return _is_bear(Op,Cp) & _is_bull(O,C) & (O < Lp) & (C > mid_prev) & (C < Op)

def patt_morning_star(O,H,L,C):
    O1, C1 = O.shift(1), C.shift(1)
    O2, C2 = O.shift(2), C.shift(2)
    bear1 = _is_bear(O2,C2)
    small2 = (_body(O1,C1) <= (_body(O2,C2)*0.6))
    bull3 = _is_bull(O,C)
    mid1 = (O2 + C2) / 2.0
    retrace = C > mid1
    return bear1 & small2 & bull3 & retrace

def patt_harami_bull(O,H,L,C):
    O1, C1 = O.shift(1), C.shift(1)
    prev_bear = _is_bear(O1,C1)
    body_small = (_body(O,C) <= _body(O1,C1) * 0.75)
    inside = (np.maximum(O,C) <= O1) & (np.minimum(O,C) >= C1)
    return prev_bear & body_small & inside & (C >= O)

def patt_harami_cross_bull(O,H,L,C, doji_pct=0.1):
    O1, C1 = O.shift(1), C.shift(1)
    rng = (H - L).replace(0, np.nan)
    doji = (_body(O,C) <= (rng * doji_pct))
    prev_bear = _is_bear(O1,C1)
    inside = (np.maximum(O,C) <= O1) & (np.minimum(O,C) >= C1)
    return prev_bear & doji & inside

def patt_hammer(O,H,L,C, shadow_mult=2.0):
    body = _body(O,C)
    lower_shadow = (np.minimum(O,C) - L).abs()
    upper_shadow = (H - np.maximum(O,C)).abs()
    return (lower_shadow >= shadow_mult*body) & (upper_shadow <= body) & (C >= O)

def patt_inverted_hammer(O,H,L,C, shadow_mult=2.0):
    body = _body(O,C)
    upper_shadow = (H - np.maximum(O,C)).abs()
    lower_shadow = (np.minimum(O,C) - L).abs()
    return (upper_shadow >= shadow_mult*body) & (lower_shadow <= body) & (C >= O)

def bullish_pattern_mask(O: pd.DataFrame, H: pd.DataFrame, L: pd.DataFrame, C: pd.DataFrame, enabled: List[str]) -> pd.DataFrame:
    masks = []
    if "ENGULFING" in enabled:        masks.append(patt_bull_engulfing(O,H,L,C))
    if "PIERCING" in enabled:         masks.append(patt_piercing(O,H,L,C))
    if "MORNING_STAR" in enabled:     masks.append(patt_morning_star(O,H,L,C))
    if "HARAMI" in enabled:           masks.append(patt_harami_bull(O,H,L,C))
    if "HARAMI_CROSS" in enabled:     masks.append(patt_harami_cross_bull(O,H,L,C))
    if "HAMMER" in enabled:           masks.append(patt_hammer(O,H,L,C))
    if "INVERTED_HAMMER" in enabled:  masks.append(patt_inverted_hammer(O,H,L,C))
    if not masks:
        return pd.DataFrame(False, index=C.index, columns=C.columns)
    out = masks[0]
    for m in masks[1:]:
        out = out | m
    return out.fillna(False)

# ========= Indicators / util =========
def ema(df: pd.DataFrame, span: int) -> pd.DataFrame:
    return df.ewm(span=span, adjust=False, min_periods=span).mean()

def rsi(series: pd.DataFrame, length: int = 14) -> pd.DataFrame:
    delta = series.diff()
    gain = (delta.where(delta > 0, 0.0)).rolling(length).mean()
    loss = (-delta.where(delta < 0, 0.0)).rolling(length).mean()
    rs = gain / loss.replace(0.0, np.nan)
    out = 100 - (100 / (1 + rs))
    return out.fillna(50.0)

def volar_scores(close: pd.DataFrame, bench: pd.Series, dt: pd.Timestamp, lookback: int) -> pd.Series:
    s_ret = close.loc[:dt].pct_change().dropna().iloc[-lookback:]
    b_ret = bench.loc[:dt].pct_change().dropna().iloc[-lookback:]
    scores = {}
    for col in close.columns:
        if col not in s_ret:
            scores[col] = np.nan; continue
        s = s_ret[col]
        common = pd.concat([s, b_ret], axis=1, keys=["s","b"]).dropna()
        if common.shape[0] < max(20, int(0.4*lookback)):
            scores[col] = np.nan; continue
        excess = common["s"] - common["b"]
        vol = common["s"].std(ddof=0)
        scores[col] = 0.0 if vol <= 1e-8 else float((excess.mean() / vol) * math.sqrt(252.0))
    return pd.Series(scores)

def try_mvo_weights(returns_window: pd.DataFrame) -> pd.Series:
    try:
        from pypfopt import expected_returns, risk_models, EfficientFrontier
        if returns_window.empty:
            return pd.Series(dtype=float)
        prices_win = (1 + returns_window).cumprod()
        mu = expected_returns.mean_historical_return(prices_win, frequency=252)
        S  = risk_models.CovarianceShrinkage(prices_win).ledoit_wolf()
        ef = EfficientFrontier(mu, S)
        w  = ef.max_sharpe()
        w  = pd.Series(ef.clean_weights())
        w = w.clip(lower=0)
        return w / w.sum() if w.sum() > 0 else pd.Series(dtype=float)
    except Exception:
        return pd.Series(dtype=float)

# ========= Main =========
def main(cfg: Config):
    ensure_dir(cfg.out_dir)
    tickers = load_universe(cfg)
    log.info("Loaded %d tickers (after cleaning)", len(tickers))

    # 1) Try to download real data
    data_map = download_clean_ohlcv(cfg, tickers)

    # 2) If nothing came back and fallback allowed, synthesize GBM data
    if (not data_map) and cfg.fallback_to_synthetic:
        log.error("No valid symbols after cleaning. Using SYNTHETIC OHLC (GBM) so you can run end-to-end.")
        np.random.seed(cfg.synthetic_seed)
        synth_map = {}
        cols = tickers[:min(5, len(tickers))]  # keep it small
        for i, s in enumerate(cols):
            df = synth_gbm_df(cfg.synthetic_days, seed=cfg.synthetic_seed + i, s0=100+10*i)
            synth_map[s] = df
        data_map = synth_map

    if not data_map:
        raise RuntimeError("No data available (and synthetic disabled). Check network/Yahoo access.")

    # 3) Build intersection index (auto-prune worst offenders) to hit min_rows
    inter_idx, kept_map = build_common_index_smart(data_map, cfg.min_rows)
    if len(inter_idx) < cfg.min_rows:
        # As a last resort with real data, relax requirement:
        log.warning("Intersection < min_rows; relaxing to len(intersection)=%d", len(inter_idx))
        if len(inter_idx) == 0:
            # If synthetic, inter_idx will be fine; if still zero, bail
            raise RuntimeError("Common intersection is empty after pruning.")
    aligned = align_to_index(kept_map, inter_idx if len(inter_idx)>0 else list(kept_map.values())[0].index)

    # Build wide OHLC
    Open  = pd.concat({k: v['Open']  for k, v in aligned.items()}, axis=1)
    High  = pd.concat({k: v['High']  for k, v in aligned.items()}, axis=1)
    Low   = pd.concat({k: v['Low']   for k, v in aligned.items()}, axis=1)
    Close = pd.concat({k: v['Close'] for k, v in aligned.items()}, axis=1)

    cols = Close.columns
    Open, High, Low = (df.reindex(columns=cols).astype(float) for df in (Open, High, Low))
    Close = Close.astype(float)

    log.info("Final universe: %d symbols | Date range: %s -> %s | Points: %d",
             len(cols), Close.index[0].date(), Close.index[-1].date(), len(Close))

    # Benchmark for VOLAR
    bench_price = Close.pct_change().median(axis=1).add(1).cumprod()
    bench_price.index = Close.index

    # Indicators
    ema_fast = ema(Close, cfg.ema_fast)
    ema_slow = ema(Close, cfg.ema_slow)
    ema_htf  = ema(Close, cfg.ema_htf)
    rsi14    = rsi(Close, 14)

    # Candles
    bull_candle = bullish_pattern_mask(Open, High, Low, Close, list(cfg.enable_patterns))

    # Entries
    entries = bull_candle.copy()
    if cfg.use_rsi_confirm:
        entries &= (rsi14 > 50)
    if cfg.use_trend_fast_slow:
        entries &= (ema_fast > ema_slow)
    if cfg.use_htf_trend:
        entries &= (Close > ema_htf)

    # 52w filter
    if cfg.use_52w_filter:
        high_52w = Close.rolling(cfg.filter_52w_window, min_periods=50).max()
        entries &= (Close >= cfg.within_pct_of_52w_high * high_52w)

    # VOLAR top-K
    if cfg.use_volar_ranking and cfg.top_k_daily > 0:
        ranked = pd.DataFrame(False, index=entries.index, columns=entries.columns)
        idx_with = entries.index[entries.any(axis=1)]
        for dt in idx_with:
            elig = entries.loc[dt]
            if not elig.any(): continue
            scores = volar_scores(Close, bench_price, dt, cfg.volar_lookback)
            scores = scores.where(elig, np.nan).dropna()
            if scores.empty: continue
            topn = scores.sort_values(ascending=False).head(cfg.top_k_daily).index
            ranked.loc[dt, topn] = True
        entries = ranked

    # Exits
    exits = pd.DataFrame(False, index=entries.index, columns=entries.columns)
    if cfg.use_indicator_exit:
        exits = ema_fast < ema_slow

    # Weights
    weights = pd.DataFrame(0.0, index=entries.index, columns=entries.columns)
    ret = Close.pct_change()
    idx_with = entries.index[entries.any(axis=1)]
    for dt in idx_with:
        cols_today = entries.columns[entries.loc[dt]]
        if len(cols_today) == 0: continue
        if cfg.use_mvo_sizing and len(cols_today) >= 2:
            R = ret.loc[:dt, cols_today].dropna().iloc[-cfg.volar_lookback:]
            w = try_mvo_weights(R)
            if w.empty:
                w = pd.Series(1/len(cols_today), index=cols_today)
            else:
                w = w.reindex(cols_today).fillna(0.0)
                w = w / w.sum() if w.sum() > 0 else pd.Series(1/len(cols_today), index=cols_today)
        else:
            w = pd.Series(1/len(cols_today), index=cols_today)
        weights.loc[dt, cols_today] = w * cfg.deploy_cash_frac

    log.info("Building vectorbt portfolio...")
    pf = vbt.Portfolio.from_signals(
        Close,
        entries=entries,
        exits=exits,
        sl_stop=cfg.stop_loss_pct if cfg.stop_loss_pct > 0 else None,
        tp_stop=cfg.target_pct    if cfg.target_pct    > 0 else None,
        fees=cfg.fees,
        slippage=cfg.slippage,
        init_cash=1_000_000.0,
        cash_sharing=True,
        weights=weights,
        freq='1D'
    )

    stats = pf.stats()
    print("\n=== Portfolio Stats ===")
    print(stats)

    stamp = pd.Timestamp.now(tz="Asia/Kolkata").strftime("%Y%m%d_%H%M%S")
    if cfg.write_csv:
        ensure_dir(cfg.out_dir)
        entries.astype(int).to_csv(os.path.join(cfg.out_dir, f"entries_{stamp}.csv"))
        exits.astype(int).to_csv(os.path.join(cfg.out_dir, f"exits_{stamp}.csv"))
        weights.to_csv(os.path.join(cfg.out_dir, f"weights_{stamp}.csv"))
        pf.value().to_csv(os.path.join(cfg.out_dir, f"equity_{stamp}.csv"), header=["equity"])
        with open(os.path.join(cfg.out_dir, f"stats_{stamp}.json"), "w") as f:
            json.dump(stats.to_dict(), f, indent=2)
        log.info("Files written to %s", cfg.out_dir)

    if cfg.plot:
        try:
            pf.value().vbt.plot(title="Equity Curve (VectorBT Candle + RSI)").show()
        except Exception:
            pass

if __name__ == "__main__":
    main(CFG)


2025-11-02 19:01:04 | INFO | Loaded 5 tickers (after cleaning)
2025-11-02 19:01:04 | INFO | Bulk downloading with vectorbt.YFData (drop mismatches) ...
2025-11-02 19:01:06 | INFO | Per-symbol fallback via yfinance for 5 tickers ...
2025-11-02 19:01:36 | ERROR | No valid symbols after cleaning. Using SYNTHETIC OHLC (GBM) so you can run end-to-end.
2025-11-02 19:01:36 | INFO | Final universe: 5 symbols | Date range: 2023-07-17 -> 2025-10-31 | Points: 600
2025-11-02 19:01:37 | INFO | Building vectorbt portfolio...


TypeError: Portfolio.__init__() got an unexpected keyword argument 'weights'