In [1]:
import pandas as pd
import polars as pl
import numpy as np
import os
import gc
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import KFold, StratifiedKFold
import xgboost as xgb
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor, log_evaluation, record_evaluation
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
#from sklearn.impute import IterativeImputer
import pickle
import optuna
from optuna.visualization import plot_slice, plot_param_importances
import shap
import random

gc.enable()

pd.options.display.max_columns = None
#pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', None)

#pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_str_lengths(10000)
optuna.logging.set_verbosity(optuna.logging.WARNING)

In [2]:
path = 'I:/Kaggle/jane-street-real-time-market-data-forecasting/'

In [3]:
os.listdir(path)

['features.csv',
 'imputed_train_ffill.parquet',
 'kaggle_evaluation',
 'lags.parquet',
 'my_folder',
 'responders.csv',
 'sample_submission.csv',
 'team_folder',
 'test.parquet',
 'top_100000_rows_sorted_by_weight_descending.parquet',
 'top_10000_rows_sorted_by_weight_descending.parquet',
 'train.parquet']

In [12]:
train_df = pl.read_parquet(path + 'train.parquet/').drop(['responder_0', 'responder_1', 'responder_2', 'responder_3', 'responder_4', 'responder_5', 'responder_7', 'responder_8', 'partition_id']).select(pl.all().shrink_dtype())
print(train_df.shape)
train_df

