In [1]:
# backtest_groww_fees_grid_threaded.py
"""
Threaded grid-search backtester with:
 - warm cache (parquet preferred, CSV fallback)
 - robust column normalization / duplicate handling
 - indicators: EMA, Supertrend (from custom_indicators) + optional RSI/MACD/ADX/SMA/Bollinger
 - Groww-style fees + GST + STT + DP + slippage
 - isolated capital_per_trade allocation (₹10,000 default)
 - threaded combos (no spawn pickling issues)
"""

import os
import time
import pathlib
import itertools
import traceback
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
import numpy as np
import yfinance as yf

# Import your custom indicators module: must define calculate_ema(df, length) and calculate_supertrend(df, length, multiplier)
from custom_indicators import calculate_ema, calculate_supertrend

# -------------------- CONFIG --------------------
TICKER_LIST = ['SBICARD.NS', 'BDL.NS', 'INDHOTEL.NS', 'BSE.NS', 'NYKAA.NS', 'BAJFINANCE.NS', 'PAYTM.NS', 'SOLARINDS.NS', 'CHOLAFIN.NS', 'UNITDSPR.NS', 'DIVISLAB.NS', 'MUTHOOTFIN.NS', 'BHARTIARTL.NS', 'ICICIBANK.NS', 'MAZDOCK.NS', 'SHREECEM.NS', 'DIXON.NS', 'PERSISTENT.NS', 'SRF.NS', 'TVSMOTOR.NS', 'SBILIFE.NS', 'MAXHEALTH.NS', 'MFSL.NS', 'COFORGE.NS', 'HDFCLIFE.NS', 'INDIGO.NS', 'KOTAKBANK.NS', 'HDFCBANK.NS', 'BEL.NS', 'BAJAJFINSV.NS']


START_DATE = "2015-01-01"
END_DATE = "2025-01-01"

# base indicators
ST_LENGTH = 10
ST_MULTIPLIER = 3.0
EMA_FAST_LENGTH = 9
EMA_SLOW_LENGTH = 15

# trade behavior defaults
DEFAULT_TRAILING_STOP_PCT = 10.0
DEFAULT_HARD_STOP_PCT = 5.0
DEFAULT_MAX_HOLD_DAYS = 10

# grid results / output
GRID_RESULTS_CSV = "grid_search_results_with_fees.csv"
SAVE_DETAILED_TRADES = True
DETAILED_DIR = "grid_search_runs"

# threading / parallelization
NUM_WORKERS = min(8, (os.cpu_count() or 1))
PARALLELIZE = 'combos'  # 'combos' or 'tickers'

# warm cache before threaded run? (recommended True)
WARM_CACHE = True

# isolated capital per trade (INR)
CAPITAL_PER_TRADE = 50000.0

# -------------------- FEES & SLIPPAGE (Groww-style) --------------------
SLIPPAGE_PCT = 0.05  # percent per side (adverse slippage)
# Groww brokerage: "₹20 OR 0.1% per executed order — whichever is lower, minimum ₹5"
BROKERAGE_CAP_RUPEES = 20.0
BROKERAGE_PCT_CAP = 0.1  # percent
BROKERAGE_MIN_RUPEES = 5.0

# regulatory/exchange charges (approx values used on Groww page)
EXCHANGE_TXN_PCT = 0.00297  # percent per side
SEBI_TURNOVER_PCT = 0.0001  # percent per side
IPFT_PCT = 0.0001  # percent per side

# STT for delivery sell
STT_SELL_PCT = 0.025  # percent on sell notional (delivery)

# stamp duty (typical small state example)
STAMP_DUTY_BUY_PCT = 0.003  # percent on buy notional

# DP charge
DP_CHARGE_SELL = 16.5  # rupees per sell (delivery) if notional >= 100

# GST
GST_PCT = 18.0

# -------------------- GRID SEARCH SPACE --------------------
GRID_SEARCH_SPACE = {
    "rsi_enabled": [True, False],
    "rsi_buy_below": [35, 40],
    "rsi_sell_above": [65, 70],

    "macd_enabled": [True, False],
    "macd_fast": [12],
    "macd_slow": [26],
    "macd_signal": [9],

    "adx_enabled": [True, False],
    "adx_length": [14],
    "adx_min_adx": [18, 20, 25],

    "sma_enabled": [False, True],
    "sma_length": [30, 50],

    "bb_enabled": [False, True],
    "bb_length": [20],
    "bb_stddev": [2.0],

    "combination_mode": ['all', 'any']
}
# --------------------------------------------------


# -------------------- CACHE HELPERS (parquet preferred, csv fallback) --------------------
CACHE_DIR = "cache"
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)

