# Databricks ML — 4h Per-Bar Barrier Classifier (Full Feature Engineering)

- Grain: **every 4h bar** (per-bar dataset)
- Label: **barrier outcome** (TP hit before SL within W bars)

This notebook builds a supervised dataset from `ohlc` (filtered by `clean_universe`) and trains a classifier in Spark ML.

Tables assumed (Unity Catalog):
- `workspace.squeeze.ohlc`
- `workspace.squeeze.clean_universe`

Outputs:
- `ml_barrier_dataset_4h` (features + label)
- `ml_barrier_predictions_4h` (scored rows with probability)

Notes:
- **No lookahead**: all features are lagged/rolling with correct window framing.
- **Sydney time** is included for time-of-day features, but ordering uses UTC.


In [None]:
%sql
USE CATALOG `workspace`;
USE SCHEMA `squeeze`;


In [None]:
import uuid
from pyspark.sql import functions as F
from pyspark.sql.window import Window

TZ = 'Australia/Sydney'
RUN_ID = uuid.uuid4().hex
print('RUN_ID =', RUN_ID)


## Config

In [None]:
INTERVAL = '4h'
UNIVERSE_LIMIT = 300  # set None for full clean_universe

# Barrier labeling parameters
ATR_N = 14
TP_ATR = 3.0
SL_ATR = 2.0
W_BARS = 24  # lookahead window in bars (24*4h = 4 days)

# OHLC ambiguity policy when both TP and SL are hit in the *same bar*.
# Options:
# - 'tp_first'     : count as win
# - 'sl_first'     : count as loss
# - 'discard_both' : set label NULL (remove ambiguous samples)
AMBIGUITY_POLICY = 'discard_both'


# Feature windows
EMA_FAST = 20
EMA_SLOW = 50
ROLL_HH_N = 20
ROLL_LL_N = 20
RET_VOL_N = 20

# Split boundaries (UTC timestamps). Tune as you like.
TRAIN_END = '2024-01-01'
VALID_END = '2025-01-01'

# Sampling / class balance
MAX_ROWS_PER_SYMBOL = None  # optionally cap


## Load OHLC for clean universe (4h)

In [None]:
universe = (spark.table('clean_universe')
  .where(F.col('interval') == INTERVAL)
  .select('exchange','symbol','interval')
)
if UNIVERSE_LIMIT is not None:
  universe = universe.limit(int(UNIVERSE_LIMIT))

ohlc = (spark.table('ohlc')
  .join(universe, on=['exchange','symbol','interval'], how='inner')
  .select('exchange','symbol','interval','open_time','open','high','low','close','volume')
  .withColumn('open_dt_utc', F.to_timestamp(F.col('open_time')/1000))
  .withColumn('open_dt_syd', F.from_utc_timestamp(F.col('open_dt_utc'), TZ))
  .withColumn('hour_syd', F.hour('open_dt_syd'))
  .withColumn('dow_syd', F.date_format('open_dt_syd', 'E'))
)

display(ohlc.limit(5))
print('rows:', ohlc.count())


## Feature engineering in Spark (no leakage)
We compute lagged/rolling features using window frames that end at the current row.

In [None]:
w = Window.partitionBy('exchange','symbol','interval').orderBy('open_time')
w_prev = w.rowsBetween(-1000000, -1)  # 'all history up to prev row' for some ops

# Basic returns
feat = (ohlc
  .withColumn('close_d', F.col('close').cast('double'))
  .withColumn('open_d', F.col('open').cast('double'))
  .withColumn('high_d', F.col('high').cast('double'))
  .withColumn('low_d', F.col('low').cast('double'))
  .withColumn('vol_d', F.col('volume').cast('double'))
  .withColumn('prev_close', F.lag('close_d', 1).over(w))
  .withColumn('ret_1', F.log(F.col('close_d') / F.col('prev_close')))
  .withColumn('range', F.col('high_d') - F.col('low_d'))
  .withColumn('body', F.abs(F.col('close_d') - F.col('open_d')))
  .withColumn('upper_wick', F.col('high_d') - F.greatest(F.col('open_d'), F.col('close_d')))
  .withColumn('lower_wick', F.least(F.col('open_d'), F.col('close_d')) - F.col('low_d'))
  .withColumn('body_to_range', F.when(F.col('range') > 0, F.col('body')/F.col('range')).otherwise(F.lit(None)))
  .withColumn('gap', F.col('open_d') - F.col('prev_close'))
)

