# Databricks â€” Momentum Swing Backtest Starter (Sydney reporting)

This notebook supports two workflows:
1) **Single-symbol iteration** (fast signal development)
2) **Multi-symbol backtest** at scale using Spark `groupBy(...).applyInPandas(...)`

Timezone:
- Data is stored as epoch ms. We convert to UTC timestamps, and add a Sydney column for reporting (AEST/AEDT DST-aware).
- Strategy logic should generally use UTC ordering to avoid DST edge cases.


In [None]:
%sql
USE CATALOG `workspace`;
USE SCHEMA `squeeze`;


In [None]:
import uuid
import numpy as np
import pandas as pd
from dataclasses import dataclass

from pyspark.sql import functions as F
from pyspark.sql import types as T

TZ = 'Australia/Sydney'
pd.set_option('display.max_columns', 200)
pd.set_option('display.width', 140)

RUN_ID = uuid.uuid4().hex
print('RUN_ID =', RUN_ID)


## Strategy parameters

In [None]:
@dataclass
class BacktestParams:
    atr_mult_stop: float = 2.0
    atr_mult_tp: float = 3.0
    max_hold_bars: int = 48

PARAMS = BacktestParams(atr_mult_stop=2.0, atr_mult_tp=3.0, max_hold_bars=48)
SIGNALS = ['sig_breakout_long']


## Single-symbol iteration (optional)
Use this section to iterate quickly on features/signals for one symbol, then run the multi-symbol section.

In [None]:
EXCHANGE = 'binance'
SYMBOL = 'BTCUSDT'
INTERVAL = '1h'
LIMIT_ROWS = 20000


In [None]:
ohlc_s = (spark.table('ohlc')
  .where((F.col('exchange')==EXCHANGE) & (F.col('symbol')==SYMBOL) & (F.col('interval')==INTERVAL))
  .select('exchange','symbol','interval','open_time','close_time','open','high','low','close','volume')
  .withColumn('open_dt_utc', F.to_timestamp(F.col('open_time')/1000))
  .withColumn('open_dt_syd', F.from_utc_timestamp(F.col('open_dt_utc'), TZ))
  .orderBy('open_time')
  .limit(LIMIT_ROWS)
)
display(ohlc_s.limit(5))


In [None]:
df = ohlc_s.toPandas()
for c in ['open','high','low','close','volume']:
    df[c] = pd.to_numeric(df[c], errors='coerce')
df = df.sort_values('open_time').reset_index(drop=True)
df.tail()


In [None]:
def ema(s: pd.Series, span: int) -> pd.Series:
    return s.ewm(span=span, adjust=False).mean()

def atr(high: pd.Series, low: pd.Series, close: pd.Series, n: int = 14) -> pd.Series:
    prev_close = close.shift(1)
    tr = pd.concat([(high - low).abs(), (high - prev_close).abs(), (low - prev_close).abs()], axis=1).max(axis=1)
    return tr.rolling(n, min_periods=n).mean()