_parquet_engine = None
try:
    import pyarrow  # noqa: F401
    _parquet_engine = "pyarrow"
except Exception:
    try:
        import fastparquet  # noqa: F401
        _parquet_engine = "fastparquet"
    except Exception:
        _parquet_engine = None

if _parquet_engine:
    print(f"Cache: using parquet engine '{_parquet_engine}'.")
else:
    print("Cache: parquet engine not found — falling back to CSV cache (no extra deps required).")

def ensure_cache_dir():
    pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)

def _cache_path_for_ticker(ticker, use_parquet):
    safe = ticker.replace('/', '_').replace(':', '_')
    if use_parquet:
        return os.path.join(CACHE_DIR, f"{safe}.parquet")
    else:
        return os.path.join(CACHE_DIR, f"{safe}.csv")

def download_with_retry(ticker, start, end, interval="4h", max_retries=3, backoff_sec=2):
    attempt = 0
    while attempt < max_retries:
        try:
            df = yf.download(ticker, start=start, interval=interval, end=end,
                             auto_adjust=True, progress=False, threads=True, multi_level_index=False)
            return df
        except Exception as e:
            attempt += 1
            wait = backoff_sec * attempt
            print(f"  download error for {ticker} (attempt {attempt}/{max_retries}): {e}. retrying in {wait}s")
            time.sleep(wait)
    print(f"  download failed for {ticker} after {max_retries} attempts.")
    return None

def _read_cache(ticker):
    use_parquet = _parquet_engine is not None
    path = _cache_path_for_ticker(ticker, use_parquet)
    alt = _cache_path_for_ticker(ticker, not use_parquet)

    paths_to_try = []
    if os.path.exists(path):
        paths_to_try.append(path)
    if os.path.exists(alt):
        paths_to_try.append(alt)

    for p in paths_to_try:
        try:
            if os.path.getsize(p) == 0:
                print(f"  Warning: cache file {p} is zero bytes — removing.")
                try:
                    os.remove(p)
                except Exception as e:
                    print(f"    Could not remove zero-byte file {p}: {e}")
                continue
        except OSError:
            continue

        try:
            if p.endswith('.parquet'):
                return pd.read_parquet(p, engine=_parquet_engine)
            else:
                return pd.read_csv(p, index_col=0, parse_dates=True)
        except Exception as e:
            print(f"  Warning: failed to read cache for {ticker} at {p}: {e}")
            try:
                os.remove(p)
                print(f"    Removed corrupted cache file {p}.")
            except Exception as rm_e:
                print(f"    Could not remove corrupted cache file {p}: {rm_e}")
            continue
    return None

def _write_cache(ticker, df):
    use_parquet = _parquet_engine is not None
    path = _cache_path_for_ticker(ticker, use_parquet)
    tmp_path = path + ".tmp"
    try:
        os.makedirs(os.path.dirname(path), exist_ok=True)
        if use_parquet:
            df.to_parquet(tmp_path, engine=_parquet_engine)
        else:
            df.to_csv(tmp_path)
        os.replace(tmp_path, path)
        return True
    except Exception as e:
        print(f"  Warning: failed to write cache for {ticker} to {path}: {e}")
        try:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
        except Exception:
            pass
        if use_parquet:
            try:
                alt = _cache_path_for_ticker(ticker, use_parquet=False)
                tmp_alt = alt + ".tmp"
                df.to_csv(tmp_alt)
                os.replace(tmp_alt, alt)
                print(f"  Wrote CSV fallback cache for {ticker} to {alt}")
                return True
            except Exception as e2:
                print(f"  Warning: fallback CSV write also failed for {ticker}: {e2}")
        return False

def warm_cache(tickers, start=START_DATE, end=END_DATE):
    print(f"Warm cache: downloading {len(tickers)} tickers sequentially ...")
    ensure_cache_dir()
    for t in tickers:
        try:
            existing = _read_cache(t)
            if existing is not None:
                continue
            df = download_with_retry(t, start, end, interval="1d", max_retries=3, backoff_sec=2)
            if df is None or df.empty:
                print(f"  warm_cache: no data for {t}; skipping.")
                continue
            df = normalize_df_columns(df)
            df = collapse_duplicate_columns_take_first(df)
            ok = _write_cache(t, df)
            if not ok:
                print(f"  warm_cache: failed to write cache for {t}")
            else:
                print(f"  warm_cache: cached {t}")
            time.sleep(0.25)
        except Exception as e:
            print(f"  warm_cache: error for {t}: {e}\n{traceback.format_exc()}")
            continue
    print("Warm cache: done.")

