In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Supertrend + Pivot (R1/S1) Intraday Backtest — NIFTY50 (5m)
-----------------------------------------------------------
Entries
  Long : Close crosses ABOVE R1 and Supertrend is GREEN
  Short: Close crosses BELOW S1 and Supertrend is RED

Exits
  Supertrend stop/flip, optional SL/TP, EOD square-off

Defaults
  • Date range: last 60 days up to today (IST), 5-minute bars
  • Capital: ₹1,00,000 with 5× intraday margin (deployable ₹5,00,000)
  • Max concurrent trades: 5
  • Long/Short toggles, SL/TP toggles
  • Next-bar-open execution (realistic)
  • Groww intraday cost model
  • Trade log: outputs/trades.csv
"""

from __future__ import annotations
import math, os, sys, warnings, logging, datetime
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
try:
    import yfinance as yf
    import pytz
except Exception:
    print("Please install dependencies: pip install pandas numpy yfinance pytz")
    sys.exit(1)

# =========================
# LOGGING
# =========================
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("supertrend_pivot_backtest")

# =========================
# CONFIG
# =========================
def default_dates_last_60d_ist() -> tuple[str, str]:
    tz = pytz.timezone("Asia/Kolkata")
    today = datetime.datetime.now(tz).date()
    start = (today - datetime.timedelta(days=60)).isoformat()
    end = (today + datetime.timedelta(days=1)).isoformat()  # Yahoo end is exclusive
    return start, end

@dataclass
class Config:
    # Dates
    start_date, end_date = default_dates_last_60d_ist()
    interval: str = "5m"
    tz: str = "Asia/Kolkata"

    # Universe
    use_nifty50_universe: bool = True
    tickers: Optional[List[str]] = None
    append_suffix_ns: bool = True

    # Strategy toggles
    enable_longs: bool  = True
    enable_shorts: bool = True
    use_next_bar_open: bool = True
    use_eod_squareoff: bool = True
    eod_squareoff_time: str = "15:25"  # IST

    # Supertrend
    st_atr_period: int = 7
    st_multiplier: float = 3.0

    # Entry logic
    require_close_cross: bool = True  # True=cross; False=touch or cross

    # Risk & money management
    account_rupees: float = 100_000.0
    intraday_margin: float = 5.0
    max_concurrent_trades: int = 5
    capital_fraction_per_trade: Optional[float] = None  # None = equal split

    # Optional SL/TP
    enable_stop_loss: bool = True
    stop_loss_pct: float = 0.01
    enable_take_profit: bool = False
    take_profit_pct: float = 0.02

    # Costs & outputs
    enable_costs: bool = True
    out_dir: str = "outputs"
    trades_csv: str = "trades.csv"

CFG = Config()

# =========================
# Groww intraday cost model
# =========================
def groww_intraday_charges(buy_turnover: float, sell_turnover: float) -> Dict[str, float]:
    def brokerage(turnover):
        fee = min(20.0, 0.001 * turnover)
        return max(5.0, fee)  # floor ₹5
    bro_buy  = brokerage(buy_turnover)
    bro_sell = brokerage(sell_turnover)
    exch_buy  = 0.0000297 * buy_turnover
    exch_sell = 0.0000297 * sell_turnover
    sebi_buy  = 0.000001 * buy_turnover
    sebi_sell = 0.000001 * sell_turnover
    ipft_buy  = 0.000001 * buy_turnover
    ipft_sell = 0.000001 * sell_turnover
    gst_buy  = 0.18 * (bro_buy  + exch_buy  + sebi_buy  + ipft_buy)
    gst_sell = 0.18 * (bro_sell + exch_sell + sebi_sell + ipft_sell)
    stt_sell = 0.00025 * sell_turnover
    stamp_buy = 0.00003 * buy_turnover
    total = (bro_buy + bro_sell + exch_buy + exch_sell +
             sebi_buy + sebi_sell + ipft_buy + ipft_sell +
             gst_buy + gst_sell + stt_sell + stamp_buy)
    return {"total_charges": total}

# =========================
# Utilities
# =========================
def ensure_dir(p: str): os.makedirs(p, exist_ok=True)

def make_nifty50() -> List[str]:
    base = [
        "ADANIENT","ADANIPORTS","APOLLOHOSP","ASIANPAINT","AXISBANK","BAJAJ-AUTO","BAJFINANCE",
        "BAJAJFINSV","BHARTIARTL","BPCL","BRITANNIA","CIPLA","COALINDIA","DIVISLAB","DRREDDY",
        "EICHERMOT","GRASIM","HCLTECH","HDFCBANK","HDFCLIFE","HEROMOTOCO","HINDALCO","HINDUNILVR",
        "ICICIBANK","INDUSINDBK","INFY","ITC","JSWSTEEL","KOTAKBANK","LT","M&M","MARUTI",
        "NESTLEIND","NTPC","ONGC","POWERGRID","RELIANCE","SBILIFE","SBIN","SUNPHARMA","TATACONSUM",
        "TATAMOTORS","TATASTEEL","TCS","TECHM","TITAN","ULTRACEMCO","UPL","WIPRO","SHRIRAMFIN"
    ]
    return [s + ".NS" for s in base] if CFG.append_suffix_ns else base

def as_ist(ts: pd.Timestamp) -> pd.Timestamp:
    return ts.tz_convert(CFG.tz) if ts.tzinfo else ts.tz_localize("UTC").tz_convert(CFG.tz)

# =========================
# Indicators
# =========================
def atr(df: pd.DataFrame, period: int) -> pd.Series:
    h, l, c = df["High"], df["Low"], df["Close"]
    prev_c = c.shift(1)
    tr = pd.concat([(h - l), (h - prev_c).abs(), (l - prev_c).abs()], axis=1).max(axis=1)
    return tr.rolling(period).mean()

def supertrend(df: pd.DataFrame, period: int, mult: float) -> pd.DataFrame:
    hl2 = (df["High"] + df["Low"]) / 2.0
    _atr = atr(df, period)
    up = hl2 + mult * _atr
    dn = hl2 - mult * _atr
    st = pd.Series(index=df.index, dtype=float)
    dir_ = pd.Series(index=df.index, dtype=int)
    st.iloc[0] = up.iloc[0]
    dir_.iloc[0] = -1
    for i in range(1, len(df)):
        if dir_.iloc[i-1] == -1:
            st_val = min(up.iloc[i], st.iloc[i-1])
            st.iloc[i] = st_val
            dir_.iloc[i] = -1 if df["Close"].iloc[i] <= st_val else 1
        else:
            st_val = max(dn.iloc[i], st.iloc[i-1])
            st.iloc[i] = st_val
            dir_.iloc[i] = 1 if df["Close"].iloc[i] >= st_val else -1
    return pd.DataFrame({"st": st, "st_dir": dir_})

# =========================
# Pivots (from previous day's daily bar) — SAFE
# =========================
def daily_pivots_from_prev_day(daily: pd.DataFrame) -> pd.DataFrame:
    """
    Compute previous-day Classic pivots (R1/S1) robustly.
    Returns empty df if daily history < 2 rows or columns missing.
    """
    # Coerce to DataFrame if Series
    if isinstance(daily, pd.Series):
        daily = daily.to_frame().T

    # Guard: need at least 2 rows to reference "previous day"
    if daily is None or daily.empty or len(daily.index) < 2:
        return pd.DataFrame(columns=["R1", "S1"])

    # Guard: must have these columns
    required = {"High", "Low", "Close"}
    if not required.issubset(set(daily.columns)):
        return pd.DataFrame(columns=["R1", "S1"])

    prev = daily.shift(1)  # prior day values aligned to current index
    pp = (prev["High"] + prev["Low"] + prev["Close"]) / 3.0
    r1 = 2 * pp - prev["Low"]
    s1 = 2 * pp - prev["High"]

    piv = pd.DataFrame({"R1": r1, "S1": s1})
    piv.index = daily.index  # align dates
    return piv

def attach_pivots(mins: pd.DataFrame, pivots: pd.DataFrame) -> pd.DataFrame:
    """
    Broadcast prior-day R1/S1 to each intraday row of current day.
    If pivots are empty, fills NaN (later bars with NaN levels are dropped).
    """
    if mins.empty:
        return mins
    idx_ist = mins.index.tz_convert(CFG.tz)
    days = idx_ist.normalize()
    rows = []
    for d in np.unique(days):
        mask = (days == d)
        # Use pivot row at current day (which contains prior day's values due to shift)
        r1 = s1 = np.nan
        if not pivots.empty and d in pivots.index:
            r1 = pivots.at[d, "R1"]
            s1 = pivots.at[d, "S1"]
        blk = pd.DataFrame({"R1": r1, "S1": s1}, index=mins.index[mask])
        rows.append(blk)
    lv = pd.concat(rows) if rows else pd.DataFrame(index=mins.index, columns=["R1", "S1"])
    out = pd.concat([mins, lv], axis=1)
    return out

# =========================
# Download data (timezone-safe)
# =========================
def _tz_to_utc(index: pd.DatetimeIndex) -> pd.DatetimeIndex:
    if getattr(index, "tz", None) is None:
        return index.tz_localize("UTC")
    return index.tz_convert("UTC")

def fetch_minute_and_daily(sym: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    df = yf.download(
        sym,
        start=CFG.start_date,
        end=CFG.end_date,
        interval=CFG.interval,
        auto_adjust=False,
        progress=False,
    )
    if df.empty:
        return df, df
    df.index = _tz_to_utc(df.index)

    daily = yf.download(
        sym,
        start=pd.to_datetime(CFG.start_date) - pd.Timedelta(days=5),
        end=CFG.end_date,
        interval="1d",
        auto_adjust=False,
        progress=False,
    )
    if not daily.empty:
        daily.index = _tz_to_utc(daily.index)

    return df, daily

# =========================
# Core backtest
# =========================
class Position:
    def __init__(self, sym, side, qty, t_in, p_in, r1, s1):
        self.sym, self.side, self.qty = sym, side, qty
        self.t_in, self.p_in, self.r1, self.s1 = t_in, p_in, r1, s1
    def sign(self): return 1 if self.side == "long" else -1

def capital_per_trade() -> float:
    total = CFG.account_rupees * CFG.intraday_margin
    return (total / CFG.max_concurrent_trades
            if CFG.capital_fraction_per_trade is None
            else total * CFG.capital_fraction_per_trade)

def run_backtest() -> pd.DataFrame:
    universe = make_nifty50() if CFG.use_nifty50_universe else (CFG.tickers or [])
    if not universe:
        raise ValueError("No tickers found.")
    ensure_dir(CFG.out_dir)

    cap_each = capital_per_trade()
    log.info(f"Capital per trade (with {CFG.intraday_margin}x): ₹{cap_each:,.0f}")

    data, ts_union = {}, None
    usable = 0

    # Prep data
    for s in universe:
        m, d = fetch_minute_and_daily(s)
        if m.empty or d.empty:
            log.warning(f"{s}: empty minute/daily; skipping.")
            continue
        piv = daily_pivots_from_prev_day(d)
        if piv.empty:
            log.warning(f"{s}: not enough daily history for pivots; skipping.")
            continue
        m = attach_pivots(m, piv)
        st = supertrend(m, CFG.st_atr_period, CFG.st_multiplier)
        df = pd.concat([m, st], axis=1)
        # Need valid levels and supertrend
        df = df.dropna(subset=["R1", "S1", "st", "st_dir"])
        if df.empty:
            log.warning(f"{s}: indicators/levels NA; skipping.")
            continue
        data[s] = df
        ts_union = df.index if ts_union is None else ts_union.union(df.index)
        usable += 1

    if usable == 0:
        log.error("No usable symbols after preparation. Adjust window/toggles.")
        return pd.DataFrame()

    trades, openpos = [], {}
    ts_union = ts_union.sort_values()
    eod_hh, eod_mm = map(int, CFG.eod_squareoff_time.split(":"))

    for ts in ts_union:
        ts_local = ts.tz_convert(CFG.tz)
        hhmm = (ts_local.hour, ts_local.minute)

        # --------- Exits ----------
        to_close = []
        for s, pos in list(openpos.items()):
            df = data[s]
            if ts not in df.index:
                continue
            r = df.loc[ts]
            c, o = r["Close"], r["Open"]
            st, st_dir = r["st"], int(r["st_dir"])
            px = o if CFG.use_next_bar_open else c

            st_exit = (c < st or st_dir < 0) if pos.side == "long" else (c > st or st_dir > 0)
            sl_hit = tp_hit = False
            if CFG.enable_stop_loss:
                sl = pos.p_in * (1 - CFG.stop_loss_pct * pos.sign())
                sl_hit = (c <= sl) if pos.side == "long" else (c >= sl)
            if CFG.enable_take_profit:
                tp = pos.p_in * (1 + CFG.take_profit_pct * pos.sign())
                tp_hit = (c >= tp) if pos.side == "long" else (c <= tp)

            reason = None
            if st_exit: reason = "supertrend_exit"
            if CFG.enable_stop_loss and sl_hit: reason = "stop_loss"
            if CFG.enable_take_profit and tp_hit: reason = "take_profit"
            if CFG.use_eod_squareoff and hhmm >= (eod_hh, eod_mm): reason = "eod_squareoff"

            if reason:
                to_close.append((s, px, reason, ts))

        for s, px, reason, tsx in to_close:
            pos = openpos.pop(s)
            q = pos.qty
            if pos.side == "long":
                buy_turn, sell_turn = pos.p_in*q, px*q
                gross = (px - pos.p_in) * q
            else:
                buy_turn, sell_turn = px*q, pos.p_in*q
                gross = (pos.p_in - px) * q
            cost = groww_intraday_charges(buy_turn, sell_turn)["total_charges"] if CFG.enable_costs else 0.0
            net = gross - cost
            trades.append(dict(
                symbol=s, side=pos.side, qty=q,
                entry_time=as_ist(pos.t_in), exit_time=as_ist(tsx),
                entry_price=round(pos.p_in,2), exit_price=round(px,2),
                gross_pnl=round(gross,2), charges=round(cost,2),
                net_pnl=round(net,2), exit_reason=reason
            ))
            log.info(f"CLOSE {s} {pos.side} x{q} @ {px:.2f} | {reason} | Net ₹{net:.2f}")

        # --------- Entries ----------
        if len(openpos) >= CFG.max_concurrent_trades:
            continue

        candidates: List[Tuple[str, str, float, float]] = []
        for s, df in data.items():
            if s in openpos or ts not in df.index:
                continue
            r = df.loc[ts]
            c = r["Close"]
            prev = df["Close"].shift(1).reindex(df.index).loc[ts]
            r1, s1, st_dir = r["R1"], r["S1"], int(r["st_dir"])
            px = r["Open"] if CFG.use_next_bar_open else c

            if CFG.enable_longs and st_dir > 0:
                cond = (prev <= r1 and c > r1) if CFG.require_close_cross else (c >= r1)
                if cond:
                    candidates.append((s, "long", px, float(c - r1)))

            if CFG.enable_shorts and st_dir < 0:
                cond = (prev >= s1 and c < s1) if CFG.require_close_cross else (c <= s1)
                if cond:
                    candidates.append((s, "short", px, float(s1 - c)))

        # Rank by distance from level (stronger break first)
        candidates.sort(key=lambda x: x[3], reverse=True)
        capacity = CFG.max_concurrent_trades - len(openpos)

        for s, side, px, _ in candidates[:capacity]:
            qty = int(capital_per_trade() // px)
            if qty <= 0:
                continue
            r = data[s].loc[ts]
            openpos[s] = Position(s, side, qty, ts, px, r["R1"], r["S1"])
            log.info(f"OPEN  {s} {side} x{qty} @ {px:.2f}")

    # No forced close at end (intraday focus), but you can add if desired
    return pd.DataFrame(trades)

# =========================
# Main
# =========================
EXPECTED_COLS = [
    "symbol","side","qty","entry_time","exit_time",
    "entry_price","exit_price","gross_pnl","charges","net_pnl","exit_reason"
]

def main():
    ensure_dir(CFG.out_dir)
    trades = run_backtest()
    out = os.path.join(CFG.out_dir, CFG.trades_csv)

    if trades.empty:
        trades = pd.DataFrame(columns=EXPECTED_COLS)
        trades.to_csv(out, index=False)
        log.warning("No trades generated. Try: disable 'require_close_cross', enable both long & short, or widen the window.")
        return

    if "entry_time" in trades.columns:
        trades.sort_values("entry_time", inplace=True)

    trades.to_csv(out, index=False)
    total = trades["net_pnl"].sum()
    winrate = (trades["net_pnl"] > 0).mean() * 100
    log.info(f"Saved -> {out}")
    log.info(f"Trades={len(trades)} | Winrate={winrate:.1f}% | Net P&L=₹{total:,.2f}")

if __name__ == "__main__":
    main()


2025-11-06 00:44:42 | INFO | Capital/trade = ₹100,000


ValueError: If using all scalar values, you must pass an index