In [90]:
from prj.config import DATA_DIR
from prj.data.data_loader import DataConfig, DataLoader
import polars as pl
from sklearn.metrics import r2_score
from catboost import CatBoostRegressor
from prj.data.data_loader import PARTITIONS_DATE_INFO
import pandas as pd
import lightgbm as lgb
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import gc
from tqdm import tqdm
import numpy as np
from prj.utils import online_iterator, online_iterator_daily
import time

SEED = 42

# Base

In [91]:
import optuna

study_name = "lgbm_offline_2025-01-04_16-04-39"
storage = "mysql+pymysql://admin:F1g5w#6zP4TN@janestreet.c3uaekuseqse.us-east-1.rds.amazonaws.com/janestreet"
study = optuna.load_study(study_name=study_name, storage=storage)
best_trial = study.best_trials[0]

model_path = f'/home/lorecampa/projects/jane_street_forecasting/experiments/lgbm_offline/{study_name}/trial_{best_trial.number}/best_model.txt'
model = lgb.Booster(model_file=model_path)
model

<lightgbm.basic.Booster at 0x7f26e7552d40>

In [92]:
model_path

'/home/lorecampa/projects/jane_street_forecasting/experiments/lgbm_offline/lgbm_offline_2025-01-04_16-04-39/trial_0/best_model.txt'

In [37]:
import optuna.visualization as vis

fig = vis.plot_pareto_front(study, target_names=["r2_score", "sharpe"])
fig.show()


In [75]:
data_args = {'include_time_id': True, 'include_intrastock_norm_temporal': True}
config = DataConfig(**data_args)
loader = DataLoader(data_dir=DATA_DIR, config=config)
start_dt, end_dt = 1530, 1698
# start_dt, end_dt = 1600, 1635
test_ds = loader.load(start_dt-1, end_dt)
X_test, y_test, w_test, _ = loader._build_splits(test_ds.filter(pl.col('date_id').ge(start_dt)))

100%|██████████| 177/177 [00:10<00:00, 16.27it/s]


Skipping 1692-1698
Skipping 1693-1698
Skipping 1694-1698
Skipping 1695-1698
Skipping 1696-1698
Skipping 1697-1698
Skipping 1698-1698


In [76]:
y_hat = model.predict(X_test).clip(-5, 5).flatten()
offline_score = r2_score(y_test, y_hat, sample_weight=w_test)
offline_score

0.006788041325703875

In [77]:
del X_test
gc.collect()

4

In [78]:
features = loader.features
print(len(features))

134


In [79]:
MEAN_FEATURES = [0, 2, 3, 5, 6, 7, 18, 19, 34, 35, 36, 37, 38, 41, 43, 44, 48, 53, 55, 59, 62, 65, 68, 73, 74, 75, 76, 77, 78]
STD_FEATURES = [39, 42, 46, 53, 57, 66]
SKEW_FEATURES = [5, 40, 41, 42, 43, 44]
ZSCORE_FEATURES = [1, 36, 40, 45, 48, 49, 51, 52, 53, 54, 55, 59, 60]

def include_intrastock_norm(df: pl.LazyFrame, responder) -> pl.LazyFrame:
    df = df.with_columns(
        pl.col([f'feature_{j:02d}' for j in set(MEAN_FEATURES + ZSCORE_FEATURES)]).mean().over(['date_id', 'time_id', f'cluster_label_{responder}']).name.suffix(f'_{responder}_mean'),
        pl.col([f'feature_{j:02d}' for j in set(STD_FEATURES + ZSCORE_FEATURES)]).std().over(['date_id', 'time_id', f'cluster_label_{responder}']).name.suffix(f'_{responder}_std'),
        pl.col([f'feature_{j:02d}' for j in SKEW_FEATURES]).skew().over(['date_id', 'time_id', f'cluster_label_{responder}']).name.suffix(f'_{responder}_skew'),
    ).with_columns(
        pl.col(f'feature_{j:02d}').sub(f'feature_{j:02d}_{responder}_mean').truediv(f'feature_{j:02d}_{responder}_std').name.suffix(f'_{responder}_zscore') for j in ZSCORE_FEATURES
    ).drop([f'feature_{j:02d}_{responder}_std' for j in ZSCORE_FEATURES if j not in STD_FEATURES] + \
        [f'feature_{j:02d}_{responder}_mean' for j in ZSCORE_FEATURES if j not in MEAN_FEATURES])
    return df

## LGBM Callback

# Inference

In [81]:
_test_ds = test_ds.collect()

In [86]:
MAX_ITERATIONS = 1000
FINE_TUNING_TIME_LIMIT = 50