# -------------------- NORMALIZATION & DUPLICATE HANDLING --------------------
def normalize_df_columns(df):
    if hasattr(df, "columns") and getattr(df.columns, "nlevels", 1) > 1:
        df.columns = ["_".join([str(c) for c in col if c is not None]).strip() for col in df.columns.values]

    cols = list(df.columns)
    mapping = {}
    lower_map = {c.lower(): c for c in cols}
    for name in ['close', 'high', 'low', 'open', 'volume']:
        if name in lower_map:
            mapping[lower_map[name]] = name.capitalize()
        else:
            match = next((c for c in cols if name in c.lower()), None)
            if match:
                mapping[match] = name.capitalize()
    if mapping:
        df = df.rename(columns=mapping)
    return df

def collapse_duplicate_columns_take_first(df):
    if df.columns.duplicated().any():
        dup_names = list({c for c in df.columns[df.columns.duplicated()]})
        print(f"  Warning: duplicate columns found and collapsed for: {dup_names}")
        df = df.groupby(df.columns, axis=1).first()
    return df

# -------------------- INDICATOR HELPERS --------------------
def calculate_rsi(df, length=14, column='Close'):
    delta = df[column].diff()
    gain = delta.clip(lower=0)
    loss = -1 * delta.clip(upper=0)
    avg_gain = gain.ewm(alpha=1/length, adjust=False).mean()
    avg_loss = loss.ewm(alpha=1/length, adjust=False).mean()
    rs = avg_gain / (avg_loss.replace(0, np.nan))
    rsi = 100 - (100 / (1 + rs))
    df['RSI'] = rsi.fillna(50)
    return df

def calculate_macd(df, fast=12, slow=26, signal=9, column='Close'):
    ema_fast = df[column].ewm(span=fast, adjust=False).mean()
    ema_slow = df[column].ewm(span=slow, adjust=False).mean()
    macd_line = ema_fast - ema_slow
    signal_line = macd_line.ewm(span=signal, adjust=False).mean()
    df['MACD'] = macd_line
    df['MACD_signal'] = signal_line
    df['MACD_hist'] = df['MACD'] - df['MACD_signal']
    return df

def calculate_sma(df, length=50, column='Close'):
    df[f'SMA_{length}'] = df[column].rolling(length).mean()
    return df

def calculate_bbands(df, length=20, stddev=2, column='Close'):
    sma = df[column].rolling(length).mean()
    std = df[column].rolling(length).std()
    df['BB_middle'] = sma
    df['BB_upper'] = sma + (stddev * std)
    df['BB_lower'] = sma - (stddev * std)
    return df

def calculate_adx(df, n=14):
    high = df['High']; low = df['Low']; close = df['Close']
    tr1 = high - low
    tr2 = (high - close.shift(1)).abs()
    tr3 = (low - close.shift(1)).abs()
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    up_move = high.diff(); down_move = -low.diff()
    plus_dm = np.where((up_move > down_move) & (up_move > 0), up_move, 0.0)
    minus_dm = np.where((down_move > up_move) & (down_move > 0), down_move, 0.0)
    tr_smooth = pd.Series(tr).rolling(window=n).sum()
    plus_dm_smooth = pd.Series(plus_dm).rolling(window=n).sum()
    minus_dm_smooth = pd.Series(minus_dm).rolling(window=n).sum()
    plus_di = 100 * (plus_dm_smooth / tr_smooth).replace([np.inf, -np.inf], 0).fillna(0)
    minus_di = 100 * (minus_dm_smooth / tr_smooth).replace([np.inf, -np.inf], 0).fillna(0)
    dx = (abs(plus_di - minus_di) / (plus_di + minus_di)).replace([np.inf, -np.inf], 0) * 100
    adx = dx.rolling(window=n).mean()
    df['+DI'] = plus_di; df['-DI'] = minus_di; df['ADX'] = adx
    return df