(47127338, 84)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,,,,,,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,,-0.169586,,-1.335938,-1.707803,0.91013,,1.636431,1.522133,-1.551398,-0.229627,,,1.378301,-0.283712,0.123196,,,,0.28118,0.269163,0.349028,-0.012596,-0.225932,,-1.073602,,,-0.181716,,,,0.564021,2.088506,0.832022,,0.204797,,,-0.808103,,-2.037683,0.727661,,-0.989118,-0.345213,-1.36224,,,,,,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,,,-0.261412,-0.211486,-0.335556,-0.281498,0.775981
0,0,7,1.370613,,,,,,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,,0.317467,,-1.250016,-1.682929,1.412757,,0.520378,0.744132,-0.788658,0.641776,,,0.2272,0.580907,1.128879,,,,-1.512286,-1.414357,-1.823322,-0.082763,-0.184119,,,,,,,,,-10.835207,-0.002704,-0.621836,,1.172836,,,-1.625862,,-1.410017,1.063013,,0.888355,0.467994,-1.36224,,,,,,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,,,-0.281207,-0.182894,-0.245565,-0.302441,0.703665
0,0,9,2.285698,,,,,,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,,-0.814909,,-1.296782,-2.040234,0.639589,,1.597359,0.657514,-1.350148,0.364215,,,-0.017751,-0.317361,-0.122379,,,,-0.320921,-0.95809,-2.436589,0.070999,-0.245239,,,,,,,,,-1.420632,-3.515137,-4.67776,,0.535897,,,-0.72542,,-2.29417,1.764551,,-0.120789,-0.063458,-1.36224,,,,,,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,,,0.377131,0.300724,-0.106842,-0.096792,2.109352
0,0,10,0.690606,,,,,,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,,-0.251882,,-1.902009,-0.979447,0.241165,,-0.392359,-0.224699,-2.129397,-0.855287,,,0.404142,-0.578156,0.105702,,,,0.544138,-0.087091,-1.500147,-0.201288,-0.038042,,,,,,,,,0.382074,2.669135,0.611711,,2.413415,,,1.313203,,-0.810125,2.939022,,3.988801,1.834661,-1.36224,,,,,,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,,,-0.226891,-0.251412,-0.215522,-0.296244,1.114137
0,0,14,0.44057,,,,,,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,,0.646086,,-1.844685,-1.58656,-0.182024,,-0.969949,-0.673813,-1.282132,-1.399894,,,0.043815,-0.320225,-0.031713,,,,-0.08842,-0.995003,-2.635336,-0.196461,-0.618719,,,,,,,,,-2.0146,-2.321076,-3.711265,,1.253902,,,0.476195,,-0.771732,2.843421,,1.379815,0.411827,-1.36224,,,,,,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,,,3.678076,2.793581,2.61825,3.418133,-3.57282
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1698,967,34,3.242493,2.52516,-0.721981,2.544025,2.477615,0.417557,0.785812,1.117796,2.199436,0.415427,42,5,150,0.804403,1.157257,1.031543,-0.671189,-0.3286,-0.486132,1.730176,-0.006173,-0.001144,-0.213062,0.932618,1.367338,-0.238197,-0.692615,-0.121163,1.090798,1.444294,-0.675626,-1.013264,-0.242888,3.427639,-0.958278,3.139836,3.416278,-1.655316,-0.59944,-0.932876,2.493458,0.969462,1.102016,0.158982,-0.496177,0.036177,1.309866,0.828025,1.577955,1.040802,1.255398,2.577441,0.057455,0.953005,1.377051,-0.396358,0.520262,1.179617,1.127657,2.231928,0.614652,2.412886,-1.101531,-0.384833,-0.275818,-0.40804,2.427115,-0.108427,0.739734,0.830205,0.366287,1.33325,1.075499,1.798264,-0.183443,-0.190222,0.234211,0.347142,-0.044463,0.016936,-0.132337
1698,967,35,1.079139,1.857906,-0.790646,2.745439,2.339877,0.845065,0.65137,1.180301,1.966379,0.321543,25,7,195,-0.075294,-0.152726,-0.20417,-0.421137,0.21708,-0.258775,1.874978,0.19988,-0.199219,-0.125619,-1.004547,-0.051933,0.450905,0.009246,0.164127,-0.939974,-1.143421,-0.320071,-0.379835,-0.142429,3.862469,-1.451786,3.477489,2.861663,0.763459,0.075972,-0.119677,0.626035,0.148815,0.653281,0.059313,-0.845099,0.098528,0.409564,-0.675728,-0.011334,0.930534,0.83198,0.808955,0.219276,-0.315776,0.687755,-1.189577,0.180146,-0.175486,-1.60435,-0.209283,0.249847,0.288816,-1.101531,-0.343868,-0.253991,-0.278832,2.050639,-0.059506,-0.029396,-0.101381,-0.187759,-0.180839,-0.0861,-0.153405,-0.196077,-0.175292,1.04578,0.739733,0.03372,0.05086,-0.249584
1698,967,36,1.033172,2.515527,-0.672298,2.28925,2.521592,0.255077,0.919892,1.172018,2.180496,0.24846,49,7,297,1.026715,-0.096892,0.224309,-0.528109,-0.704952,-0.704818,2.312482,0.32804,-0.108193,,-0.945684,-0.244173,0.205989,-0.357343,,,-1.11075,-0.580242,-0.400568,,2.397877,-0.637258,3.260638,3.046786,0.440965,0.234842,-0.17558,1.022406,-0.500069,2.071033,0.413488,-0.450016,-0.156616,-0.253755,-0.769588,0.066086,0.047826,1.713707,0.772772,-0.549192,1.338474,0.933568,0.032978,-0.519118,-0.290343,-0.806786,0.106295,0.183461,1.830421,-1.101531,-0.341991,-0.249132,-0.34365,2.251358,0.601888,1.035051,-0.283241,0.107244,0.86016,0.024223,0.374852,-0.220933,-0.161584,0.032771,0.036888,0.168908,0.152333,-0.065355
1698,967,37,1.243116,2.663298,-0.889112,2.313155,3.101428,0.324454,0.618944,1.185663,1.599724,0.319719,34,4,214,0.759314,0.284057,0.41716,-0.611075,-0.513717,-0.891423,1.84994,0.406756,-1.608196,-0.252663,-0.271574,-0.051405,0.098146,-0.653961,0.173676,-0.016497,-0.404509,-0.577262,-0.731429,-0.21646,3.018564,-0.472061,3.13922,3.065858,0.842925,0.053283,-0.074403,0.500129,0.08263,0.336223,0.643934,-0.422367,-0.418195,0.203037,-0.702278,0.543305,-0.195764,0.693364,0.953293,0.352567,0.471775,1.876459,-0.143377,0.845516,0.301135,-0.395703,0.738038,-0.04124,1.270645,-1.101531,-0.358106,-0.141883,-0.255192,2.489247,0.537652,0.982107,-0.158009,0.137389,0.478357,0.782692,0.581421,-0.106056,-0.111017,0.163867,0.169331,-0.037563,-0.029483,-0.148711


