In [1]:
import pandas as pd
import polars as pl
import numpy as np
import os
import gc
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import KFold, StratifiedKFold
import xgboost as xgb
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor, log_evaluation, record_evaluation
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
#from sklearn.impute import IterativeImputer
import pickle
import optuna
from optuna.visualization import plot_slice, plot_param_importances
import shap
import random

gc.enable()

pd.options.display.max_columns = None
#pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', None)

#pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_str_lengths(10000)
optuna.logging.set_verbosity(optuna.logging.WARNING)

In [2]:
path = 'I:/Kaggle/jane-street-real-time-market-data-forecasting/'

In [3]:
os.listdir(path)

['features.csv',
 'kaggle_evaluation',
 'lags.parquet',
 'my_folder',
 'responders.csv',
 'sample_submission.csv',
 'team_folder',
 'test.parquet',
 'top_100000_rows_sorted_by_weight_descending.parquet',
 'top_10000_rows_sorted_by_weight_descending.parquet',
 'train.parquet']

In [4]:
train_df = pl.read_parquet(path + 'train.parquet/').drop(['responder_0', 'responder_1', 'responder_2', 'responder_3', 'responder_4', 'responder_5', 'responder_7', 'responder_8', 'partition_id']).select(pl.all().shrink_dtype())
print(train_df.shape)
train_df.head()

(47127338, 84)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,,,,,,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,,-0.169586,,-1.335938,-1.707803,0.91013,,1.636431,1.522133,-1.551398,-0.229627,,,1.378301,-0.283712,0.123196,,,,0.28118,0.269163,0.349028,-0.012596,-0.225932,,-1.073602,,,-0.181716,,,,0.564021,2.088506,0.832022,,0.204797,,,-0.808103,,-2.037683,0.727661,,-0.989118,-0.345213,-1.36224,,,,,,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,,,-0.261412,-0.211486,-0.335556,-0.281498,0.775981
0,0,7,1.370613,,,,,,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,,0.317467,,-1.250016,-1.682929,1.412757,,0.520378,0.744132,-0.788658,0.641776,,,0.2272,0.580907,1.128879,,,,-1.512286,-1.414357,-1.823322,-0.082763,-0.184119,,,,,,,,,-10.835207,-0.002704,-0.621836,,1.172836,,,-1.625862,,-1.410017,1.063013,,0.888355,0.467994,-1.36224,,,,,,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,,,-0.281207,-0.182894,-0.245565,-0.302441,0.703665
0,0,9,2.285698,,,,,,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,,-0.814909,,-1.296782,-2.040234,0.639589,,1.597359,0.657514,-1.350148,0.364215,,,-0.017751,-0.317361,-0.122379,,,,-0.320921,-0.95809,-2.436589,0.070999,-0.245239,,,,,,,,,-1.420632,-3.515137,-4.67776,,0.535897,,,-0.72542,,-2.29417,1.764551,,-0.120789,-0.063458,-1.36224,,,,,,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,,,0.377131,0.300724,-0.106842,-0.096792,2.109352
0,0,10,0.690606,,,,,,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,,-0.251882,,-1.902009,-0.979447,0.241165,,-0.392359,-0.224699,-2.129397,-0.855287,,,0.404142,-0.578156,0.105702,,,,0.544138,-0.087091,-1.500147,-0.201288,-0.038042,,,,,,,,,0.382074,2.669135,0.611711,,2.413415,,,1.313203,,-0.810125,2.939022,,3.988801,1.834661,-1.36224,,,,,,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,,,-0.226891,-0.251412,-0.215522,-0.296244,1.114137
0,0,14,0.44057,,,,,,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,,0.646086,,-1.844685,-1.58656,-0.182024,,-0.969949,-0.673813,-1.282132,-1.399894,,,0.043815,-0.320225,-0.031713,,,,-0.08842,-0.995003,-2.635336,-0.196461,-0.618719,,,,,,,,,-2.0146,-2.321076,-3.711265,,1.253902,,,0.476195,,-0.771732,2.843421,,1.379815,0.411827,-1.36224,,,,,,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,,,3.678076,2.793581,2.61825,3.418133,-3.57282