# -------------------- DATA PREP (cache-aware) --------------------
def get_stock_data(ticker, start, end, indicator_params):
    ensure_cache_dir()
    raw_df = _read_cache(ticker)
    if raw_df is None:
        raw_df = download_with_retry(ticker, start, end, interval="1d", max_retries=3, backoff_sec=2)
        if raw_df is None or raw_df.empty:
            return None
        raw_df = normalize_df_columns(raw_df)
        raw_df = collapse_duplicate_columns_take_first(raw_df)
        ok = _write_cache(ticker, raw_df)
        if not ok:
            print(f"  Warning: failed to cache {ticker} (continuing without cache).")

    df = raw_df.copy()
    df = normalize_df_columns(df)
    df = collapse_duplicate_columns_take_first(df)

    if 'Close' not in df.columns:
        close_col = next((c for c in df.columns if 'close' in c.lower()), None)
        if close_col:
            df = df.rename(columns={close_col: 'Close'})
        else:
            print(f"  Warning: {ticker} data missing Close column after normalization. Columns: {list(df.columns)[:10]}")
            return None

    for c in ['Close', 'High', 'Low', 'Open', 'Volume']:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors='coerce')

    try:
        df['EMA_fast'] = df['Close'].ewm(span=EMA_FAST_LENGTH, adjust=False).mean()
        df['EMA_slow'] = df['Close'].ewm(span=EMA_SLOW_LENGTH, adjust=False).mean()
    except Exception as e:
        try:
            df['EMA_fast'] = calculate_ema(df, EMA_FAST_LENGTH)
            df['EMA_slow'] = calculate_ema(df, EMA_SLOW_LENGTH)
        except Exception as e2:
            print(f"  Error computing EMA for {ticker}: {e} / fallback error: {e2}")
            return None

    try:
        df = calculate_supertrend(df, ST_LENGTH, ST_MULTIPLIER)
    except Exception as e:
        print(f"  Error computing Supertrend for {ticker}: {e}")
        return None

    if indicator_params.get('rsi_enabled', False):
        df = calculate_rsi(df, length=indicator_params.get('rsi_length', 14))
    if indicator_params.get('macd_enabled', False):
        df = calculate_macd(df, fast=indicator_params.get('macd_fast', 12),
                             slow=indicator_params.get('macd_slow', 26),
                             signal=indicator_params.get('macd_signal', 9))
    if indicator_params.get('adx_enabled', False):
        df = calculate_adx(df, n=indicator_params.get('adx_length', 14))
    if indicator_params.get('sma_enabled', False):
        df = calculate_sma(df, length=indicator_params.get('sma_length', 50))
    if indicator_params.get('bb_enabled', False):
        df = calculate_bbands(df, length=indicator_params.get('bb_length', 20), stddev=indicator_params.get('bb_stddev', 2.0))

    df.dropna(inplace=True)
    if df.empty:
        return None
    return df

# -------------------- FEE MODEL (Groww) & slippage helpers --------------------
def groww_brokerage_for_order(order_value):
    pct_based = (BROKERAGE_PCT_CAP / 100.0) * order_value
    capped = min(BROKERAGE_CAP_RUPEES, pct_based)
    return max(capped, BROKERAGE_MIN_RUPEES)

def apply_fees_and_slippage_groww(entry_price, exit_price, position_size=1.0, is_delivery=True):
    """
    Returns:
      entry_eff, exit_eff,
      gross_return_pct, net_return_pct,
      fees_total, fees_breakdown,
      gross_profit_currency, net_profit_currency
    """
    entry_eff = entry_price * (1.0 + SLIPPAGE_PCT / 100.0)
    exit_eff = exit_price * (1.0 - SLIPPAGE_PCT / 100.0)

    entry_notional_eff = entry_eff * position_size
    exit_notional_eff = exit_eff * position_size

    gross_profit = exit_notional_eff - entry_notional_eff
    gross_return_pct = (gross_profit / entry_notional_eff) * 100.0 if entry_notional_eff != 0 else 0.0

    brokerage_entry = groww_brokerage_for_order(entry_notional_eff)
    brokerage_exit = groww_brokerage_for_order(exit_notional_eff)

    exch_entry = (EXCHANGE_TXN_PCT / 100.0) * entry_notional_eff
    exch_exit = (EXCHANGE_TXN_PCT / 100.0) * exit_notional_eff

    sebi_entry = (SEBI_TURNOVER_PCT / 100.0) * entry_notional_eff
    sebi_exit = (SEBI_TURNOVER_PCT / 100.0) * exit_notional_eff

    ipft_entry = (IPFT_PCT / 100.0) * entry_notional_eff
    ipft_exit = (IPFT_PCT / 100.0) * exit_notional_eff

    stamp_buy = (STAMP_DUTY_BUY_PCT / 100.0) * entry_notional_eff if is_delivery else 0.0
    stt_sell = (STT_SELL_PCT / 100.0) * exit_notional_eff if is_delivery else 0.0
    dp_sell = DP_CHARGE_SELL if (is_delivery and exit_notional_eff >= 100.0) else 0.0

    taxable = (brokerage_entry + brokerage_exit +
               exch_entry + exch_exit +
               ipft_entry + ipft_exit +
               sebi_entry + sebi_exit +
               dp_sell)

    gst = (GST_PCT / 100.0) * taxable

    fees_total = (brokerage_entry + brokerage_exit +
                  exch_entry + exch_exit +
                  sebi_entry + sebi_exit +
                  ipft_entry + ipft_exit +
                  stamp_buy + stt_sell + dp_sell + gst)

    net_profit = gross_profit - fees_total
    net_return_pct = (net_profit / entry_notional_eff) * 100.0 if entry_notional_eff != 0 else 0.0

    fees_breakdown = {
        "brokerage_entry": brokerage_entry,
        "brokerage_exit": brokerage_exit,
        "exchange_entry": exch_entry,
        "exchange_exit": exch_exit,
        "sebi_entry": sebi_entry,
        "sebi_exit": sebi_exit,
        "ipft_entry": ipft_entry,
        "ipft_exit": ipft_exit,
        "stamp_buy": stamp_buy,
        "stt_sell": stt_sell,
        "dp_sell": dp_sell,
        "gst": gst,
        "fees_total": fees_total
    }

    return (entry_eff, exit_eff,
            gross_return_pct, net_return_pct,
            fees_total, fees_breakdown,
            gross_profit, net_profit)

