# Databricks â€” Momentum Swing Backtest (Clean Universe + Parameter Grid + Regimes)

This notebook is the main engine for **pattern recognition + backtesting**.

It includes:
- Single-symbol iteration (fast debugging)
- Multi-symbol backtest at scale using Spark `groupBy(...).applyInPandas(...)`
- Universe sourced from `clean_universe` (generated by `databricks_eda_sydney_time.ipynb`)
- A small parameter grid (stop/TP/hold)
- ATR% regime buckets recorded per trade

Timezone:
- We compute UTC timestamps from epoch ms for correct ordering.
- We add Sydney timestamps (AEST/AEDT) for reporting, not for trading logic.


In [None]:
%sql
USE CATALOG `workspace`;
USE SCHEMA `squeeze`;


In [None]:
import uuid
import numpy as np
import pandas as pd
from dataclasses import dataclass

from pyspark.sql import functions as F
from pyspark.sql import types as T

TZ = 'Australia/Sydney'
RUN_ID = uuid.uuid4().hex
print('RUN_ID =', RUN_ID)

pd.set_option('display.max_columns', 200)
pd.set_option('display.width', 140)


## Strategy definition
Start with one simple momentum pattern, then iterate based on results.

In [None]:
@dataclass
class BacktestParams:
    atr_mult_stop: float
    atr_mult_tp: float
    max_hold_bars: int

# Small parameter grid to start (expand later)
PARAM_GRID = [
    BacktestParams(atr_mult_stop=1.5, atr_mult_tp=2.5, max_hold_bars=48),
    BacktestParams(atr_mult_stop=2.0, atr_mult_tp=3.0, max_hold_bars=48),
    BacktestParams(atr_mult_stop=2.5, atr_mult_tp=3.5, max_hold_bars=72),
]

SIGNALS = ['sig_breakout_long']
DEFAULT_PARAMS = PARAM_GRID[1]


## Helpers: features + signals

In [None]:
def ema(s: pd.Series, span: int) -> pd.Series:
    return s.ewm(span=span, adjust=False).mean()

def atr(high: pd.Series, low: pd.Series, close: pd.Series, n: int = 14) -> pd.Series:
    prev_close = close.shift(1)
    tr = pd.concat([(high - low).abs(), (high - prev_close).abs(), (low - prev_close).abs()], axis=1).max(axis=1)
    return tr.rolling(n, min_periods=n).mean()