def build_splits(df: pl.DataFrame, features: list):
    X = df.select(features).to_numpy()
    y = df['responder_6'].to_numpy().flatten()
    w = df['weight'].to_numpy().flatten()
    return X, y, w

def train_with_es(init_model: CatBoostRegressor, params: dict, train_df: pl.DataFrame, val_df: pl.DataFrame, use_weighted_loss, es_patience):
    start_time = time.time()
    _params = params.copy()
    _params.pop('num_iterations', None)
        
    X_train, y_train, w_train = build_splits(train_df, features)
    train_data = lgb.Dataset(data=X_train, label=y_train, weight=w_train if use_weighted_loss else None)    
    del X_train, y_train, w_train
    gc.collect()
    
    X_val, y_val, w_val = build_splits(val_df, features)
    val_data = lgb.Dataset(data=X_val, label=y_val, weight=w_val if use_weighted_loss else None, reference=train_data)
    del X_val, y_val, w_val
    gc.collect()

    callbacks = [
        lgb.log_evaluation(period=50),
        LGBMEarlyStoppingCallbackWithTimeout(es_patience, timeout_seconds=FINE_TUNING_TIME_LIMIT)
    ]
    
    print(f"Learning rate: {_params['learning_rate']:e}")
    model = lgb.train(
        train_set=train_data,
        params=_params,
        init_model=init_model,
        num_boost_round=MAX_ITERATIONS,
        valid_sets=[val_data],
        callbacks=callbacks,
    )
    
    print(f'Train completed in {((time.time() - start_time)/60):.3f} minutes')
    
    return model

In [87]:
responder_replay_buffer_config = DataConfig()
responder_replay_buffer_loader = DataLoader(data_dir=DATA_DIR, config=responder_replay_buffer_config)
base_responder_replay_buffer = responder_replay_buffer_loader.load(start_dt-1-loader.window_period, start_dt-2)\
    .select('date_id', 'time_id', 'symbol_id', 'responder_6')\
    .collect()
    

TREE_OLD_DATASET_MAX_HISTORY = 30
AUX_COLS = ['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']

base_old_dataset = loader.load(start_dt-TREE_OLD_DATASET_MAX_HISTORY, start_dt-1)\
    .select(AUX_COLS + features) \
    .collect()

100%|██████████| 37/37 [00:01<00:00, 18.78it/s]


Skipping 1523-1529
Skipping 1524-1529
Skipping 1525-1529
Skipping 1526-1529
Skipping 1527-1529
Skipping 1528-1529
Skipping 1529-1529


In [89]:
TREE_FINE_TUNING = True

TREE_OLD_DATA_FRACTION = 0.2
TREE_EARLY_STOPPING_DAYS = 14
TREE_ES_PATIENCE = 20
TREE_TRAIN_EVERY = 30
TREE_USE_WEIGHTED_LOSS = best_trial.params['use_weighted_loss']

TREE_LR_DECAY = 0.5
assert TREE_TRAIN_EVERY > TREE_EARLY_STOPPING_DAYS

TREE_MAX_FINE_TUNING_TIME_LIMIT = time.time() + 60 * 60 * 8


USE_INTRA_STOCK_NORM = True
USE_TIME_NORM_ID = True
verbose=True

corr_responder = 'responder_6'
model = lgb.Booster(model_file=model_path)
params = model.params.copy()

period = loader.window_period
stock_cluster_mapping = {}
stock_max_time_id = {}
default_cluster=-1
default_max_time_id=967


responder_replay_buffer = base_responder_replay_buffer
    

TREE_OLD_DATASET_MAX_HISTORY = 30
AUX_COLS = ['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']

old_dataset = base_old_dataset
    

current_day_data: pl.DataFrame | None = None
new_dataset: pl.DataFrame | None = None
date_idx = 0