# -------------------- SIGNALS & BACKTEST (capital-per-trade sizing) --------------------
def indicator_buy_checks(row, prev_row, params):
    results = {}
    if params.get('rsi_enabled', False):
        r = row.get('RSI', np.nan); results['rsi'] = (r <= params.get('rsi_buy_below', 40))
    if params.get('macd_enabled', False):
        macd = row.get('MACD', 0); sig = row.get('MACD_signal', 0); hist = row.get('MACD_hist', 0)
        results['macd'] = (macd > sig) and (hist > 0.0)
    if params.get('adx_enabled', False):
        adx = row.get('ADX', 0); plus = row.get('+DI', 0); minus = row.get('-DI', 0)
        results['adx'] = (adx >= params.get('adx_min_adx', 20.0)) and (plus > minus)
    if params.get('sma_enabled', False):
        sma_col = f"SMA_{params.get('sma_length', 50)}"; results['sma'] = (row['Close'] > row.get(sma_col, np.nan))
    if params.get('bb_enabled', False):
        results['bb'] = (prev_row is not None) and (prev_row['Close'] < prev_row.get('BB_lower', np.nan)) and (row['Close'] > row.get('BB_lower', np.nan))
    return results

def combine_indicator_signals(ind_results, mode='all'):
    if not ind_results:
        return True
    vals = list(ind_results.values())
    return all(vals) if mode == 'all' else any(vals)