# ATR (simple moving average of true range)
tr = F.greatest(
  (F.col('high_d') - F.col('low_d')).cast('double'),
  F.abs(F.col('high_d') - F.col('prev_close')),
  F.abs(F.col('low_d') - F.col('prev_close'))
)
feat = (feat
  .withColumn('tr', tr)
  .withColumn('atr', F.avg('tr').over(w.rowsBetween(-(ATR_N-1), 0)))
  .withColumn('atrp', F.col('atr') / F.col('close_d'))
)

# EMAs via exponential smoothing implemented with Spark SQL function `ewm` is not available;
# Instead we use moving averages as a strong baseline. You can replace with pandas UDF EMAs later.
feat = (feat
  .withColumn('sma_fast', F.avg('close_d').over(w.rowsBetween(-(EMA_FAST-1), 0)))
  .withColumn('sma_slow', F.avg('close_d').over(w.rowsBetween(-(EMA_SLOW-1), 0)))
  .withColumn('trend_up', (F.col('sma_fast') > F.col('sma_slow')).cast('int'))
  .withColumn('trend_strength', (F.col('sma_fast') - F.col('sma_slow')) / F.col('atr'))
)

# Rolling highs/lows (shifted by 1 to avoid lookahead for breakout context)
feat = (feat
  .withColumn('hh_n', F.max('high_d').over(w.rowsBetween(-(ROLL_HH_N), -1)))
  .withColumn('ll_n', F.min('low_d').over(w.rowsBetween(-(ROLL_LL_N), -1)))
  .withColumn('dist_to_hh_atr', (F.col('close_d') - F.col('hh_n')) / F.col('atr'))
  .withColumn('dist_to_ll_atr', (F.col('close_d') - F.col('ll_n')) / F.col('atr'))
)

# Volatility and volume regime
feat = (feat
  .withColumn('ret_vol', F.stddev('ret_1').over(w.rowsBetween(-(RET_VOL_N-1), 0)))
  .withColumn('range_sma', F.avg('range').over(w.rowsBetween(-(RET_VOL_N-1), 0)))
  .withColumn('vol_sma', F.avg('vol_d').over(w.rowsBetween(-(RET_VOL_N-1), 0)))
  .withColumn('vol_std', F.stddev('vol_d').over(w.rowsBetween(-(RET_VOL_N-1), 0)))
  .withColumn('vol_z', F.when(F.col('vol_std') > 0, (F.col('vol_d')-F.col('vol_sma'))/F.col('vol_std')).otherwise(F.lit(None)))
)

# Add a bar index for forward barrier label joins
feat = feat.withColumn('bar_idx', (F.row_number().over(w) - 1).cast('long'))

display(feat.select('exchange','symbol','open_dt_syd','close','atr','atrp','trend_up','dist_to_hh_atr','vol_z').limit(10))


## Enhanced technical features (EMA/RSI/ADX) via applyInPandas
Spark doesn’t provide native EMA/RSI/ADX operators. For stronger features we compute them per symbol using `groupBy(...).applyInPandas(...)`.

This keeps the pipeline scalable while preserving true indicator definitions.


In [None]:
import pandas as pd
import numpy as np
from pyspark.sql import types as T

def _ema(s: pd.Series, span: int) -> pd.Series:
    return s.ewm(span=span, adjust=False).mean()

def _rsi(close: pd.Series, n: int = 14) -> pd.Series:
    delta = close.diff()
    up = delta.clip(lower=0)
    down = (-delta).clip(lower=0)
    roll_up = up.ewm(alpha=1/n, adjust=False).mean()
    roll_down = down.ewm(alpha=1/n, adjust=False).mean()
    rs = roll_up / roll_down
    return 100 - (100 / (1 + rs))