def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame:
    global responder_replay_buffer, stock_cluster_mapping, stock_max_time_id, default_max_time_id, old_dataset, new_dataset, date_idx, model, current_day_data, params, TREE_FINE_TUNING
    curr_date = test['date_id'].first()

    if lags is not None:            
        _lags = lags.select(pl.col('date_id').sub(1), 'time_id', 'symbol_id', pl.col(f'{corr_responder}_lag_1').alias(corr_responder))
            
        # Time id norm preparation
        if USE_TIME_NORM_ID:
            stock_max_time_id_map = lags.group_by('symbol_id').agg(pl.col('time_id').max())
            stock_max_time_id = dict(zip(stock_max_time_id_map['symbol_id'], stock_max_time_id_map['time_id']))
            default_max_time_id = max(list(stock_max_time_id.values()))

        
        # Intrastock normalization preparation
        if USE_INTRA_STOCK_NORM:
            responder_replay_buffer = responder_replay_buffer.vstack(
                _lags
            ).filter(pl.col('date_id').is_between(curr_date-period, curr_date))

            pivot = responder_replay_buffer.filter(pl.col('date_id') < curr_date)\
                        .pivot(index=['date_id', 'time_id'], values=[corr_responder], separator='_', on='symbol_id')\
                        .sort('date_id', 'time_id') \
                        .fill_nan(None)\
                        .fill_null(strategy='zero')

            corr_cols = [col for col in pivot.columns if col not in ['date_id', 'time_id']]
            stocks = [int(col) for col in corr_cols]
            df_corr_responder = pivot.select(corr_cols).corr()
            linked = linkage(df_corr_responder, method='ward')
            cluster_labels = fcluster(linked, t=2.5, criterion='distance')
            stock_cluster_mapping = dict(zip(stocks, cluster_labels))
            
        TREE_FINE_TUNING = TREE_FINE_TUNING and time.time() < TREE_MAX_FINE_TUNING_TIME_LIMIT
        if TREE_FINE_TUNING:
            if current_day_data is not None:
                current_day_data = current_day_data.join(_lags, on=['date_id', 'time_id', 'symbol_id'], how='left', maintain_order='left')\
                    .select(AUX_COLS + features)
                
                new_dataset = current_day_data if new_dataset is None else new_dataset.vstack(current_day_data)
            
            current_day_data = test
            
            if (date_idx + 1) % TREE_TRAIN_EVERY == 0:
                print(f'Starting fine tuning at date {curr_date}')
                max_date = new_dataset['date_id'].max()
                split_date = max_date - TREE_EARLY_STOPPING_DAYS
                new_validation_dataset = new_dataset.filter(pl.col('date_id').gt(split_date))
                new_training_dataset = new_dataset.filter(pl.col('date_id').le(split_date))
                

                new_training_dataset_len = new_training_dataset.shape[0]
                old_dataset_len = old_dataset.shape[0]
                old_data_len = min(int(TREE_OLD_DATA_FRACTION * new_training_dataset_len / (1 - TREE_OLD_DATA_FRACTION)), old_dataset_len)
                
                if verbose:
                    old_days = old_dataset['date_id'].unique().sort().to_list()
                    train_days = new_training_dataset['date_id'].unique().sort().to_list()
                    val_days = new_validation_dataset['date_id'].unique().sort().to_list()
                    print('Old days: ', old_days)
                    print('Train days: ', train_days)
                    print('Val days: ', val_days)
                    print(new_training_dataset_len, old_data_len, TREE_OLD_DATA_FRACTION)
                
                train_df = old_dataset.sample(n=old_data_len).vstack(new_training_dataset)
                val_df = new_validation_dataset
                
                print(f'Starting fine tuning at date {curr_date}')

                params['learning_rate'] = max(params['learning_rate'] * TREE_LR_DECAY, 1e-5)
                
                model = train_with_es(
                    init_model= model, 
                    train_df=train_df,
                    val_df=val_df,
                    use_weighted_loss=TREE_USE_WEIGHTED_LOSS,
                    es_patience=TREE_ES_PATIENCE,
                    params=params,
                )
                
                del train_df, val_df
                gc.collect()
                
                new_max_old_dataset_date = new_training_dataset['date_id'].max()
                old_dataset = pl.concat([
                    old_dataset,
                    new_training_dataset
                ]).filter(
                    pl.col('date_id').is_between(new_max_old_dataset_date-TREE_OLD_DATASET_MAX_HISTORY, new_max_old_dataset_date)
                )
                new_dataset = new_validation_dataset
    
        date_idx += 1
    else:
        if TREE_FINE_TUNING:
            current_day_data = current_day_data.vstack(test)
    
    
    if USE_TIME_NORM_ID:
        test = test.with_columns(
            pl.col('symbol_id').replace_strict(
                stock_max_time_id, default=default_max_time_id, return_dtype=pl.Int16
            ).alias('max_prev_stock_time_id'),
        ).with_columns(
            pl.col('time_id').truediv(
                'max_prev_stock_time_id'
            ).alias('time_id_norm')
        ).drop('max_prev_stock_time_id')
        
    if USE_INTRA_STOCK_NORM:
        test = test.with_columns(
            pl.col('symbol_id').replace_strict(
                stock_cluster_mapping, default=default_cluster, return_dtype=pl.Int8
            ).alias(f'cluster_label_{corr_responder}')
        ).pipe(
            include_intrastock_norm, 
            corr_responder
        ).drop(f'cluster_label_{corr_responder}')
    
    
    
    X = test.select(features).cast(pl.Float32).to_numpy()
    y_hat = model.predict(X).clip(-5, 5).flatten()

    predictions = test.select('row_id', pl.Series(y_hat).alias('responder_6'))


    assert len(predictions) == len(test)

    return predictions