In [5]:
feature_cols = [col for col in train_df.columns if 'feature' in col]
feature_cols

['feature_00',
 'feature_01',
 'feature_02',
 'feature_03',
 'feature_04',
 'feature_05',
 'feature_06',
 'feature_07',
 'feature_08',
 'feature_09',
 'feature_10',
 'feature_11',
 'feature_12',
 'feature_13',
 'feature_14',
 'feature_15',
 'feature_16',
 'feature_17',
 'feature_18',
 'feature_19',
 'feature_20',
 'feature_21',
 'feature_22',
 'feature_23',
 'feature_24',
 'feature_25',
 'feature_26',
 'feature_27',
 'feature_28',
 'feature_29',
 'feature_30',
 'feature_31',
 'feature_32',
 'feature_33',
 'feature_34',
 'feature_35',
 'feature_36',
 'feature_37',
 'feature_38',
 'feature_39',
 'feature_40',
 'feature_41',
 'feature_42',
 'feature_43',
 'feature_44',
 'feature_45',
 'feature_46',
 'feature_47',
 'feature_48',
 'feature_49',
 'feature_50',
 'feature_51',
 'feature_52',
 'feature_53',
 'feature_54',
 'feature_55',
 'feature_56',
 'feature_57',
 'feature_58',
 'feature_59',
 'feature_60',
 'feature_61',
 'feature_62',
 'feature_63',
 'feature_64',
 'feature_65',
 'feature_

In [6]:
train_scan = pl.scan_parquet(path + 'train.parquet/')
test_scan = pl.scan_parquet(path + 'test.parquet/')

In [7]:
train_symbol_ids_list = sorted(train_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
test_symbol_ids_list = sorted(test_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
unique_symbol_ids_list = sorted(list(set(train_symbol_ids_list + test_symbol_ids_list)))
unique_symbol_ids_list

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38]

In [10]:
def bfill_impute(df):
    df = df.clone()
    
    results = []
    for symbol_id in tqdm(unique_symbol_ids_list):
        symbol_id_df = df.filter(pl.col("symbol_id") == symbol_id)
        #display(symbol_id_df)
        symbol_id_df = symbol_id_df.fill_null(strategy="forward")
        #display(symbol_id_df)
        symbol_id_df = symbol_id_df.fill_null(strategy="backward")
        results.append(symbol_id_df)

        #return

    # Combine all backfilled DataFrames
    result_df = pl.concat(results).sort(by=['date_id', 'time_id', 'symbol_id'])
    return result_df

In [13]:
train_df = bfill_impute(train_df)

  0%|                                                                                           | 0/39 [00:00<?, ?it/s]

date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
1,0,0,1.749479,,,,,,0.053447,2.192887,1.160708,-0.00983,11,7,76,-0.819115,1.403962,-0.098782,,0.340309,,-1.577595,-1.237946,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,1.48572,0.760141,2.145851,-0.021399,-0.016591,,,,,,,,,3.88576,7.287179,4.877959,,2.20351,,,-0.850504,,-1.543995,1.741718,,2.035935,0.892778,-1.081097,,,,,,-0.734145,0.276101,-0.576441,-0.799783,4.812901,0.393107,,,-0.226359,-0.327119,-0.31504,-0.385573,2.337418
1,1,0,1.749479,,,,,,-0.029047,0.847413,0.769509,0.009426,11,7,76,-0.67729,2.703819,0.307246,,0.361542,,-1.508338,-1.414294,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,1.272551,0.986756,2.769935,0.028861,0.036394,,,,,,,,,3.062142,1.634271,3.670358,,1.121471,,,-1.338483,,-1.557803,1.283185,,-0.10197,0.154968,-1.081097,,,,,,-0.828776,3.034146,0.070293,-0.49952,3.384181,0.708746,,,-0.286174,-0.187674,-0.354506,-0.302946,2.492198
1,2,0,1.749479,,,,,,0.017732,0.727187,1.153705,-0.0056,11,7,76,-0.587609,3.081821,0.2564,,0.19876,,-1.091023,-1.343286,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.09876,1.217269,2.32622,0.054985,0.054874,,,,,,,,,4.411406,1.200419,3.279326,,1.229474,,,-2.004817,,-1.917212,0.923784,,-0.364484,0.043953,-1.081097,,,,,,-0.528249,1.561701,0.044169,-0.697554,2.512118,0.71579,,,-0.32492,-0.306634,-0.342315,-0.33082,1.993902
1,3,0,1.749479,,,,,,0.247528,1.072752,1.534643,-0.041765,11,7,76,-0.658146,1.82926,0.238465,,0.01081,,-1.370826,-1.561505,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.427911,1.390987,2.505615,0.059815,0.037799,,,,,,,,,2.351531,1.33464,2.359366,,0.967031,,,0.252059,,-1.972575,0.864815,,-0.252517,0.055876,-1.081097,,,,,,-0.720327,1.186939,0.029493,-0.728824,2.11122,0.556564,,,-0.242646,-0.302316,-0.253091,-0.313789,1.864082
1,4,0,1.749479,,,,,,-0.024495,-0.506043,0.517926,0.006077,11,7,76,-0.984723,0.849526,0.356314,,-0.123173,-0.03724,-1.412314,-1.341913,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.252309,1.170764,2.555653,0.0561,0.056733,,,,,,,,,2.886509,0.777438,3.265261,,-0.935642,,,0.216384,,-1.541225,0.946789,,-0.191459,0.046544,-1.081097,,,,,,-0.618408,0.875253,-0.00241,-0.783194,1.114563,0.517224,,,-0.294731,-0.211305,-0.36654,-0.224103,2.604931
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1698,963,0,3.743546,2.568135,-0.793953,3.069358,2.539425,0.889981,0.319947,0.116308,1.438311,0.195987,11,7,76,0.336648,0.064829,-0.053682,-0.856546,-0.85187,-0.67437,2.034368,-0.18008,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.446697,0.422791,3.657838,2.798806,0.985844,-0.039476,-0.081274,0.848423,-0.162918,-0.259538,0.853406,0.024779,-0.412884,-0.014181,-0.188477,0.337806,-0.03369,0.442706,1.06784,-0.694643,-1.275397,1.45678,0.511287,1.454648,-0.338394,0.026483,0.489739,0.033094,0.557365,-1.101531,-0.548784,-0.392891,-0.386612,2.12372,0.750058,0.314409,-0.15676,-0.240713,0.705058,0.321062,0.211622,-0.164847,-0.223636,-0.225986,-0.128775,-0.051006,-0.138405,0.006146
1698,964,0,3.743546,2.718452,-0.511277,2.608628,2.033411,0.387283,0.340792,0.109136,0.728542,0.281848,11,7,76,0.458838,0.021207,-0.082063,-1.183542,-0.851339,-0.782223,1.545191,0.094937,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.677647,0.332351,3.497165,2.811974,0.713854,-0.038192,-0.091909,0.423664,-0.117968,-0.036031,0.821478,-0.33453,-0.106793,-0.292852,-0.575695,0.459095,-0.468565,0.350401,0.699706,-0.379422,-0.994081,1.62243,0.539615,1.356119,0.465958,0.310901,0.446288,-0.083168,0.426014,-1.101531,-0.342624,-0.251435,-0.485331,1.970657,0.310639,0.32353,-0.177152,-0.241838,0.50289,0.279055,0.216196,-0.133364,-0.220455,-0.180748,-0.125245,-0.112044,-0.122066,-0.142388
1698,965,0,3.743546,2.227324,-1.210031,2.470631,2.569757,0.842342,0.56668,0.739505,1.696227,0.207651,11,7,76,0.525306,-0.153095,-0.097157,-0.945363,-0.758521,-0.814847,2.055125,0.509254,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.261821,0.437572,2.817772,3.30254,0.062859,-0.039016,-0.078513,0.201333,-0.10028,-0.459648,0.226261,-0.473897,-0.492614,0.338202,0.061678,0.372418,-0.043383,0.343548,0.56,-0.291057,-0.360662,1.861143,0.398721,1.400427,-0.011518,-0.288103,0.389545,0.153374,0.491404,-1.101531,-0.493324,-0.380566,-0.492167,1.660286,0.7646,0.301674,-0.202914,-0.210503,0.417254,0.038902,0.147303,-0.204672,-0.225959,-0.276943,-0.200029,-0.161556,-0.087569,-0.23691
1698,966,0,3.743546,2.294538,-1.234424,1.828253,2.229581,0.048254,0.617474,0.645901,1.696628,0.233323,11,7,76,0.35473,0.056575,-0.038474,-0.695993,-0.769787,-0.665899,2.26081,-0.044641,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.583516,0.274626,2.908904,3.163464,-1.332066,-0.061103,-0.104068,0.848918,-0.239616,-0.610005,0.417844,-0.019155,0.285854,0.217884,0.165823,0.31112,-0.033976,0.314276,1.274701,-0.380444,-0.577036,1.675805,0.526321,1.285819,0.067032,0.08497,0.493787,0.35573,0.852465,-1.101531,-0.557998,-0.299636,-0.391958,2.011111,0.332355,0.325992,-0.148839,-0.254832,0.411193,0.385642,0.275157,-0.190027,-0.151191,-0.260749,-0.252218,-0.164368,-0.121943,-0.067162


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
1,0,0,1.749479,,,,,,0.053447,2.192887,1.160708,-0.00983,11,7,76,-0.819115,1.403962,-0.098782,,0.340309,,-1.577595,-1.237946,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,1.48572,0.760141,2.145851,-0.021399,-0.016591,,,,,,,,,3.88576,7.287179,4.877959,,2.20351,,,-0.850504,,-1.543995,1.741718,,2.035935,0.892778,-1.081097,,,,,,-0.734145,0.276101,-0.576441,-0.799783,4.812901,0.393107,,,-0.226359,-0.327119,-0.31504,-0.385573,2.337418
1,1,0,1.749479,,,,,,-0.029047,0.847413,0.769509,0.009426,11,7,76,-0.67729,2.703819,0.307246,,0.361542,,-1.508338,-1.414294,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,1.272551,0.986756,2.769935,0.028861,0.036394,,,,,,,,,3.062142,1.634271,3.670358,,1.121471,,,-1.338483,,-1.557803,1.283185,,-0.10197,0.154968,-1.081097,,,,,,-0.828776,3.034146,0.070293,-0.49952,3.384181,0.708746,,,-0.286174,-0.187674,-0.354506,-0.302946,2.492198
1,2,0,1.749479,,,,,,0.017732,0.727187,1.153705,-0.0056,11,7,76,-0.587609,3.081821,0.2564,,0.19876,,-1.091023,-1.343286,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.09876,1.217269,2.32622,0.054985,0.054874,,,,,,,,,4.411406,1.200419,3.279326,,1.229474,,,-2.004817,,-1.917212,0.923784,,-0.364484,0.043953,-1.081097,,,,,,-0.528249,1.561701,0.044169,-0.697554,2.512118,0.71579,,,-0.32492,-0.306634,-0.342315,-0.33082,1.993902
1,3,0,1.749479,,,,,,0.247528,1.072752,1.534643,-0.041765,11,7,76,-0.658146,1.82926,0.238465,,0.01081,,-1.370826,-1.561505,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.427911,1.390987,2.505615,0.059815,0.037799,,,,,,,,,2.351531,1.33464,2.359366,,0.967031,,,0.252059,,-1.972575,0.864815,,-0.252517,0.055876,-1.081097,,,,,,-0.720327,1.186939,0.029493,-0.728824,2.11122,0.556564,,,-0.242646,-0.302316,-0.253091,-0.313789,1.864082
1,4,0,1.749479,,,,,,-0.024495,-0.506043,0.517926,0.006077,11,7,76,-0.984723,0.849526,0.356314,,-0.123173,-0.03724,-1.412314,-1.341913,1.039095,,0.756488,0.432157,-1.626238,-0.137137,,,0.952274,-0.186008,0.127051,,,,0.252309,1.170764,2.555653,0.0561,0.056733,,,,,,,,,2.886509,0.777438,3.265261,,-0.935642,,,0.216384,,-1.541225,0.946789,,-0.191459,0.046544,-1.081097,,,,,,-0.618408,0.875253,-0.00241,-0.783194,1.114563,0.517224,,,-0.294731,-0.211305,-0.36654,-0.224103,2.604931
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1698,963,0,3.743546,2.568135,-0.793953,3.069358,2.539425,0.889981,0.319947,0.116308,1.438311,0.195987,11,7,76,0.336648,0.064829,-0.053682,-0.856546,-0.85187,-0.67437,2.034368,-0.18008,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.446697,0.422791,3.657838,2.798806,0.985844,-0.039476,-0.081274,0.848423,-0.162918,-0.259538,0.853406,0.024779,-0.412884,-0.014181,-0.188477,0.337806,-0.03369,0.442706,1.06784,-0.694643,-1.275397,1.45678,0.511287,1.454648,-0.338394,0.026483,0.489739,0.033094,0.557365,-1.101531,-0.548784,-0.392891,-0.386612,2.12372,0.750058,0.314409,-0.15676,-0.240713,0.705058,0.321062,0.211622,-0.164847,-0.223636,-0.225986,-0.128775,-0.051006,-0.138405,0.006146
1698,964,0,3.743546,2.718452,-0.511277,2.608628,2.033411,0.387283,0.340792,0.109136,0.728542,0.281848,11,7,76,0.458838,0.021207,-0.082063,-1.183542,-0.851339,-0.782223,1.545191,0.094937,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.677647,0.332351,3.497165,2.811974,0.713854,-0.038192,-0.091909,0.423664,-0.117968,-0.036031,0.821478,-0.33453,-0.106793,-0.292852,-0.575695,0.459095,-0.468565,0.350401,0.699706,-0.379422,-0.994081,1.62243,0.539615,1.356119,0.465958,0.310901,0.446288,-0.083168,0.426014,-1.101531,-0.342624,-0.251435,-0.485331,1.970657,0.310639,0.32353,-0.177152,-0.241838,0.50289,0.279055,0.216196,-0.133364,-0.220455,-0.180748,-0.125245,-0.112044,-0.122066,-0.142388
1698,965,0,3.743546,2.227324,-1.210031,2.470631,2.569757,0.842342,0.56668,0.739505,1.696227,0.207651,11,7,76,0.525306,-0.153095,-0.097157,-0.945363,-0.758521,-0.814847,2.055125,0.509254,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.261821,0.437572,2.817772,3.30254,0.062859,-0.039016,-0.078513,0.201333,-0.10028,-0.459648,0.226261,-0.473897,-0.492614,0.338202,0.061678,0.372418,-0.043383,0.343548,0.56,-0.291057,-0.360662,1.861143,0.398721,1.400427,-0.011518,-0.288103,0.389545,0.153374,0.491404,-1.101531,-0.493324,-0.380566,-0.492167,1.660286,0.7646,0.301674,-0.202914,-0.210503,0.417254,0.038902,0.147303,-0.204672,-0.225959,-0.276943,-0.200029,-0.161556,-0.087569,-0.23691
1698,966,0,3.743546,2.294538,-1.234424,1.828253,2.229581,0.048254,0.617474,0.645901,1.696628,0.233323,11,7,76,0.35473,0.056575,-0.038474,-0.695993,-0.769787,-0.665899,2.26081,-0.044641,1.515887,-0.202102,1.285443,0.787999,1.921844,1.28118,1.223592,1.002706,0.587736,-0.643336,-0.943377,-0.148149,3.583516,0.274626,2.908904,3.163464,-1.332066,-0.061103,-0.104068,0.848918,-0.239616,-0.610005,0.417844,-0.019155,0.285854,0.217884,0.165823,0.31112,-0.033976,0.314276,1.274701,-0.380444,-0.577036,1.675805,0.526321,1.285819,0.067032,0.08497,0.493787,0.35573,0.852465,-1.101531,-0.557998,-0.299636,-0.391958,2.011111,0.332355,0.325992,-0.148839,-0.254832,0.411193,0.385642,0.275157,-0.190027,-0.151191,-0.260749,-0.252218,-0.164368,-0.121943,-0.067162


  0%|                                                                                           | 0/39 [00:00<?, ?it/s]


In [None]:
train_df

In [None]:
train_df.null_count().sum_horizontal()

In [None]:
def get_medians(df):
    medians_dict = {}
    for symbol_id in tqdm(unique_symbol_ids_list):
        symbol_id_df = df.filter(pl.col('symbol_id') == symbol_id)
        #for col in feature_cols:
        medians_dict[f'{symbol_id}'] = {col:symbol_id_df[col].median() for col in feature_cols}

    return medians_dict

In [None]:
train_medians = get_medians(train_df)

In [None]:
train_medians

In [None]:
train_df.estimated_size() / 1e9

In [None]:
models_path = path + 'my_folder/models/20250110_02/'
if not os.path.exists(models_path):
    os.makedirs(models_path)

In [None]:
train_df.write_parquet(models_path + 'imputed_train_ffill.parquet')

In [None]:
previous_models_path = path + 'my_folder/models/20250109_03/'

In [None]:
lgb_params_df = pd.read_csv(previous_models_path + 'lgb_params.csv')
lgb_params_df

In [None]:
def lgb_sliding_window(train_data):#, optuna_n_trials):

    unique_date_ids = sorted(train_data['date_id'].unique())
    #unique_date_ids = [i for i in range(1110, train_data['date_id'].max()+1)]
    #date_ids_df = train_data['date_id'].to_frame()
    sample_df = train_data.sample(fraction=lgb_params_df['fraction'][0])
    unique_date_ids = sorted(sample_df['date_id'].unique())
    print(len(unique_date_ids))

    for date_id in unique_date_ids:
        val_window_size = lgb_params_df['val_window_size'][0]
        training_window_size = lgb_params_df['training_window_size'][0]
        #fraction = trial.suggest_float('fraction', 0.05, 0.3)
        #sample_df = train_data.sample(fraction=fraction)
        #unique_date_ids = [i for i in range(1110, sample_df['date_id'].max()+1)]
        #date_id = random.choice(unique_date_ids)

        test_date_id_df = sample_df.filter(pl.col('date_id') == date_id)

        val_date_id_cut_lower = date_id - val_window_size
        if val_date_id_cut_lower < 0:
            val_window_df = sample_df.filter(pl.col('date_id') <= val_window_size)
            val_window_df = val_window_df.join(test_date_id_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
        else:
            val_window_df = sample_df.filter((pl.col('date_id') >= val_date_id_cut_lower)&(pl.col('date_id') < date_id))
        
        training_date_id_cut_lower = val_date_id_cut_lower - training_window_size
        if training_date_id_cut_lower < 0:
            training_window_df = sample_df.filter(pl.col('date_id') <= val_window_size + training_window_size)
            training_window_df = training_window_df.join(test_date_id_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
            training_window_df = training_window_df.join(val_window_df, on=['date_id', 'time_id', 'symbol_id'], how='anti')
        else:
            training_window_df = sample_df.filter((pl.col('date_id') >= training_date_id_cut_lower)&(pl.col('date_id') < val_date_id_cut_lower))

        
        '''
        test_date_id_df = sample_df.filter(pl.col('date_id') == date_id)

        val_date_id_cut_lower = date_id - val_window_size
        val_window_df = sample_df.filter((pl.col('date_id') >= val_date_id_cut_lower)&(pl.col('date_id') < date_id))#.sample(fraction=fraction)

        training_date_id_cut_lower = val_date_id_cut_lower - training_window_size
        training_window_df = sample_df.filter((pl.col('date_id') >= training_date_id_cut_lower)&(pl.col('date_id') < val_date_id_cut_lower))#.sample(fraction=fraction).sort(by=['date_id', 'time_id', 'symbol_id'])
        '''
        
        '''
        print(date_id)
        print('this is training_window_df')
        display(training_window_df)
        print('training_window_size:', training_window_size)
        print('n unique in training_window_df:', training_window_df['date_id'].n_unique())

        print('this is val_window_df')
        display(val_window_df)
        print('val_window_size:', val_window_size)
        print('n unique in training_window_df:', val_window_df['date_id'].n_unique())

        print('this is test_date_id_df')
        display(test_date_id_df)
        print('n unique in test_date_id_df:', test_date_id_df['date_id'].n_unique())
        '''

        #training_window_df = training_window_df.sample(fraction=fraction)
                

        '''
        print('this is window_df')
        display(window_df)
        print('this is date_id_df')
        display(date_id_df)
        '''
    
        base_params = {
            'verbosity': -1,
            'device': 'gpu',
            'early_stopping_round': 20,
        }
    
        params_to_tune = {
            'learning_rate': lgb_params_df['learning_rate'][0],
            'max_depth': lgb_params_df['max_depth'][0],
            'min_data_in_leaf': lgb_params_df['min_data_in_leaf'][0],
            'num_leaves': lgb_params_df['num_leaves'][0],
            'min_gain_to_split': lgb_params_df['min_gain_to_split'][0],
            'lambda_l1': lgb_params_df['lambda_l1'][0],
            'lambda_l2': lgb_params_df['lambda_l2'][0],
            'feature_fraction': lgb_params_df['feature_fraction'][0],
        }
    
        model = LGBMRegressor(
            **base_params,
            **params_to_tune,
            n_estimators=100000
        )
    
        X_train = training_window_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
        X_val = val_window_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
        X_test = test_date_id_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
    
        y_train = training_window_df['responder_6'].to_pandas()
        y_val = val_window_df['responder_6'].to_pandas()
        y_test = test_date_id_df['responder_6'].to_pandas()
    
        weights_train = training_window_df['weight'].to_pandas()
        weights_val = val_window_df['weight'].to_pandas()
        weights_test = test_date_id_df['weight'].to_pandas()
    
        model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val])#, callbacks=[log_evaluation(period=10)])
        
        test_preds = model.predict(X_test)

        test_score = r2_score(y_test, test_preds, sample_weight=weights_test)

        print('date_id is:', date_id)
        print('Test Weighted R2 score is:', test_score)

In [None]:
lgb_sliding_window(train_df)#, 300)

In [None]:
for param in lgb_study.best_params.keys():
    fig = plot_slice(lgb_study, params=[param])
    fig.show()

In [None]:
plot_param_importances(lgb_study)

In [None]:
lgb_study.best_params

In [None]:
lgb_study.best_value

In [None]:
for k, v in lgb_study.best_params.items():
    print(k, v)

In [None]:
lgb_params_df = pd.DataFrame({k:[v] for k, v in lgb_study.best_params.items()})

In [None]:
lgb_params_df

In [None]:
lgb_params_df.to_csv(models_path + 'lgb_params.csv', index=False)