def _adx(high: pd.Series, low: pd.Series, close: pd.Series, n: int = 14) -> pd.Series:
    high_shift = high.shift(1)
    low_shift = low.shift(1)
    close_shift = close.shift(1)
    up_move = high - high_shift
    down_move = low_shift - low
    plus_dm = np.where((up_move > down_move) & (up_move > 0), up_move, 0.0)
    minus_dm = np.where((down_move > up_move) & (down_move > 0), down_move, 0.0)
    tr = pd.concat([(high - low).abs(), (high - close_shift).abs(), (low - close_shift).abs()], axis=1).max(axis=1)
    atr = tr.ewm(alpha=1/n, adjust=False).mean()
    plus_di = 100 * (pd.Series(plus_dm, index=high.index).ewm(alpha=1/n, adjust=False).mean() / atr)
    minus_di = 100 * (pd.Series(minus_dm, index=high.index).ewm(alpha=1/n, adjust=False).mean() / atr)
    dx = (100 * (plus_di - minus_di).abs() / (plus_di + minus_di)).replace([np.inf, -np.inf], np.nan)
    adx = dx.ewm(alpha=1/n, adjust=False).mean()
    return adx

def _bb(close: pd.Series, n: int = 20, k: float = 2.0):
    mid = close.rolling(n, min_periods=n).mean()
    sd = close.rolling(n, min_periods=n).std()
    upper = mid + k*sd
    lower = mid - k*sd
    bw = (upper - lower) / mid
    return bw

def _macd(close: pd.Series, fast: int = 12, slow: int = 26, sig: int = 9):
    macd_line = _ema(close, fast) - _ema(close, slow)
    signal = _ema(macd_line, sig)
    hist = macd_line - signal
    return macd_line, signal, hist

tech_schema = T.StructType([
  T.StructField('exchange', T.StringType()),
  T.StructField('symbol', T.StringType()),
  T.StructField('interval', T.StringType()),
  T.StructField('open_time', T.LongType()),
  T.StructField('ema_20', T.DoubleType()),
  T.StructField('ema_50', T.DoubleType()),
  T.StructField('ema_ratio', T.DoubleType()),
  T.StructField('rsi_14', T.DoubleType()),
  T.StructField('adx_14', T.DoubleType()),
  T.StructField('macd_line', T.DoubleType()),
  T.StructField('macd_signal', T.DoubleType()),
  T.StructField('macd_hist', T.DoubleType()),
  T.StructField('bb_bw_20', T.DoubleType()),
  T.StructField('donch_pos_20', T.DoubleType()),
  T.StructField('mom_logret_3', T.DoubleType()),
  T.StructField('mom_logret_6', T.DoubleType()),
  T.StructField('mom_logret_12', T.DoubleType()),
])

def compute_tech(pdf: pd.DataFrame) -> pd.DataFrame:
  pdf = pdf.sort_values('open_time').reset_index(drop=True)
  for c in ['close_d','high_d','low_d']:
    pdf[c] = pd.to_numeric(pdf[c], errors='coerce')
  close = pdf['close_d']
  ema20 = _ema(close, 20)
  ema50 = _ema(close, 50)
  rsi14 = _rsi(close, 14)
  adx14 = _adx(pdf['high_d'], pdf['low_d'], close, 14)
  macd_line, macd_sig, macd_hist = _macd(close)
  bb_bw = _bb(close, 20, 2.0)
  hh = pdf['high_d'].rolling(20, min_periods=20).max()
  ll = pdf['low_d'].rolling(20, min_periods=20).min()
  donch_pos = (close - ll) / (hh - ll)
  mom3 = np.log(close / close.shift(3))
  mom6 = np.log(close / close.shift(6))
  mom12 = np.log(close / close.shift(12))
  out = pdf[['exchange','symbol','interval','open_time']].copy()
  out['ema_20'] = ema20.astype('float64')
  out['ema_50'] = ema50.astype('float64')
  out['ema_ratio'] = (ema20 / ema50).astype('float64')
  out['rsi_14'] = rsi14.astype('float64')
  out['adx_14'] = adx14.astype('float64')
  out['macd_line'] = macd_line.astype('float64')
  out['macd_signal'] = macd_sig.astype('float64')
  out['macd_hist'] = macd_hist.astype('float64')
  out['bb_bw_20'] = bb_bw.astype('float64')
  out['donch_pos_20'] = donch_pos.astype('float64')
  out['mom_logret_3'] = mom3.astype('float64')
  out['mom_logret_6'] = mom6.astype('float64')
  out['mom_logret_12'] = mom12.astype('float64')
  return out