y_hat_iterator = []

for test, lags in online_iterator_daily(_test_ds, show_progress=True):
    res = predict(test, lags)
    y_hat_iterator.append(res['responder_6'].to_numpy())

y_hat_iterator = np.concatenate(y_hat_iterator, dtype=np.float32)

online_score = r2_score(y_true=y_test, y_pred=y_hat_iterator, sample_weight=w_test)
gain = online_score - offline_score

print(f'Online score: {online_score:.4f}, Offline score: {offline_score:.4f} -> Gain: {gain:.4f}')


 17%|█▋        | 29/169 [00:07<00:33,  4.15it/s]

Starting fine tuning at date 1559
Old days:  [1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529]
Train days:  [1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544]
Val days:  [1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558]
552728 138182 0.2
Starting fine tuning at date 1559
Learning rate: 2.955130e-02
Training until validation scores don't improve for 20 rounds
[250]	valid_0's l2: 0.512678
Early stopping, best iteration is:
[247]	valid_0's l2: 0.51263
Train completed in 0.152 minutes


 35%|███▍      | 59/169 [00:24<00:27,  4.03it/s]

Starting fine tuning at date 1589
Old days:  [1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544]
Train days:  [1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574]
Val days:  [1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588]
1114168 278542 0.2
Starting fine tuning at date 1589
Learning rate: 1.477565e-02
Training until validation scores don't improve for 20 rounds
[250]	valid_0's l2: 0.560586
Early stopping, best iteration is:
[268]	valid_0's l2: 0.56055
Train completed in 0.236 minutes


 53%|█████▎    | 89/169 [00:47<00:23,  3.42it/s]

Starting fine tuning at date 1619
Old days:  [1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574]
Train days:  [1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604]
Val days:  [1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618]
1109328 277332 0.2
Starting fine tuning at date 1619
Learning rate: 7.387825e-03
Training until validation scores don't improve for 20 rounds
[300]	valid_0's l2: 0.71511
[350]	valid_0's l2: 0.715016
[400]	valid_0's l2: 0.714943
[450]	valid_0's l2: 0.714895
[500]	valid_0's l2: 0.714842
[550]	valid_0's l2: 0.714793
[600]	valid_0's l2: 0.714766
[650]	valid_0's l2: 0.714729
[700]	valid_0's l2: 0.714725
Early stopping, best iteration is:
[680]	valid_0's l2: 0.714714
Train comple

 70%|███████   | 119/169 [01:24<00:20,  2.46it/s]

Starting fine tuning at date 1649
Old days:  [1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604]
Train days:  [1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634]
Val days:  [1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648]
1123848 280962 0.2
Starting fine tuning at date 1649
Learning rate: 3.693912e-03
Training until validation scores don't improve for 20 rounds
[700]	valid_0's l2: 0.500325
[750]	valid_0's l2: 0.50025
[800]	valid_0's l2: 0.500176
[850]	valid_0's l2: 0.500126
[900]	valid_0's l2: 0.500072
[950]	valid_0's l2: 0.500031
[1000]	valid_0's l2: 0.500005
[1050]	valid_0's l2: 0.499958
[1100]	valid_0's l2: 0.499922
[1150]	valid_0's l2: 0.499889
[1200]	valid_0's l2: 0.49985
[1250]	valid_0

 88%|████████▊ | 149/169 [02:29<00:15,  1.33it/s]

Starting fine tuning at date 1679
Old days:  [1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634]
Train days:  [1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664]
Val days:  [1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678]
1107392 276848 0.2
Starting fine tuning at date 1679
Learning rate: 1.846956e-03
Training until validation scores don't improve for 20 rounds
[1600]	valid_0's l2: 0.605943
[1650]	valid_0's l2: 0.605883
[1700]	valid_0's l2: 0.605832
[1750]	valid_0's l2: 0.605783
[1800]	valid_0's l2: 0.605743
[1850]	valid_0's l2: 0.605703
[1900]	valid_0's l2: 0.605664
[1950]	valid_0's l2: 0.605623
[2000]	valid_0's l2: 0.605584
[2050]	valid_0's l2: 0.60555
[2100]	valid_0's l2: 0.605511
[2150]	

100%|██████████| 169/169 [03:41<00:00,  1.31s/it]

Online score: 0.0051, Offline score: 0.0068 -> Gain: -0.0016