def run_backtest_on_df(df, ticker, params, trailing_stop_loss_pct=DEFAULT_TRAILING_STOP_PCT,
                       hard_stop_loss_pct=DEFAULT_HARD_STOP_PCT, max_holding_days=DEFAULT_MAX_HOLD_DAYS,
                       combination_mode='all', capital_per_trade=CAPITAL_PER_TRADE):
    """
    Each trade uses an isolated capital_per_trade; position_size = floor(capital / expected_entry_eff).
    If position_size == 0, the signal is skipped.
    """
    in_position = False
    trades = []
    entry_price = 0; entry_date = None
    peak_price = 0; trailing_stop_price = 0; hard_stop_price = 0; days_in_trade = 0
    tsl_multiplier = 1 - (trailing_stop_loss_pct / 100)
    hsl_multiplier = 1 - (hard_stop_loss_pct / 100)
    current_position_size = 0

    for i in range(1, len(df)):
        prev_row = df.iloc[i-1]; current_row = df.iloc[i]
        is_bullish_state = current_row['supertrend_direction'] == 1 and current_row['EMA_fast'] > current_row['EMA_slow']
        was_bullish_state = prev_row['supertrend_direction'] == 1 and prev_row['EMA_fast'] > prev_row['EMA_slow']
        buy_signal_base = is_bullish_state and not was_bullish_state
        stop_loss_signal = prev_row['supertrend_direction'] == 1 and current_row['supertrend_direction'] == -1

        ind_results = indicator_buy_checks(current_row, prev_row, params)
        ind_combined = combine_indicator_signals(ind_results, mode=combination_mode)
        enabled_any = any([params.get(k) for k in ['rsi_enabled', 'macd_enabled', 'adx_enabled', 'sma_enabled', 'bb_enabled']])
        buy_signal = buy_signal_base and (ind_combined if enabled_any else True)

        if not in_position and buy_signal:
            entry_price = current_row['Close']
            expected_entry_eff = entry_price * (1.0 + SLIPPAGE_PCT / 100.0)
            position_size = int(capital_per_trade // expected_entry_eff)
            if position_size <= 0:
                # skip trade
                in_position = False
                entry_price = 0; entry_date = None
                continue
            in_position = True
            current_position_size = position_size
            entry_date = current_row.name
            peak_price = entry_price
            trailing_stop_price = peak_price * tsl_multiplier
            hard_stop_price = entry_price * hsl_multiplier
            days_in_trade = 0

        elif in_position:
            days_in_trade += 1
            if current_row['High'] > peak_price:
                peak_price = current_row['High']; trailing_stop_price = peak_price * tsl_multiplier

            exit_price = 0; exit_reason = None

            if current_row['Low'] <= hard_stop_price:
                exit_price = hard_stop_price; exit_reason = "HSL"
            elif current_row['Low'] <= trailing_stop_price:
                exit_price = trailing_stop_price; exit_reason = "TSL"
            elif days_in_trade >= max_holding_days:
                exit_price = current_row['Close']; exit_reason = "Time"
            elif stop_loss_signal:
                exit_price = current_row['Close']; exit_reason = "SL"
            else:
                if params.get('rsi_enabled', False):
                    r = current_row.get('RSI', np.nan)
                    if r >= params.get('rsi_sell_above', 70):
                        exit_price = current_row['Close']; exit_reason = "RSI_exit"
                if exit_price == 0 and params.get('adx_enabled', False):
                    adx = current_row.get('ADX', np.nan)
                    if adx < params.get('adx_min_adx', 20.0):
                        exit_price = current_row['Close']; exit_reason = "ADX_weak"
                if exit_price == 0 and params.get('macd_enabled', False):
                    prev_macd = prev_row.get('MACD', 0); prev_sig = prev_row.get('MACD_signal', 0)
                    cur_macd = current_row.get('MACD', 0); cur_sig = current_row.get('MACD_signal', 0)
                    if (prev_macd > prev_sig) and (cur_macd < cur_sig):
                        exit_price = current_row['Close']; exit_reason = "MACD_cross_down"

            if exit_price > 0:
                in_position = False
                exit_date = current_row.name
                position_size = current_position_size

                (entry_eff, exit_eff,
                 gross_ret_pct_per_position, net_ret_pct_per_position,
                 fees_total, fees_breakdown,
                 gross_profit_currency, net_profit_currency) = apply_fees_and_slippage_groww(
                    entry_price, exit_price, position_size=position_size, is_delivery=True
                )

                entry_notional = entry_eff * position_size
                exit_notional = exit_eff * position_size

                roi_pct_on_capital = (net_profit_currency / capital_per_trade) * 100.0

                trades.append({
                    "ticker": ticker,
                    "entry_date": entry_date, "exit_date": exit_date,
                    "position_size": position_size,
                    "entry_price": entry_price, "exit_price": exit_price,
                    "entry_price_effective": entry_eff, "exit_price_effective": exit_eff,
                    "entry_notional": entry_notional, "exit_notional": exit_notional,
                    "gross_profit": gross_profit_currency,
                    "net_profit": net_profit_currency,
                    "gross_return_%": gross_ret_pct_per_position,
                    "net_return_%": net_ret_pct_per_position,
                    "return_%": roi_pct_on_capital,  # percent vs isolated capital_per_trade
                    "fees": fees_total,
                    "fees_breakdown": fees_breakdown,
                    "exit_reason": exit_reason,
                    "days_held": days_in_trade
                })
    return trades

# -------------------- GRID ENGINE --------------------
def generate_param_combinations(space):
    keys = list(space.keys()); vals = [space[k] for k in keys]
    combos = []
    for prod in itertools.product(*vals):
        combos.append(dict(zip(keys, prod)))
    return combos

def build_indicator_params_from_combo(combo):
    return {
        "rsi_enabled": bool(combo.get('rsi_enabled', False)),
        "rsi_buy_below": combo.get('rsi_buy_below', 40),
        "rsi_sell_above": combo.get('rsi_sell_above', 70),

        "macd_enabled": bool(combo.get('macd_enabled', False)),
        "macd_fast": combo.get('macd_fast', 12),
        "macd_slow": combo.get('macd_slow', 26),
        "macd_signal": combo.get('macd_signal', 9),

        "adx_enabled": bool(combo.get('adx_enabled', False)),
        "adx_length": combo.get('adx_length', 14),
        "adx_min_adx": combo.get('adx_min_adx', 20),

        "sma_enabled": bool(combo.get('sma_enabled', False)),
        "sma_length": combo.get('sma_length', 50),

        "bb_enabled": bool(combo.get('bb_enabled', False)),
        "bb_length": combo.get('bb_length', 20),
        "bb_stddev": combo.get('bb_stddev', 2.0),

        "combination_mode": combo.get('combination_mode', 'all')
    }

def aggregate_trades_metrics(trades, capital_per_trade=CAPITAL_PER_TRADE):
    if not trades:
        return {
            "total_trades": 0, "win_rate": 0.0, "avg_roi_pct": 0.0,
            "total_net_profit": 0.0, "avg_net_profit": 0.0, "avg_holding": 0.0,
            "overall_roi_pct": 0.0,
            "hsl_exits": 0, "tsl_exits": 0, "time_exits": 0, "sl_exits": 0, "other_exits": 0
        }
    df = pd.DataFrame(trades)

    total_trades = len(df)
    wins = df[df['net_profit'] > 0]
    win_rate = len(wins) / total_trades * 100.0

    avg_roi_pct = df['return_%'].mean()
    total_net_profit = df['net_profit'].sum()
    avg_net_profit = df['net_profit'].mean()
    avg_holding = df['days_held'].mean()

    hsl_exits = len(df[df['exit_reason'] == 'HSL'])
    tsl_exits = len(df[df['exit_reason'] == 'TSL'])
    time_exits = len(df[df['exit_reason'] == 'Time'])
    sl_exits = len(df[df['exit_reason'] == 'SL'])
    other_exits = total_trades - (hsl_exits + tsl_exits + time_exits + sl_exits)

    total_capital_allocated = capital_per_trade * total_trades
    overall_roi_pct = (total_net_profit / total_capital_allocated) * 100.0 if total_capital_allocated else 0.0

    return {
        "total_trades": total_trades,
        "win_rate": win_rate,
        "avg_roi_pct": avg_roi_pct,
        "total_net_profit": total_net_profit,
        "avg_net_profit": avg_net_profit,
        "overall_roi_pct": overall_roi_pct,
        "avg_holding": avg_holding,
        "hsl_exits": hsl_exits,
        "tsl_exits": tsl_exits,
        "time_exits": time_exits,
        "sl_exits": sl_exits,
        "other_exits": other_exits
    }

def safe_filename(s):
    return "".join(c if c.isalnum() or c in "._-" else "_" for c in s)[:180]

def process_ticker_for_combo(ticker, params):
    try:
        df = get_stock_data(ticker, START_DATE, END_DATE, params)
    except Exception as e:
        print(f"  Error fetching {ticker}: {e}")
        return []
    if df is None:
        return []
    trades = run_backtest_on_df(df, ticker, params,
                                trailing_stop_loss_pct=DEFAULT_TRAILING_STOP_PCT,
                                hard_stop_loss_pct=DEFAULT_HARD_STOP_PCT,
                                max_holding_days=DEFAULT_MAX_HOLD_DAYS,
                                combination_mode=params.get('combination_mode', 'all'),
                                capital_per_trade=CAPITAL_PER_TRADE)
    return trades

def process_combo_worker(args):
    combo_idx, combo = args
    params = build_indicator_params_from_combo(combo)
    combination_mode = params.get('combination_mode', 'all')
    combo_name = f"{combo_idx:04d}_" + "_".join([f"{k}={v}" for k, v in combo.items()])
    overall_trades = []

    if PARALLELIZE == 'tickers':
        with ThreadPoolExecutor(max_workers=min(6, (os.cpu_count() or 1))) as tp:
            futures = {tp.submit(process_ticker_for_combo, t, params): t for t in TICKER_LIST}
            for fut in as_completed(futures):
                try:
                    trades = fut.result()
                except Exception as e:
                    print(f"  Error in ticker worker: {e}")
                    trades = []
                if trades:
                    for tr in trades:
                        tr['combo_name'] = combo_name
                    overall_trades.extend(trades)
    else:
        for ticker in TICKER_LIST:
            try:
                df = get_stock_data(ticker, START_DATE, END_DATE, params)
            except Exception as e:
                print(f"[Worker {combo_idx}] Error fetching {ticker}: {e}")
                df = None
            if df is None:
                continue
            trades = run_backtest_on_df(df, ticker, params,
                                        trailing_stop_loss_pct=DEFAULT_TRAILING_STOP_PCT,
                                        hard_stop_loss_pct=DEFAULT_HARD_STOP_PCT,
                                        max_holding_days=DEFAULT_MAX_HOLD_DAYS,
                                        combination_mode=combination_mode,
                                        capital_per_trade=CAPITAL_PER_TRADE)
            for t in trades:
                t['combo_name'] = combo_name
            overall_trades.extend(trades)

    metrics = aggregate_trades_metrics(overall_trades, capital_per_trade=CAPITAL_PER_TRADE)
    return (combo_idx, combo, metrics, overall_trades)

def run_grid_search_parallel(tickers, grid_space, save_detailed=SAVE_DETAILED_TRADES):
    combos = generate_param_combinations(grid_space)
    print(f"Generated {len(combos)} parameter combinations.")
    os.makedirs(DETAILED_DIR, exist_ok=True)

    indexed_combos = list(enumerate(combos, start=1))
    num_workers = NUM_WORKERS or (os.cpu_count() or 1)
    print(f"Launching thread pool with {num_workers} worker threads. PARALLELIZE='{PARALLELIZE}'")

    results_rows = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        future_to_combo = {executor.submit(process_combo_worker, ic): ic for ic in indexed_combos}
        for fut in as_completed(future_to_combo):
            ic = future_to_combo[fut]
            try:
                combo_idx, combo, metrics, detailed_trades = fut.result()
            except Exception as e:
                print(f"[Combo {ic[0]}] Worker raised exception: {e}")
                continue

            row = {**combo, **metrics}
            results_rows.append(row)

            if save_detailed and detailed_trades:
                combo_name = f"{combo_idx:04d}_" + "_".join([f"{k}={v}" for k, v in combo.items()])
                fname = safe_filename(combo_name) + "_trades.csv"
                outpath = os.path.join(DETAILED_DIR, fname)
                try:
                    # flatten fees_breakdown dictionary into CSV rows automatically (pandas will keep dicts unless exploded)
                    pd.DataFrame(detailed_trades).to_csv(outpath, index=False)
                except Exception as e:
                    print(f"  Warning: failed to save detailed trades for combo {combo_idx}: {e}")

            try:
                pd.DataFrame(results_rows).to_csv(GRID_RESULTS_CSV, index=False)
            except Exception as e:
                print(f"  Warning: failed to checkpoint grid results: {e}")

            print(f"[Done combo {combo_idx}] total_trades={metrics['total_trades']} win_rate={metrics['win_rate']:.2f}% overall_roi={metrics['overall_roi_pct']:.2f}% total_net_profit={metrics['total_net_profit']:.2f}")

    return pd.DataFrame(results_rows)

# -------------------- MAIN --------------------
if __name__ == "__main__":
    print("Starting threaded grid search (warm cache optional) with Groww fees and isolated capital per trade.")
    combos = generate_param_combinations(GRID_SEARCH_SPACE)
    print(f"TOTAL combos to run: {len(combos)}")

    if WARM_CACHE:
        print("WARM_CACHE is True — warming cache sequentially before threaded run (recommended).")
        warm_cache(TICKER_LIST)

    results_df = run_grid_search_parallel(TICKER_LIST, GRID_SEARCH_SPACE, save_detailed=SAVE_DETAILED_TRADES)

    if results_df is not None and not results_df.empty:
        top_by_total = results_df.sort_values(by='total_net_profit', ascending=False).head(10)
        print("\nTop 10 combos by total_net_profit:")
        print(top_by_total[['total_net_profit', 'overall_roi_pct', 'win_rate', 'total_trades']].to_string(index=False))
        top_by_win = results_df.sort_values(by='win_rate', ascending=False).head(10)
        print("\nTop 10 combos by win_rate:")
        print(top_by_win[['win_rate', 'total_net_profit', 'overall_roi_pct', 'total_trades']].to_string(index=False))
    else:
        print("No results produced.")

    print("\nGrid search finished. Results in:", GRID_RESULTS_CSV)


Cache: using parquet engine 'pyarrow'.
Starting threaded grid search (warm cache optional) with Groww fees and isolated capital per trade.
TOTAL combos to run: 1536
WARM_CACHE is True — warming cache sequentially before threaded run (recommended).
Warm cache: downloading 30 tickers sequentially ...
  warm_cache: cached SBICARD.NS
  warm_cache: cached BDL.NS
  warm_cache: cached INDHOTEL.NS
  warm_cache: cached BSE.NS
  warm_cache: cached NYKAA.NS
  warm_cache: cached BAJFINANCE.NS
  warm_cache: cached PAYTM.NS
  warm_cache: cached SOLARINDS.NS
  warm_cache: cached CHOLAFIN.NS
  warm_cache: cached UNITDSPR.NS
  warm_cache: cached DIVISLAB.NS
  warm_cache: cached MUTHOOTFIN.NS
  warm_cache: cached BHARTIARTL.NS
  warm_cache: cached ICICIBANK.NS
  warm_cache: cached MAZDOCK.NS
  warm_cache: cached SHREECEM.NS
  warm_cache: cached DIXON.NS
  warm_cache: cached PERSISTENT.NS
  warm_cache: cached SRF.NS
  warm_cache: cached TVSMOTOR.NS
  warm_cache: cached SBILIFE.NS
  warm_cache: cached MAX