tech = (feat
  .select('exchange','symbol','interval','open_time','close_d','high_d','low_d')
  .groupBy('exchange','symbol','interval')
  .applyInPandas(compute_tech, schema=tech_schema)
)

feat = feat.join(tech, on=['exchange','symbol','interval','open_time'], how='left')
display(feat.select('exchange','symbol','open_dt_syd','macd_hist','bb_bw_20','donch_pos_20','mom_logret_12').limit(10))


## Multi-timeframe context: 1d features joined onto 4h bars
We compute a small set of 1d features per symbol and join the **most recent 1d bar at or before** each 4h bar.
This provides higher-timeframe trend/regime context without lookahead.


In [None]:
# Build 1d feature table for the same (exchange,symbol) universe
# Includes 1d RSI computed per symbol via applyInPandas.
import pandas as pd
import numpy as np
from pyspark.sql import types as T

def _rsi_sma(close: pd.Series, n: int = 14) -> pd.Series:
  delta = close.diff()
  up = delta.clip(lower=0)
  down = (-delta).clip(lower=0)
  roll_up = up.rolling(n, min_periods=n).mean()
  roll_down = down.rolling(n, min_periods=n).mean()
  rs = roll_up / roll_down
  return 100 - (100 / (1 + rs))

ohlc_1d_base = (spark.table('ohlc')
  .join(universe.select('exchange','symbol').distinct(), on=['exchange','symbol'], how='inner')
  .where(F.col('interval') == '1d')
  .select('exchange','symbol','interval','open_time','high','low','close')
  .withColumn('close_d', F.col('close').cast('double'))
  .withColumn('high_d', F.col('high').cast('double'))
  .withColumn('low_d', F.col('low').cast('double'))
)

schema_1d = T.StructType([
  T.StructField('exchange', T.StringType()),
  T.StructField('symbol', T.StringType()),
  T.StructField('open_time', T.LongType()),
  T.StructField('sma_20d', T.DoubleType()),
  T.StructField('sma_50d', T.DoubleType()),
  T.StructField('trend_up_1d', T.IntegerType()),
  T.StructField('ret_vol_20d', T.DoubleType()),
  T.StructField('rsi_14_1d', T.DoubleType()),
  T.StructField('macd_hist_1d', T.DoubleType()),
  T.StructField('adx_14_1d', T.DoubleType()),
])

def _ema_1d(s: pd.Series, span: int) -> pd.Series:
  return s.ewm(span=span, adjust=False).mean()

def _macd_hist_1d(close: pd.Series) -> pd.Series:
  macd_line = _ema_1d(close, 12) - _ema_1d(close, 26)
  sig = _ema_1d(macd_line, 9)
  return macd_line - sig

def _adx_1d(high: pd.Series, low: pd.Series, close: pd.Series, n: int = 14) -> pd.Series:
  # Wilder ADX on 1d
  high_shift = high.shift(1)
  low_shift = low.shift(1)
  close_shift = close.shift(1)
  up_move = high - high_shift
  down_move = low_shift - low
  plus_dm = np.where((up_move > down_move) & (up_move > 0), up_move, 0.0)
  minus_dm = np.where((down_move > up_move) & (down_move > 0), down_move, 0.0)
  tr = pd.concat([(high - low).abs(), (high - close_shift).abs(), (low - close_shift).abs()], axis=1).max(axis=1)
  atr = tr.ewm(alpha=1/n, adjust=False).mean()
  plus_di = 100 * (pd.Series(plus_dm, index=high.index).ewm(alpha=1/n, adjust=False).mean() / atr)
  minus_di = 100 * (pd.Series(minus_dm, index=high.index).ewm(alpha=1/n, adjust=False).mean() / atr)
  dx = (100 * (plus_di - minus_di).abs() / (plus_di + minus_di)).replace([np.inf, -np.inf], np.nan)
  return dx.ewm(alpha=1/n, adjust=False).mean()