In [14]:
feature_cols = [col for col in train_df.columns if 'feature' in col]
feature_cols

['feature_00',
 'feature_01',
 'feature_02',
 'feature_03',
 'feature_04',
 'feature_05',
 'feature_06',
 'feature_07',
 'feature_08',
 'feature_09',
 'feature_10',
 'feature_11',
 'feature_12',
 'feature_13',
 'feature_14',
 'feature_15',
 'feature_16',
 'feature_17',
 'feature_18',
 'feature_19',
 'feature_20',
 'feature_21',
 'feature_22',
 'feature_23',
 'feature_24',
 'feature_25',
 'feature_26',
 'feature_27',
 'feature_28',
 'feature_29',
 'feature_30',
 'feature_31',
 'feature_32',
 'feature_33',
 'feature_34',
 'feature_35',
 'feature_36',
 'feature_37',
 'feature_38',
 'feature_39',
 'feature_40',
 'feature_41',
 'feature_42',
 'feature_43',
 'feature_44',
 'feature_45',
 'feature_46',
 'feature_47',
 'feature_48',
 'feature_49',
 'feature_50',
 'feature_51',
 'feature_52',
 'feature_53',
 'feature_54',
 'feature_55',
 'feature_56',
 'feature_57',
 'feature_58',
 'feature_59',
 'feature_60',
 'feature_61',
 'feature_62',
 'feature_63',
 'feature_64',
 'feature_65',
 'feature_

In [15]:
train_scan = pl.scan_parquet(path + 'train.parquet/')
test_scan = pl.scan_parquet(path + 'test.parquet/')

In [16]:
train_symbol_ids_list = sorted(train_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
test_symbol_ids_list = sorted(test_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
unique_symbol_ids_list = sorted(list(set(train_symbol_ids_list + test_symbol_ids_list)))
unique_symbol_ids_list

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38]

In [24]:
def get_medians(df):
    medians_dict = {}
    for symbol_id in tqdm(unique_symbol_ids_list):
        symbol_id_df = df.filter(pl.col('symbol_id') == symbol_id)
        #for col in feature_cols:
        medians_dict[f'{symbol_id}'] = {col:symbol_id_df[col].median() for col in feature_cols}

    return medians_dict

In [25]:
train_medians = get_medians(train_df)

100%|██████████████████████████████████████████████████████████████████████████████████| 39/39 [00:25<00:00,  1.52it/s]


In [26]:
train_medians

{'0': {'feature_00': 0.3473193049430847,
  'feature_01': 0.007909894920885563,
  'feature_02': 0.3482334017753601,
  'feature_03': 0.3475276231765747,
  'feature_04': 0.002359127625823021,
  'feature_05': -0.04910224676132202,
  'feature_06': -0.009338218718767166,
  'feature_07': -0.03103272244334221,
  'feature_08': 0.05644051730632782,
  'feature_09': 11.0,
  'feature_10': 7.0,
  'feature_11': 76.0,
  'feature_12': -0.21585240960121155,
  'feature_13': -0.23796901106834412,
  'feature_14': -0.22476664185523987,
  'feature_15': -0.31632620096206665,
  'feature_16': -0.283399373292923,
  'feature_17': -0.3105772137641907,
  'feature_18': 0.09238912165164948,
  'feature_19': 0.04509960487484932,
  'feature_20': 0.6573883891105652,
  'feature_21': -0.15391001105308533,
  'feature_22': 0.8330196142196655,
  'feature_23': 0.6250646114349365,
  'feature_24': 0.054774921387434006,
  'feature_25': -0.16233320534229279,
  'feature_26': 0.9890378713607788,
  'feature_27': 1.1365611553192139,
 

In [23]:
train_medians['0']

{'feature_78': -0.237059086561203}

In [27]:
train_df['weight'].n_unique()

50341

In [30]:
sorted_df = train_df.sort(by=['weight'], descending=True)
print(sorted_df.shape)
sorted_df

(47127338, 84)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
957,0,1,10.240419,-0.111762,0.541699,0.197509,0.60798,2.297127,-0.708253,-2.733304,-1.030854,0.137474,11,7,76,-1.20938,0.274233,-0.456101,,-0.677073,,-0.652768,-1.581288,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.755011,-0.596669,1.894601,-0.010734,-0.366429,,0.271573,,,-1.123766,,-1.407963,1.481466,0.461492,-0.703476,0.143205,,-0.167775,,,-0.926009,,-2.113329,1.285392,,0.025718,-0.016277,0.690947,-0.384646,-0.25405,-0.387314,-1.522583,-1.653084,-0.911377,0.282105,-0.505506,-0.990799,0.167858,-0.470706,,,-0.194636,-0.133774,-0.287852,-0.196749,-0.741283
957,1,1,10.240419,0.045592,0.663517,0.569123,0.200289,2.107541,-0.400604,-0.957494,-0.570551,0.070622,11,7,76,-0.905697,0.190957,-0.670279,,-0.444593,,-1.311299,-1.358684,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.640106,-0.273644,2.132119,-0.008013,-0.313746,,0.2231,,,-0.694956,,-1.467145,0.708822,0.387088,-0.693303,0.113265,,-0.571529,,,-0.075534,,-2.367978,0.878476,,-0.178436,-0.056209,0.690947,-0.374019,-0.452783,-0.3876,-1.180055,-1.655294,-0.765336,0.161461,-0.356588,-0.805825,0.048269,-0.571969,,,-0.264392,-0.223558,-0.328832,-0.208607,-0.779799
957,2,1,10.240419,0.946682,0.698162,0.200073,0.297052,1.570207,-0.278999,-0.100121,-0.391649,0.07389,11,7,76,-1.104573,0.135747,-0.519934,,-0.435951,,-1.436819,-1.566505,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.517803,-0.096473,2.093153,-0.019315,-0.30623,,-0.272451,,,-1.00995,,-1.334793,1.253801,0.537276,-0.276118,0.214213,,0.393833,,,0.251851,,-2.07757,0.944253,,0.058492,0.02664,0.690947,-0.453473,-0.34958,-0.395326,-1.10431,-2.144159,-0.965245,0.075267,-0.568569,-1.09512,0.159139,-0.682688,,,-0.277421,-0.265382,-0.34605,-0.178963,-0.763515
957,3,1,10.240419,0.315907,0.801992,-0.414742,0.173344,1.669185,-0.42923,-1.249202,-0.532247,0.120997,11,7,76,-0.814337,0.129616,-0.613119,,-0.642771,,-0.757715,-1.675575,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.865244,-0.090707,1.77257,0.028363,-0.232334,,0.130099,,,-1.338459,,-1.509099,1.500026,0.421293,-0.680312,-0.10092,,-0.440924,,,-0.859091,,-2.479372,1.088899,,-1.139394,-0.418533,0.690947,-0.414845,-0.258394,-0.363882,-1.034853,-1.581279,-1.031972,0.164122,-0.458048,-0.88297,0.072577,-0.623852,,,-0.228645,-0.206837,-0.299041,-0.299685,-0.709867
957,4,1,10.240419,0.181526,1.239259,0.090425,-0.160181,1.921098,-0.662245,-1.71344,-1.358971,0.150708,11,7,76,-1.046015,0.23802,-0.292382,,-0.459641,-0.868039,-1.321723,-1.838364,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.085544,-0.626138,1.830449,0.045586,-0.246236,,-0.643497,,,-0.840332,,-1.660702,0.623261,0.393863,-1.149277,-0.201798,,-0.525527,,,-0.13862,,-2.650985,1.049739,,-1.158757,-0.585836,0.690947,-0.372317,-0.338869,-0.402958,-1.098662,-0.902667,-0.890817,0.196073,-0.41206,-1.032269,0.152,-0.331594,,,-0.268855,-0.329539,-0.26388,-0.252379,-0.529056
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
572,844,21,0.149967,-0.551016,0.219968,-0.266276,0.26196,-0.921703,-0.612699,-2.800681,-1.809055,-2.389769,34,4,214,1.95667,2.459976,2.514925,-0.029653,-0.07606,-0.23378,2.446756,2.217652,-2.265971,-0.337342,-2.156907,-1.686209,-0.381813,-0.734935,0.371254,-1.677513,-1.866786,0.128419,0.190471,-0.295125,-1.531707,0.728662,-1.997347,-0.874166,-0.776611,0.845105,1.200227,-0.279795,0.870752,0.159551,1.592036,1.507372,1.597248,-0.320082,-0.619713,-0.22331,-0.058419,0.052234,-0.759772,0.141514,-0.508191,0.75854,-0.650356,0.80465,0.647422,-0.785344,-0.446973,-0.257175,-0.050606,1.420116,1.387356,1.401032,1.394435,0.442947,1.275306,2.30792,2.585586,3.443758,1.012033,0.632749,0.234233,-0.558215,-0.799003,-1.756907,-2.012227,-1.583463,-0.956693,-0.324382
572,845,21,0.149967,-0.087373,0.20099,-0.272012,-0.643299,-1.129103,-0.348704,-0.95436,-0.685615,-2.379219,34,4,214,2.252772,0.858005,2.194028,-0.033232,-0.272845,-0.2355,2.145804,1.539531,-2.265971,-0.337342,-2.156907,-1.686209,-0.381813,-0.734935,0.371254,-1.677513,-1.866786,0.128419,0.190471,-0.295125,-1.451935,0.883702,-1.510782,-1.658086,0.37841,0.581518,0.976628,-0.308339,1.073546,0.104037,1.418401,1.169612,1.093266,-0.097671,-0.319596,-0.258461,-0.038022,0.068841,-0.556162,-0.96786,-0.833888,0.353032,0.482764,1.121152,0.357589,-1.249307,-0.569501,0.155415,0.06184,1.420116,1.145647,1.275237,1.513812,0.671959,0.735505,2.769902,0.886521,3.668134,0.949382,0.04926,0.19528,-0.556684,-0.587505,-1.68637,-1.750495,-0.725666,-0.929144,-0.123372
572,846,21,0.149967,-0.463196,-0.210953,-0.390834,-0.084623,0.2811,-0.388404,-0.445499,-1.06537,-2.382157,34,4,214,1.482859,0.914056,2.378988,-0.030009,-0.404089,-0.289309,1.68474,1.571505,-2.265971,-0.337342,-2.156907,-1.686209,-0.381813,-0.734935,0.371254,-1.677513,-1.866786,0.128419,0.190471,-0.295125,-1.265049,0.700265,-1.874117,-1.410517,0.712778,0.775024,0.966389,0.082255,1.303716,0.92631,1.492,1.413478,1.589726,0.00705,-0.256196,-0.142491,0.380472,0.347837,-0.218426,-1.275196,-0.214466,0.774408,-0.202671,1.102191,-0.006744,-0.702485,-0.469381,0.028403,0.073955,1.420116,0.852876,1.730469,1.38721,0.838962,1.063359,2.300827,0.238356,2.259255,1.059295,0.860549,0.37428,-0.457053,-0.586391,-1.830631,-2.003999,-0.822576,-1.10617,-0.1878
572,847,21,0.149967,-1.091798,0.137555,-0.258189,-0.828415,0.223384,-0.363649,-0.237761,-1.172222,-1.944605,34,4,214,1.76824,1.878567,2.766998,-0.040308,-0.600277,-0.312141,1.826355,1.674618,-2.265971,-0.337342,-2.156907,-1.686209,-0.381813,-0.734935,0.371254,-1.677513,-1.866786,0.128419,0.190471,-0.295125,-1.256181,0.498872,-1.240031,-1.733059,0.758456,0.769268,1.183452,-0.187827,1.171909,0.580495,2.457778,0.892859,1.98162,-0.136776,0.0725,-0.092755,0.169506,0.257092,-0.300471,-1.953463,-0.293386,0.462829,0.286417,1.204375,-0.608182,-0.825013,-0.488694,-0.200829,-0.034363,1.420116,0.646206,1.267465,1.705583,0.881047,1.040308,3.272652,-0.032657,2.085831,0.940176,2.718485,0.685388,-0.818328,-0.91364,-1.725063,-1.458995,-1.580992,-1.245342,-0.269456


In [31]:
weight_max_df = train_df.filter(pl.col('weight') == pl.col('weight').max())
print(weight_max_df.shape)
weight_max_df

(968, 84)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
957,0,1,10.240419,-0.111762,0.541699,0.197509,0.60798,2.297127,-0.708253,-2.733304,-1.030854,0.137474,11,7,76,-1.20938,0.274233,-0.456101,,-0.677073,,-0.652768,-1.581288,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.755011,-0.596669,1.894601,-0.010734,-0.366429,,0.271573,,,-1.123766,,-1.407963,1.481466,0.461492,-0.703476,0.143205,,-0.167775,,,-0.926009,,-2.113329,1.285392,,0.025718,-0.016277,0.690947,-0.384646,-0.25405,-0.387314,-1.522583,-1.653084,-0.911377,0.282105,-0.505506,-0.990799,0.167858,-0.470706,,,-0.194636,-0.133774,-0.287852,-0.196749,-0.741283
957,1,1,10.240419,0.045592,0.663517,0.569123,0.200289,2.107541,-0.400604,-0.957494,-0.570551,0.070622,11,7,76,-0.905697,0.190957,-0.670279,,-0.444593,,-1.311299,-1.358684,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.640106,-0.273644,2.132119,-0.008013,-0.313746,,0.2231,,,-0.694956,,-1.467145,0.708822,0.387088,-0.693303,0.113265,,-0.571529,,,-0.075534,,-2.367978,0.878476,,-0.178436,-0.056209,0.690947,-0.374019,-0.452783,-0.3876,-1.180055,-1.655294,-0.765336,0.161461,-0.356588,-0.805825,0.048269,-0.571969,,,-0.264392,-0.223558,-0.328832,-0.208607,-0.779799
957,2,1,10.240419,0.946682,0.698162,0.200073,0.297052,1.570207,-0.278999,-0.100121,-0.391649,0.07389,11,7,76,-1.104573,0.135747,-0.519934,,-0.435951,,-1.436819,-1.566505,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.517803,-0.096473,2.093153,-0.019315,-0.30623,,-0.272451,,,-1.00995,,-1.334793,1.253801,0.537276,-0.276118,0.214213,,0.393833,,,0.251851,,-2.07757,0.944253,,0.058492,0.02664,0.690947,-0.453473,-0.34958,-0.395326,-1.10431,-2.144159,-0.965245,0.075267,-0.568569,-1.09512,0.159139,-0.682688,,,-0.277421,-0.265382,-0.34605,-0.178963,-0.763515
957,3,1,10.240419,0.315907,0.801992,-0.414742,0.173344,1.669185,-0.42923,-1.249202,-0.532247,0.120997,11,7,76,-0.814337,0.129616,-0.613119,,-0.642771,,-0.757715,-1.675575,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.865244,-0.090707,1.77257,0.028363,-0.232334,,0.130099,,,-1.338459,,-1.509099,1.500026,0.421293,-0.680312,-0.10092,,-0.440924,,,-0.859091,,-2.479372,1.088899,,-1.139394,-0.418533,0.690947,-0.414845,-0.258394,-0.363882,-1.034853,-1.581279,-1.031972,0.164122,-0.458048,-0.88297,0.072577,-0.623852,,,-0.228645,-0.206837,-0.299041,-0.299685,-0.709867
957,4,1,10.240419,0.181526,1.239259,0.090425,-0.160181,1.921098,-0.662245,-1.71344,-1.358971,0.150708,11,7,76,-1.046015,0.23802,-0.292382,,-0.459641,-0.868039,-1.321723,-1.838364,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,,,-0.085544,-0.626138,1.830449,0.045586,-0.246236,,-0.643497,,,-0.840332,,-1.660702,0.623261,0.393863,-1.149277,-0.201798,,-0.525527,,,-0.13862,,-2.650985,1.049739,,-1.158757,-0.585836,0.690947,-0.372317,-0.338869,-0.402958,-1.098662,-0.902667,-0.890817,0.196073,-0.41206,-1.032269,0.152,-0.331594,,,-0.268855,-0.329539,-0.26388,-0.252379,-0.529056
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
957,963,1,10.240419,0.443445,0.752214,0.276545,-0.096122,-1.529969,0.056257,-3.539673,-1.470971,-0.728129,11,7,76,0.107186,0.021152,-0.001472,-0.632178,-0.571416,-0.699716,-1.077018,2.469527,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,0.108423,-0.552554,-0.151291,0.216959,-1.751433,-1.458245,-1.590026,0.250108,-0.712731,-1.063856,-0.75711,-0.982332,-0.811626,1.626265,-0.01778,0.828262,-0.875282,-0.552392,0.807897,-0.731431,-0.96959,-0.332158,-0.577289,-0.980529,1.806795,-0.071918,0.724698,-0.856261,-0.181972,0.690947,-0.334547,-0.325586,-0.377528,-1.246756,2.188196,0.028994,0.063568,-0.025097,0.257951,-0.0185,0.031826,-0.411507,-0.10491,-0.294776,-0.049972,-0.258121,-0.053828,-0.240253
957,964,1,10.240419,0.655231,0.567509,-0.188986,0.100081,-1.474321,0.111132,-1.925912,-1.366499,-0.96886,11,7,76,0.101647,0.318253,0.081409,-0.718489,-0.517661,-0.585977,-0.608632,2.588153,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,-0.25903,-0.745839,-0.402004,-0.063103,-0.97876,-1.429877,-1.112029,0.210597,0.140108,-1.129258,-0.573366,0.028808,-1.035378,1.353414,0.039351,0.518908,-0.164104,-0.332156,0.113875,0.497264,-0.014083,-0.738109,1.360997,-0.907521,1.645756,0.121964,1.295786,-0.084932,-0.179012,0.690947,-0.32414,-0.388152,-0.384441,-0.70093,2.880427,0.052811,0.083414,0.010913,0.318776,0.449095,0.157329,-0.318287,-0.149907,-0.198751,-0.05336,-0.277348,-0.098026,-0.159647
957,965,1,10.240419,0.252774,-0.116959,0.386532,0.240184,-2.53997,0.063919,-0.655256,-1.485479,-0.801469,11,7,76,0.148086,-0.040067,0.027223,-0.63073,-0.538008,-0.731834,-0.273187,2.453068,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,-0.384141,-1.605115,-0.071394,0.509293,-1.375531,-1.401514,-1.715178,0.61165,0.044726,-0.589326,-1.572394,-0.481316,-1.318698,1.196813,-0.205978,0.28859,-0.415069,-0.69819,0.11839,-0.317769,-0.415709,-0.714154,0.373688,-0.771894,1.235244,0.590063,0.878521,-0.259158,-0.185412,0.690947,-0.405841,-0.328342,-0.384995,-0.776024,2.377786,0.029632,-0.054432,-0.007284,0.259862,0.00386,0.078477,-0.309196,-0.121923,-0.366908,-0.048572,-0.238305,-0.066983,-0.324444
957,966,1,10.240419,-0.108106,0.81553,-0.335174,0.441192,-2.623096,-0.005555,-0.959441,-1.5255,-0.908088,11,7,76,0.112896,0.112087,0.058127,-0.563693,-0.363605,-0.583898,-0.993849,2.100822,2.034377,1.134396,2.3325,1.487128,1.646361,1.631919,0.316101,1.376159,1.581527,-0.197882,-0.760353,1.502041,-0.400567,-0.519503,0.339676,-0.04587,-1.469132,-1.243873,-1.722833,0.317777,0.802461,0.225503,-0.296764,-0.069401,-0.17263,0.721774,-0.130912,0.453909,-0.145018,-0.486786,0.256794,0.352037,-0.067409,-0.766896,-0.399781,-0.398337,1.597465,-0.140645,1.239611,-0.337173,-0.403232,0.690947,-0.431717,-0.255214,-0.36157,-0.658,2.657073,0.05275,0.094435,0.038374,0.26448,0.027604,0.061516,-0.281304,-0.101449,-0.242508,-0.051756,-0.17902,-0.050198,-0.231988


In [5]:
train_df.estimated_size() / 1e9

15.282841602

In [6]:
models_path = path + 'my_folder/models/20250110_01/'
if not os.path.exists(models_path):
    os.makedirs(models_path)

In [7]:
previous_models_path = path + 'my_folder/models/20250109_03/'

In [8]:
lgb_params_df = pd.read_csv(previous_models_path + 'lgb_params.csv')
lgb_params_df

Unnamed: 0,val_window_size,training_window_size,fraction,learning_rate,max_depth,min_data_in_leaf,num_leaves,min_gain_to_split,lambda_l1,lambda_l2,feature_fraction
0,60,618,0.152297,0.04979,18,102,8995,0.237934,7.797875,1046.995871,0.94769


In [12]:
def lgb_sliding_window(train_data):#, optuna_n_trials):

    unique_date_ids = sorted(train_data['date_id'].unique())
    #unique_date_ids = [i for i in range(1110, train_data['date_id'].max()+1)]
    #date_ids_df = train_data['date_id'].to_frame()
    sample_df = train_data.sample(fraction=lgb_params_df['fraction'][0])
    unique_date_ids = sorted(sample_df['date_id'].unique())
    print(len(unique_date_ids))

    for date_id in unique_date_ids:
        val_window_size = lgb_params_df['val_window_size'][0]
        training_window_size = lgb_params_df['training_window_size'][0]
        #fraction = trial.suggest_float('fraction', 0.05, 0.3)
        #sample_df = train_data.sample(fraction=fraction)
        #unique_date_ids = [i for i in range(1110, sample_df['date_id'].max()+1)]
        #date_id = random.choice(unique_date_ids)

        test_date_id_df = sample_df.filter(pl.col('date_id') == date_id)

        val_date_id_cut_lower = date_id - val_window_size
        if val_date_id_cut_lower < 0:
            val_window_df = sample_df.filter(pl.col('date_id') <= val_window_size)
            val_window_df = val_window_df.join(test_date_id_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
        else:
            val_window_df = sample_df.filter((pl.col('date_id') >= val_date_id_cut_lower)&(pl.col('date_id') < date_id))
        
        training_date_id_cut_lower = val_date_id_cut_lower - training_window_size
        if training_date_id_cut_lower < 0:
            training_window_df = sample_df.filter(pl.col('date_id') <= val_window_size + training_window_size)
            training_window_df = training_window_df.join(test_date_id_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
            training_window_df = training_window_df.join(val_window_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
        else:
            training_window_df = sample_df.filter((pl.col('date_id') >= training_date_id_cut_lower)&(pl.col('date_id') < val_date_id_cut_lower))

        
        '''
        test_date_id_df = sample_df.filter(pl.col('date_id') == date_id)

        val_date_id_cut_lower = date_id - val_window_size
        val_window_df = sample_df.filter((pl.col('date_id') >= val_date_id_cut_lower)&(pl.col('date_id') < date_id))#.sample(fraction=fraction)

        training_date_id_cut_lower = val_date_id_cut_lower - training_window_size
        training_window_df = sample_df.filter((pl.col('date_id') >= training_date_id_cut_lower)&(pl.col('date_id') < val_date_id_cut_lower))#.sample(fraction=fraction).sort(by=['date_id', 'time_id', 'symbol_id'])
        '''
        
        '''
        print(date_id)
        print('this is training_window_df')
        display(training_window_df)
        print('training_window_size:', training_window_size)
        print('n unique in training_window_df:', training_window_df['date_id'].n_unique())

        print('this is val_window_df')
        display(val_window_df)
        print('val_window_size:', val_window_size)
        print('n unique in training_window_df:', val_window_df['date_id'].n_unique())

        print('this is test_date_id_df')
        display(test_date_id_df)
        print('n unique in test_date_id_df:', test_date_id_df['date_id'].n_unique())
        '''

        #training_window_df = training_window_df.sample(fraction=fraction)
                

        '''
        print('this is window_df')
        display(window_df)
        print('this is date_id_df')
        display(date_id_df)
        '''
    
        base_params = {
            'verbosity': -1,
            'device': 'gpu',
            'early_stopping_round': 20,
        }
    
        params_to_tune = {
            'learning_rate': lgb_params_df['learning_rate'][0],
            'max_depth': lgb_params_df['max_depth'][0],
            'min_data_in_leaf': lgb_params_df['min_data_in_leaf'][0],
            'num_leaves': lgb_params_df['num_leaves'][0],
            'min_gain_to_split': lgb_params_df['min_gain_to_split'][0],
            'lambda_l1': lgb_params_df['lambda_l1'][0],
            'lambda_l2': lgb_params_df['lambda_l2'][0],
            'feature_fraction': lgb_params_df['feature_fraction'][0],
        }
    
        model = LGBMRegressor(
            **base_params,
            **params_to_tune,
            n_estimators=100000
        )
    
        X_train = training_window_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
        X_val = val_window_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
        X_test = test_date_id_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
    
        y_train = training_window_df['responder_6'].to_pandas()
        y_val = val_window_df['responder_6'].to_pandas()
        y_test = test_date_id_df['responder_6'].to_pandas()
    
        weights_train = training_window_df['weight'].to_pandas()
        weights_val = val_window_df['weight'].to_pandas()
        weights_test = test_date_id_df['weight'].to_pandas()
    
        model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val])#, callbacks=[log_evaluation(period=10)])
        
        test_preds = model.predict(X_test)

        test_score = r2_score(y_test, test_preds, sample_weight=weights_test)

        print('date_id is:', date_id)
        print('Test Weighted R2 score is:', test_score)

In [13]:
lgb_sliding_window(train_df)#, 300)

1699
date_id is: 0
Test Weighted R2 score is: -0.01813452461849452
date_id is: 1
Test Weighted R2 score is: 0.021921658541187772
date_id is: 2
Test Weighted R2 score is: -0.01900653235990246
date_id is: 3
Test Weighted R2 score is: -0.013140583556615404
date_id is: 4
Test Weighted R2 score is: 0.02271487988410248
date_id is: 5
Test Weighted R2 score is: 0.0264420889255238
date_id is: 6
Test Weighted R2 score is: 0.031511946097709354
date_id is: 7
Test Weighted R2 score is: -0.01184603153920083
date_id is: 8
Test Weighted R2 score is: -0.019137489701575694
date_id is: 9
Test Weighted R2 score is: -0.005798053224977062
date_id is: 10
Test Weighted R2 score is: 0.002475282178731053
date_id is: 11
Test Weighted R2 score is: 0.01447552228673521
date_id is: 12
Test Weighted R2 score is: 0.021866309699152375
date_id is: 13
Test Weighted R2 score is: 0.007273602883855035
date_id is: 14
Test Weighted R2 score is: 0.045534937331239034
date_id is: 15
Test Weighted R2 score is: 0.00418688038349135

KeyboardInterrupt: 

In [None]:
for param in lgb_study.best_params.keys():
    fig = plot_slice(lgb_study, params=[param])
    fig.show()

In [None]:
plot_param_importances(lgb_study)

In [None]:
lgb_study.best_params

In [None]:
lgb_study.best_value

In [None]:
for k, v in lgb_study.best_params.items():
    print(k, v)

In [None]:
lgb_params_df = pd.DataFrame({k:[v] for k, v in lgb_study.best_params.items()})

In [None]:
lgb_params_df

In [None]:
lgb_params_df.to_csv(models_path + 'lgb_params.csv', index=False)