def add_features(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out['ema_20'] = ema(out['close'], 20)
    out['ema_50'] = ema(out['close'], 50)
    out['trend_up'] = out['ema_20'] > out['ema_50']
    out['atr_14'] = atr(out['high'], out['low'], out['close'], 14)
    out['atrp_14'] = out['atr_14'] / out['close']
    L = 20
    out['hh_20'] = out['high'].rolling(L, min_periods=L).max()
    out['ll_20'] = out['low'].rolling(L, min_periods=L).min()
    return out

def add_signals(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out['sig_breakout_long'] = out['trend_up'] & (out['close'] > out['hh_20'].shift(1))
    return out

df = add_signals(add_features(df))
df[['open_dt_syd','close','ema_20','ema_50','atr_14','hh_20','sig_breakout_long']].tail(50)


## Backtest core (pandas)
Used both for single-symbol iteration and in Spark `applyInPandas`.

In [None]:
def backtest_long(df: pd.DataFrame, signal_col: str, p: BacktestParams) -> pd.DataFrame:
    rows = []
    n = len(df)
    for i in range(n - 2):
        if not bool(df[signal_col].iloc[i]):
            continue
        entry_i = i + 1  # enter next bar open (avoid lookahead)
        entry = float(df['open'].iloc[entry_i])
        atrv = float(df['atr_14'].iloc[entry_i])
        if not np.isfinite(entry) or not np.isfinite(atrv) or atrv <= 0:
            continue
        stop = entry - p.atr_mult_stop * atrv
        tp = entry + p.atr_mult_tp * atrv
        last_i = min(n - 1, entry_i + p.max_hold_bars)

        exit_i = None
        exit_px = None
        outcome = None

        # Conservative intrabar ordering for long: stop can trigger before TP within bar
        for j in range(entry_i, last_i + 1):
            lo = float(df['low'].iloc[j])
            hi = float(df['high'].iloc[j])
            if lo <= stop:
                exit_i, exit_px, outcome = j, stop, 'stop'
                break
            if hi >= tp:
                exit_i, exit_px, outcome = j, tp, 'tp'
                break

        if exit_i is None:
            exit_i, exit_px, outcome = last_i, float(df['close'].iloc[last_i]), 'time'

        r = (exit_px - entry) / (entry - stop) if (entry - stop) != 0 else np.nan

        rows.append({
            'entry_i': entry_i,
            'exit_i': exit_i,
            'entry_time_utc': df['open_dt_utc'].iloc[entry_i],
            'entry_time_syd': df['open_dt_syd'].iloc[entry_i],
            'exit_time_utc': df['open_dt_utc'].iloc[exit_i],
            'entry': entry,
            'stop': stop,
            'tp': tp,
            'exit': exit_px,
            'bars_held': int(exit_i - entry_i),
            'outcome': outcome,
            'r_multiple': float(r),
        })

    return pd.DataFrame(rows)

def summarize(trades: pd.DataFrame) -> pd.DataFrame:
    if trades.empty:
        return pd.DataFrame([{'n_trades': 0}])
    wins = trades['r_multiple'] > 0
    return pd.DataFrame([{
        'n_trades': int(len(trades)),
        'win_rate': float(wins.mean()),
        'avg_r': float(trades['r_multiple'].mean()),
        'median_r': float(trades['r_multiple'].median()),
        'p10_r': float(trades['r_multiple'].quantile(0.10)),
        'p90_r': float(trades['r_multiple'].quantile(0.90)),
        'avg_bars_held': float(trades['bars_held'].mean()),
    }])

# Single-symbol quick check
trades = backtest_long(df, 'sig_breakout_long', PARAMS)
display(trades.head(20))
display(summarize(trades))


# Multi-symbol backtest (Spark)
This section backtests a whole universe by grouping OHLC per `(exchange, symbol, interval)` and running your pandas logic with `applyInPandas`.

Tip: Start with a small universe (e.g., top 20 by recent volume) until the strategy logic is stable.

In [None]:
UNIVERSE_TOP_N = 30
UNIVERSE_INTERVAL = '1h'
MIN_BARS = 500

# Optional: limit to recent history for volume ranking (tune)
# Example: last ~180 days in ms
DAYS_FOR_RANK = 180
MS_PER_DAY = 86400000
now_ms = int(spark.sql('SELECT CAST(unix_millis(current_timestamp()) AS BIGINT) AS now_ms').collect()[0]['now_ms'])
min_open_time = now_ms - DAYS_FOR_RANK * MS_PER_DAY

base = (spark.table('ohlc')
  .where(F.col('interval') == UNIVERSE_INTERVAL)
  .where(F.col('open_time') >= F.lit(min_open_time))
  .select('exchange','symbol','interval','open_time','close_time','open','high','low','close','volume')
)

universe = (base
  .groupBy('exchange','symbol','interval')
  .agg(
    F.sum(F.col('volume').cast('double')).alias('vol_sum'),
    F.count('*').alias('bars')
  )
  .where(F.col('bars') >= MIN_BARS)
  .orderBy(F.col('vol_sum').desc())
  .limit(UNIVERSE_TOP_N)
)

display(universe)
print('Universe size:', universe.count())


## Run grouped backtest via applyInPandas

In [None]:
# Schema for trade-level output written to Delta
trade_schema = T.StructType([
    T.StructField('run_id', T.StringType(), False),
    T.StructField('exchange', T.StringType(), True),
    T.StructField('symbol', T.StringType(), True),
    T.StructField('interval', T.StringType(), True),
    T.StructField('signal', T.StringType(), True),
    T.StructField('entry_time_utc', T.TimestampType(), True),
    T.StructField('entry_time_syd', T.TimestampType(), True),
    T.StructField('exit_time_utc', T.TimestampType(), True),
    T.StructField('entry', T.DoubleType(), True),
    T.StructField('stop', T.DoubleType(), True),
    T.StructField('tp', T.DoubleType(), True),
    T.StructField('exit', T.DoubleType(), True),
    T.StructField('bars_held', T.IntegerType(), True),
    T.StructField('outcome', T.StringType(), True),
    T.StructField('r_multiple', T.DoubleType(), True),
])

def backtest_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf contains one (exchange,symbol,interval) group
    if pdf.empty:
        return pd.DataFrame(columns=[f.name for f in trade_schema.fields])

    # ensure ordering and dtypes
    pdf = pdf.sort_values('open_time').reset_index(drop=True)
    for c in ['open','high','low','close','volume']:
        pdf[c] = pd.to_numeric(pdf[c], errors='coerce')

    # time columns
    pdf['open_dt_utc'] = pd.to_datetime(pd.to_numeric(pdf['open_time'], errors='coerce'), unit='ms', utc=True)
    pdf['open_dt_syd'] = pdf['open_dt_utc'].dt.tz_convert(TZ)

    pdf = add_signals(add_features(pdf))

    out_rows = []
    for sig in SIGNALS:
        tr = backtest_long(pdf, sig, PARAMS)
        if tr.empty:
            continue
        # enrich
        ex = pdf['exchange'].iloc[0] if 'exchange' in pdf.columns else None
        sym = pdf['symbol'].iloc[0] if 'symbol' in pdf.columns else None
        itv = pdf['interval'].iloc[0] if 'interval' in pdf.columns else None
        tr.insert(0, 'signal', sig)
        tr.insert(0, 'interval', itv)
        tr.insert(0, 'symbol', sym)
        tr.insert(0, 'exchange', ex)
        tr.insert(0, 'run_id', RUN_ID)
        out_rows.append(tr[['run_id','exchange','symbol','interval','signal',
                            'entry_time_utc','entry_time_syd','exit_time_utc',
                            'entry','stop','tp','exit','bars_held','outcome','r_multiple']])

    if not out_rows:
        return pd.DataFrame(columns=[f.name for f in trade_schema.fields])
    return pd.concat(out_rows, ignore_index=True)

ohlc_universe = (spark.table('ohlc')
  .join(universe.select('exchange','symbol','interval'), on=['exchange','symbol','interval'], how='inner')
  .select('exchange','symbol','interval','open_time','close_time','open','high','low','close','volume')
)

trades_s = (ohlc_universe
  .groupBy('exchange','symbol','interval')
  .applyInPandas(backtest_group, schema=trade_schema)
)

display(trades_s.limit(50))
print('Trade rows:', trades_s.count())


## Write outputs to Delta + compute summary stats

In [None]:
# Ensure tables exist (Delta). You can change names if you prefer.
spark.sql('''
CREATE TABLE IF NOT EXISTS backtest_trades (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  signal STRING,
  entry_time_utc TIMESTAMP,
  entry_time_syd TIMESTAMP,
  exit_time_utc TIMESTAMP,
  entry DOUBLE,
  stop DOUBLE,
  tp DOUBLE,
  exit DOUBLE,
  bars_held INT,
  outcome STRING,
  r_multiple DOUBLE
) USING DELTA
''')

spark.sql('''
CREATE TABLE IF NOT EXISTS backtest_results (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  signal STRING,
  n_trades BIGINT,
  win_rate DOUBLE,
  avg_r DOUBLE,
  median_r DOUBLE,
  p10_r DOUBLE,
  p90_r DOUBLE,
  avg_bars_held DOUBLE
) USING DELTA
''')

# Append trades
(trades_s
  .write
  .mode('append')
  .saveAsTable('backtest_trades')
)

# Build summary per symbol+signal
results_s = (trades_s
  .groupBy('run_id','exchange','symbol','interval','signal')
  .agg(
    F.count('*').alias('n_trades'),
    F.avg(F.when(F.col('r_multiple') > 0, 1.0).otherwise(0.0)).alias('win_rate'),
    F.avg('r_multiple').alias('avg_r'),
    F.expr('percentile_approx(r_multiple, 0.5)').alias('median_r'),
    F.expr('percentile_approx(r_multiple, 0.1)').alias('p10_r'),
    F.expr('percentile_approx(r_multiple, 0.9)').alias('p90_r'),
    F.avg('bars_held').alias('avg_bars_held'),
  )
)

(results_s
  .write
  .mode('append')
  .saveAsTable('backtest_results')
)

display(results_s.orderBy(F.col('avg_r').desc()).limit(100))
print('Wrote run_id', RUN_ID)