def compute_1d(pdf: pd.DataFrame) -> pd.DataFrame:
  pdf = pdf.sort_values('open_time').reset_index(drop=True)
  c = pd.to_numeric(pdf['close_d'], errors='coerce')
  h = pd.to_numeric(pdf.get('high_d'), errors='coerce')
  l = pd.to_numeric(pdf.get('low_d'), errors='coerce')
  ret = np.log(c / c.shift(1))
  out = pdf[['exchange','symbol','open_time']].copy()
  out['sma_20d'] = c.rolling(20, min_periods=20).mean()
  out['sma_50d'] = c.rolling(50, min_periods=50).mean()
  out['trend_up_1d'] = (out['sma_20d'] > out['sma_50d']).astype('int64')
  out['ret_vol_20d'] = ret.rolling(20, min_periods=20).std()
  out['rsi_14_1d'] = _rsi_sma(c, 14)
  out['macd_hist_1d'] = _macd_hist_1d(c)
  out['adx_14_1d'] = _adx_1d(h, l, c, 14)
  return out

ohlc_1d = (ohlc_1d_base
  .groupBy('exchange','symbol')
  .applyInPandas(compute_1d, schema=schema_1d)
)

# Optimized as-of join: only consider candidate 1d bars from within the last 2 days
MS_1D = 86400000
f4 = feat.alias('f4')
d1 = ohlc_1d.alias('d1')
j = (f4.join(d1, on=[f4.exchange==d1.exchange, f4.symbol==d1.symbol], how='left')
  .where((d1.open_time <= f4.open_time) & (d1.open_time >= (f4.open_time - F.lit(2*MS_1D))))
)
w_asof = Window.partitionBy('f4.exchange','f4.symbol','f4.interval','f4.open_time').orderBy(F.col('d1.open_time').desc())
feat = (j
  .withColumn('rn', F.row_number().over(w_asof))
  .where(F.col('rn')==1)
  .drop('rn')
  .withColumnRenamed('sma_20d','sma_20d_1d')
  .withColumnRenamed('sma_50d','sma_50d_1d')
  .withColumnRenamed('ret_vol_20d','ret_vol_20d_1d')
)
display(feat.select('exchange','symbol','open_dt_syd','trend_up_1d','rsi_14_1d','sma_20d_1d','sma_50d_1d').limit(10))


## Barrier label (TP hit before SL within W bars)
Label is computed from future highs/lows within the next W bars relative to entry (next bar open).

Conservative assumptions:
- We treat entry as next bar open (`open` at bar_idx+1)
- We compute whether any future `high` breaches TP and whether any future `low` breaches SL
- Ordering within a bar is ambiguous; for labeling we approximate by first-hit using the earliest bar where each condition occurs.

In [None]:
# Entry is next bar open. Create an entry view with entry_bar_idx and entry_price
entry = (feat
  .withColumn('entry_bar_idx', (F.col('bar_idx') + 1).cast('long'))
  .withColumn('entry_open', F.lead('open_d', 1).over(w))
  .withColumn('atr_entry', F.lead('atr', 1).over(w))
  .withColumn('open_dt_utc_entry', F.lead('open_dt_utc', 1).over(w))
  .withColumn('open_dt_syd_entry', F.lead('open_dt_syd', 1).over(w))
  .where(F.col('entry_open').isNotNull() & F.col('atr_entry').isNotNull())
  .withColumn('tp_price', F.col('entry_open') + F.lit(TP_ATR) * F.col('atr_entry'))
  .withColumn('sl_price', F.col('entry_open') - F.lit(SL_ATR) * F.col('atr_entry'))
)

future = (feat
  .select('exchange','symbol','interval','bar_idx','high_d','low_d')
  .withColumnRenamed('bar_idx','f_bar_idx')
)