def add_features(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out['ema_20'] = ema(out['close'], 20)
    out['ema_50'] = ema(out['close'], 50)
    out['trend_up'] = out['ema_20'] > out['ema_50']
    out['atr_14'] = atr(out['high'], out['low'], out['close'], 14)
    out['atrp_14'] = out['atr_14'] / out['close']
    L = 20
    out['hh_20'] = out['high'].rolling(L, min_periods=L).max()
    return out

def add_signals(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    # Trend + breakout above prior high (use shift to avoid lookahead)
    out['sig_breakout_long'] = out['trend_up'] & (out['close'] > out['hh_20'].shift(1))
    return out


## Backtest core (pandas)

In [None]:
def backtest_long(df: pd.DataFrame, signal_col: str, p: BacktestParams) -> pd.DataFrame:
    rows = []
    n = len(df)
    for i in range(n - 2):
        if not bool(df[signal_col].iloc[i]):
            continue
        entry_i = i + 1
        entry = float(df['open'].iloc[entry_i])
        atrv = float(df['atr_14'].iloc[entry_i])
        if not np.isfinite(entry) or not np.isfinite(atrv) or atrv <= 0:
            continue
        stop = entry - p.atr_mult_stop * atrv
        tp = entry + p.atr_mult_tp * atrv
        last_i = min(n - 1, entry_i + p.max_hold_bars)

        exit_i = None
        exit_px = None
        outcome = None

        # conservative bar model: stop before tp within bar
        for j in range(entry_i, last_i + 1):
            lo = float(df['low'].iloc[j])
            hi = float(df['high'].iloc[j])
            if lo <= stop:
                exit_i, exit_px, outcome = j, stop, 'stop'
                break
            if hi >= tp:
                exit_i, exit_px, outcome = j, tp, 'tp'
                break

        if exit_i is None:
            exit_i, exit_px, outcome = last_i, float(df['close'].iloc[last_i]), 'time'

        r = (exit_px - entry) / (entry - stop) if (entry - stop) != 0 else np.nan
        atrp_entry = float(df['atrp_14'].iloc[entry_i]) if 'atrp_14' in df.columns else np.nan

        rows.append({
            'entry_i': int(entry_i),
            'exit_i': int(exit_i),
            'entry_time_utc': df['open_dt_utc'].iloc[entry_i],
            'entry_time_syd': df['open_dt_syd'].iloc[entry_i],
            'exit_time_utc': df['open_dt_utc'].iloc[exit_i],
            'entry': entry,
            'stop': stop,
            'tp': tp,
            'exit': exit_px,
            'bars_held': int(exit_i - entry_i),
            'outcome': outcome,
            'r_multiple': float(r),
            'atrp_entry': atrp_entry,
        })

    return pd.DataFrame(rows)


## Single-symbol iteration (optional)
Use this to debug strategy logic quickly.

In [None]:
EXCHANGE = 'binance'
SYMBOL = 'BTCUSDT'
INTERVAL = '1h'
LIMIT_ROWS = 20000


In [None]:
ohlc_s = (spark.table('ohlc')
  .where((F.col('exchange')==EXCHANGE) & (F.col('symbol')==SYMBOL) & (F.col('interval')==INTERVAL))
  .select('exchange','symbol','interval','open_time','open','high','low','close','volume')
  .withColumn('open_dt_utc', F.to_timestamp(F.col('open_time')/1000))
  .withColumn('open_dt_syd', F.from_utc_timestamp(F.col('open_dt_utc'), TZ))
  .orderBy('open_time')
  .limit(LIMIT_ROWS)
)
df = ohlc_s.toPandas()
for c in ['open','high','low','close','volume']:
    df[c] = pd.to_numeric(df[c], errors='coerce')
df = df.sort_values('open_time').reset_index(drop=True)
df = add_signals(add_features(df))
df.tail()


In [None]:
# Regime buckets on the full series
try:
    df['atrp_bucket'] = pd.qcut(df['atrp_14'], q=5, labels=['Q1_low','Q2','Q3','Q4','Q5_high'])
except Exception:
    df['atrp_bucket'] = None

tr = backtest_long(df, 'sig_breakout_long', DEFAULT_PARAMS)
tr['atrp_bucket'] = [str(df['atrp_bucket'].iloc[i]) if i < len(df) else None for i in tr['entry_i']]
display(tr.head(50))
display(tr['outcome'].value_counts())
display(tr['r_multiple'].describe())


# Multi-symbol backtest (Spark applyInPandas)
Universe is read from `clean_universe`.

In [None]:
UNIVERSE_INTERVAL = '1h'
UNIVERSE_LIMIT = 200  # set None for full clean_universe

universe = (spark.table('clean_universe')
  .where(F.col('interval') == UNIVERSE_INTERVAL)
  .select('exchange','symbol','interval')
)
if UNIVERSE_LIMIT is not None:
    universe = universe.limit(int(UNIVERSE_LIMIT))

display(universe)
print('Universe size:', universe.count())


In [None]:
trade_schema = T.StructType([
  T.StructField('run_id', T.StringType(), False),
  T.StructField('exchange', T.StringType(), True),
  T.StructField('symbol', T.StringType(), True),
  T.StructField('interval', T.StringType(), True),
  T.StructField('signal', T.StringType(), True),
  T.StructField('atr_mult_stop', T.DoubleType(), True),
  T.StructField('atr_mult_tp', T.DoubleType(), True),
  T.StructField('max_hold_bars', T.IntegerType(), True),
  T.StructField('atrp_entry', T.DoubleType(), True),
  T.StructField('atrp_bucket', T.StringType(), True),
  T.StructField('entry_time_utc', T.TimestampType(), True),
  T.StructField('entry_time_syd', T.TimestampType(), True),
  T.StructField('exit_time_utc', T.TimestampType(), True),
  T.StructField('entry', T.DoubleType(), True),
  T.StructField('stop', T.DoubleType(), True),
  T.StructField('tp', T.DoubleType(), True),
  T.StructField('exit', T.DoubleType(), True),
  T.StructField('bars_held', T.IntegerType(), True),
  T.StructField('outcome', T.StringType(), True),
  T.StructField('r_multiple', T.DoubleType(), True),
])

def backtest_group(pdf: pd.DataFrame) -> pd.DataFrame:
    if pdf.empty:
        return pd.DataFrame(columns=[f.name for f in trade_schema.fields])

    pdf = pdf.sort_values('open_time').reset_index(drop=True)
    for c in ['open','high','low','close','volume']:
        pdf[c] = pd.to_numeric(pdf[c], errors='coerce')

    pdf['open_dt_utc'] = pd.to_datetime(pd.to_numeric(pdf['open_time'], errors='coerce'), unit='ms', utc=True)
    pdf['open_dt_syd'] = pdf['open_dt_utc'].dt.tz_convert(TZ)

    pdf = add_signals(add_features(pdf))

    try:
        pdf['atrp_bucket'] = pd.qcut(pdf['atrp_14'], q=5, labels=['Q1_low','Q2','Q3','Q4','Q5_high'])
    except Exception:
        pdf['atrp_bucket'] = None

    out_rows = []
    ex = pdf['exchange'].iloc[0] if 'exchange' in pdf.columns else None
    sym = pdf['symbol'].iloc[0] if 'symbol' in pdf.columns else None
    itv = pdf['interval'].iloc[0] if 'interval' in pdf.columns else None

    for sig in SIGNALS:
        for p in PARAM_GRID:
            tr = backtest_long(pdf, sig, p)
            if tr.empty:
                continue
            tr['atrp_bucket'] = [str(pdf['atrp_bucket'].iloc[i]) if i < len(pdf) else None for i in tr['entry_i']]
            tr.insert(0, 'r_multiple', tr['r_multiple'])  # no-op, ensure column exists

            tr.insert(0, 'outcome', tr['outcome'])
            tr.insert(0, 'bars_held', tr['bars_held'])
            tr.insert(0, 'exit', tr['exit'])
            tr.insert(0, 'tp', tr['tp'])
            tr.insert(0, 'stop', tr['stop'])
            tr.insert(0, 'entry', tr['entry'])
            tr.insert(0, 'exit_time_utc', tr['exit_time_utc'])
            tr.insert(0, 'entry_time_syd', tr['entry_time_syd'])
            tr.insert(0, 'entry_time_utc', tr['entry_time_utc'])

            tr.insert(0, 'atrp_bucket', tr['atrp_bucket'])
            tr.insert(0, 'atrp_entry', tr['atrp_entry'])
            tr.insert(0, 'max_hold_bars', int(p.max_hold_bars))
            tr.insert(0, 'atr_mult_tp', float(p.atr_mult_tp))
            tr.insert(0, 'atr_mult_stop', float(p.atr_mult_stop))
            tr.insert(0, 'signal', sig)
            tr.insert(0, 'interval', itv)
            tr.insert(0, 'symbol', sym)
            tr.insert(0, 'exchange', ex)
            tr.insert(0, 'run_id', RUN_ID)

            out_rows.append(tr[[
                'run_id','exchange','symbol','interval','signal',
                'atr_mult_stop','atr_mult_tp','max_hold_bars',
                'atrp_entry','atrp_bucket',
                'entry_time_utc','entry_time_syd','exit_time_utc',
                'entry','stop','tp','exit','bars_held','outcome','r_multiple'
            ]])

    if not out_rows:
        return pd.DataFrame(columns=[f.name for f in trade_schema.fields])
    return pd.concat(out_rows, ignore_index=True)

ohlc_universe = (spark.table('ohlc')
  .join(universe, on=['exchange','symbol','interval'], how='inner')
  .select('exchange','symbol','interval','open_time','open','high','low','close','volume')
)

trades_s = (ohlc_universe
  .groupBy('exchange','symbol','interval')
  .applyInPandas(backtest_group, schema=trade_schema)
)

display(trades_s.limit(50))
print('Trade rows:', trades_s.count())


## Write outputs to Delta

In [None]:
spark.sql('''
CREATE TABLE IF NOT EXISTS backtest_trades (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  signal STRING,
  atr_mult_stop DOUBLE,
  atr_mult_tp DOUBLE,
  max_hold_bars INT,
  atrp_entry DOUBLE,
  atrp_bucket STRING,
  entry_time_utc TIMESTAMP,
  entry_time_syd TIMESTAMP,
  exit_time_utc TIMESTAMP,
  entry DOUBLE,
  stop DOUBLE,
  tp DOUBLE,
  exit DOUBLE,
  bars_held INT,
  outcome STRING,
  r_multiple DOUBLE
) USING DELTA
''')

spark.sql('''
CREATE TABLE IF NOT EXISTS backtest_results (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  signal STRING,
  atr_mult_stop DOUBLE,
  atr_mult_tp DOUBLE,
  max_hold_bars INT,
  atrp_bucket STRING,
  n_trades BIGINT,
  win_rate DOUBLE,
  avg_r DOUBLE,
  median_r DOUBLE,
  p10_r DOUBLE,
  p90_r DOUBLE,
  avg_bars_held DOUBLE
) USING DELTA
''')

(trades_s.write.mode('append').saveAsTable('backtest_trades'))

results_s = (trades_s
  .groupBy('run_id','exchange','symbol','interval','signal','atr_mult_stop','atr_mult_tp','max_hold_bars','atrp_bucket')
  .agg(
    F.count('*').alias('n_trades'),
    F.avg(F.when(F.col('r_multiple') > 0, 1.0).otherwise(0.0)).alias('win_rate'),
    F.avg('r_multiple').alias('avg_r'),
    F.expr('percentile_approx(r_multiple, 0.5)').alias('median_r'),
    F.expr('percentile_approx(r_multiple, 0.1)').alias('p10_r'),
    F.expr('percentile_approx(r_multiple, 0.9)').alias('p90_r'),
    F.avg('bars_held').alias('avg_bars_held')
  )
)

(results_s.write.mode('append').saveAsTable('backtest_results'))

display(results_s.orderBy(F.col('avg_r').desc()).limit(200))
print('Wrote run_id', RUN_ID)