# Join future bars within [entry_bar_idx, entry_bar_idx + W_BARS]
joined = (entry
  .select('exchange','symbol','interval','bar_idx','entry_bar_idx','entry_open','tp_price','sl_price','open_dt_utc_entry','open_dt_syd_entry',
          'hour_syd','dow_syd',
          'atr_entry','atrp','trend_up','trend_strength','ret_1','ret_vol','range','body_to_range','upper_wick','lower_wick','gap',
          'dist_to_hh_atr','dist_to_ll_atr','vol_z')
  .join(future, on=['exchange','symbol','interval'], how='inner')
  .where((F.col('f_bar_idx') >= F.col('entry_bar_idx')) & (F.col('f_bar_idx') <= (F.col('entry_bar_idx') + F.lit(W_BARS))))
)

# Earliest bar that hits TP and earliest bar that hits SL
tp_hit = F.when(F.col('high_d') >= F.col('tp_price'), F.col('f_bar_idx'))
sl_hit = F.when(F.col('low_d') <= F.col('sl_price'), F.col('f_bar_idx'))

hits = (joined
  .groupBy('exchange','symbol','interval','bar_idx','entry_bar_idx')
  .agg(
    F.min(tp_hit).alias('tp_hit_idx'),
    F.min(sl_hit).alias('sl_hit_idx'),
    F.first('entry_open').alias('entry_open'),
    F.first('tp_price').alias('tp_price'),
    F.first('sl_price').alias('sl_price'),
    F.first('open_dt_utc_entry').alias('entry_dt_utc'),
    F.first('open_dt_syd_entry').alias('entry_dt_syd'),
    F.first('hour_syd').alias('hour_syd'),
    F.first('dow_syd').alias('dow_syd'),
    F.first('atr_entry').alias('atr_entry'),
    F.first('atrp').alias('atrp'),
    F.first('trend_up').alias('trend_up'),
    F.first('trend_strength').alias('trend_strength'),
    F.first('ret_1').alias('ret_1'),
    F.first('ret_vol').alias('ret_vol'),
    F.first('range').alias('range'),
    F.first('body_to_range').alias('body_to_range'),
    F.first('upper_wick').alias('upper_wick'),
    F.first('lower_wick').alias('lower_wick'),
    F.first('gap').alias('gap'),
    F.first('dist_to_hh_atr').alias('dist_to_hh_atr'),
    F.first('dist_to_ll_atr').alias('dist_to_ll_atr'),
    F.first('vol_z').alias('vol_z')
  )
)

# Label with ambiguity handling
# - If TP and SL hit in different bars: whichever hits first wins
# - If both hit in the same bar: apply AMBIGUITY_POLICY
ambiguous_same_bar = (F.col('tp_hit_idx').isNotNull() & F.col('sl_hit_idx').isNotNull() & (F.col('tp_hit_idx') == F.col('sl_hit_idx')))

base_label = (
  F.when(F.col('tp_hit_idx').isNotNull() & (F.col('sl_hit_idx').isNull() | (F.col('tp_hit_idx') < F.col('sl_hit_idx'))), F.lit(1))
   .when(F.col('sl_hit_idx').isNotNull() & (F.col('tp_hit_idx').isNull() | (F.col('sl_hit_idx') < F.col('tp_hit_idx'))), F.lit(0))
   .otherwise(F.lit(None))
)

label = (
  F.when(ambiguous_same_bar & (F.lit(AMBIGUITY_POLICY) == F.lit('tp_first')), F.lit(1))
   .when(ambiguous_same_bar & (F.lit(AMBIGUITY_POLICY) == F.lit('sl_first')), F.lit(0))
   .when(ambiguous_same_bar & (F.lit(AMBIGUITY_POLICY) == F.lit('discard_both')), F.lit(None))
   .otherwise(base_label)
)

dataset = (hits
  .withColumn('ambiguous_same_bar', ambiguous_same_bar.cast('int'))
  .withColumn('label', label)
  .withColumn('label_timeout', (F.col('label').isNull()).cast('int'))
  .withColumn('label_variant', F.concat(F.lit('TP'), F.lit(TP_ATR), F.lit('_SL'), F.lit(SL_ATR), F.lit('_W'), F.lit(W_BARS), F.lit('_'), F.lit(AMBIGUITY_POLICY)))
  .withColumn('run_id', F.lit(RUN_ID))
)

display(dataset.select('exchange','symbol','entry_dt_syd','label','tp_hit_idx','sl_hit_idx','trend_up','dist_to_hh_atr','atrp').limit(20))
print('dataset rows:', dataset.count())
display(dataset.groupBy('label').count())


## Persist dataset to Delta

In [None]:
spark.sql('''
CREATE TABLE IF NOT EXISTS ml_barrier_dataset_4h (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  bar_idx BIGINT,
  entry_bar_idx BIGINT,
  entry_dt_utc TIMESTAMP,
  entry_dt_syd TIMESTAMP,
  hour_syd INT,
  dow_syd STRING,
  entry_open DOUBLE,
  tp_price DOUBLE,
  sl_price DOUBLE,
  tp_hit_idx BIGINT,
  sl_hit_idx BIGINT,
  ambiguous_same_bar INT,
  label INT,
  label_timeout INT,
  label_variant STRING,
  atr_entry DOUBLE,
  atrp DOUBLE,
  trend_up INT,
  trend_strength DOUBLE,
  ret_1 DOUBLE,
  ret_vol DOUBLE,
  range DOUBLE,
  body_to_range DOUBLE,
  upper_wick DOUBLE,
  lower_wick DOUBLE,
  gap DOUBLE,
  dist_to_hh_atr DOUBLE,
  dist_to_ll_atr DOUBLE,
  vol_z DOUBLE,
  ema_20 DOUBLE,
  ema_50 DOUBLE,
  ema_ratio DOUBLE,
  rsi_14 DOUBLE,
  adx_14 DOUBLE,
  macd_line DOUBLE,
  macd_signal DOUBLE,
  macd_hist DOUBLE,
  bb_bw_20 DOUBLE,
  donch_pos_20 DOUBLE,
  mom_logret_3 DOUBLE,
  mom_logret_6 DOUBLE,
  mom_logret_12 DOUBLE,
  sma_20d_1d DOUBLE,
  sma_50d_1d DOUBLE,
  trend_up_1d INT,
  ret_vol_20d_1d DOUBLE,
  rsi_14_1d DOUBLE
) USING DELTA
''')

(dataset
  .write
  .mode('append')
  .saveAsTable('ml_barrier_dataset_4h')
)
print('Wrote run_id', RUN_ID, 'to ml_barrier_dataset_4h')


## Train a model (Spark ML GBTClassifier)
We use a **time-based split** and train only on rows with non-null labels.

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import mlflow
import mlflow.spark

df = dataset.where(F.col('label').isNotNull())

df = (df
  .withColumn('split',
    F.when(F.col('entry_dt_utc') < F.to_timestamp(F.lit(TRAIN_END)), F.lit('train'))
     .when(F.col('entry_dt_utc') < F.to_timestamp(F.lit(VALID_END)), F.lit('valid'))
     .otherwise(F.lit('test'))
  )
)

feature_cols = [
  # core
  'atr_entry','atrp','trend_up','trend_strength','ret_1','ret_vol','range','body_to_range','upper_wick','lower_wick','gap',
  'dist_to_hh_atr','dist_to_ll_atr','vol_z',
  # enhanced indicators (applyInPandas)
  'ema_20','ema_50','ema_ratio','rsi_14','adx_14',
  'macd_line','macd_signal','macd_hist','bb_bw_20','donch_pos_20',
  'mom_logret_3','mom_logret_6','mom_logret_12',
  # higher timeframe context
  'sma_20d_1d','sma_50d_1d','trend_up_1d','ret_vol_20d_1d','rsi_14_1d','macd_hist_1d','adx_14_1d',
]
cat_cols = ['dow_syd']

indexers = [StringIndexer(inputCol=c, outputCol=f'{c}_idx', handleInvalid='keep') for c in cat_cols]
assembler = VectorAssembler(inputCols=feature_cols + [f'{c}_idx' for c in cat_cols], outputCol='features', handleInvalid='keep')

clf = GBTClassifier(featuresCol='features', labelCol='label', maxDepth=5, maxIter=80, stepSize=0.1)
pipeline = Pipeline(stages=indexers + [assembler, clf])

train = df.where(F.col('split')=='train')
valid = df.where(F.col('split')=='valid')
test = df.where(F.col('split')=='test')

mlflow.autolog()
with mlflow.start_run(run_name=f'barrier_4h_{RUN_ID}'):
  model = pipeline.fit(train)
  pred_valid = model.transform(valid)
  pred_test = model.transform(test)

  evaluator = BinaryClassificationEvaluator(labelCol='label', rawPredictionCol='rawPrediction', metricName='areaUnderROC')
  auc_valid = evaluator.evaluate(pred_valid)
  auc_test = evaluator.evaluate(pred_test)
  print('AUC valid:', auc_valid)
  print('AUC test :', auc_test)

  mlflow.log_metric('auc_valid', auc_valid)
  mlflow.log_metric('auc_test', auc_test)

  mlflow.spark.log_model(model, 'model')


## Feature importance
- Spark GBT feature importances
- Sampled permutation importance on validation


In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import pandas as pd

gbt_model = model.stages[-1]
assembler_stage = model.stages[-2]
input_cols = assembler_stage.getInputCols()
importances = list(gbt_model.featureImportances)
imp_df = pd.DataFrame({'feature': input_cols, 'importance': importances}).sort_values('importance', ascending=False)
display(imp_df.head(50))

perm_sample = valid.select(['label'] + feature_cols + cat_cols).sample(False, 0.05, seed=42)
evaluator = BinaryClassificationEvaluator(labelCol='label', rawPredictionCol='rawPrediction', metricName='areaUnderROC')
base_auc = evaluator.evaluate(model.transform(perm_sample))
print('baseline AUC(sampled valid)=', base_auc)

K = 15
top_feats = imp_df['feature'].head(K).tolist()

from pyspark.sql import functions as F
from pyspark.sql.window import Window

def permute_col(df_in, col):
  tmp = df_in.withColumn('_id', F.monotonically_increasing_id())
  shuffled = tmp.select(col).withColumn('_rn', F.row_number().over(Window.orderBy(F.rand(123))))
  tmp2 = tmp.withColumn('_rn', F.row_number().over(Window.orderBy(F.lit(1))))
  out = (tmp2.join(shuffled, on='_rn', how='inner')
           .drop(col)
           .withColumnRenamed(shuffled.columns[0], col)
           .drop('_rn','_id'))
  return out

rows=[]
for f in top_feats:
  auc = evaluator.evaluate(model.transform(permute_col(perm_sample, f)))
  rows.append((f, float(base_auc-auc), float(auc)))
perm_df = pd.DataFrame(rows, columns=['feature','auc_drop','auc_after']).sort_values('auc_drop', ascending=False)
display(perm_df)


## Score and write predictions to Delta
We write predicted probabilities for all labeled rows, plus a simple decision threshold column.

In [None]:
# Score all rows (including train/valid/test for inspection)
pred = model.transform(df)

# Extract probability of class 1
pred = pred.withColumn('p_win', F.col('probability').getItem(1))

THRESHOLD = 0.55
pred = pred.withColumn('take_trade', (F.col('p_win') >= F.lit(THRESHOLD)).cast('int'))

spark.sql('''
CREATE TABLE IF NOT EXISTS ml_barrier_predictions_4h (
  run_id STRING,
  exchange STRING,
  symbol STRING,
  interval STRING,
  entry_dt_utc TIMESTAMP,
  entry_dt_syd TIMESTAMP,
  split STRING,
  label INT,
  p_win DOUBLE,
  take_trade INT
) USING DELTA
''')

(pred
  .select('run_id','exchange','symbol','interval','entry_dt_utc','entry_dt_syd','split','label','p_win','take_trade')
  .write
  .mode('append')
  .saveAsTable('ml_barrier_predictions_4h')
)

display(pred.select('exchange','symbol','entry_dt_syd','label','p_win','take_trade').orderBy(F.col('entry_dt_utc').desc()).limit(50))
print('Wrote predictions for run_id', RUN_ID)
