In [1]:
import pandas as pd
import polars as pl
import numpy as np
import os
import gc
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import KFold, StratifiedKFold
import xgboost as xgb
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor, log_evaluation, record_evaluation
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
#from sklearn.impute import IterativeImputer
import pickle
import optuna
import shap

gc.enable()

pd.options.display.max_columns = None
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', None)

pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_str_lengths(10000)

polars.config.Config

In [2]:
path = 'I:/Kaggle/jane-street-real-time-market-data-forecasting/'

In [3]:
os.listdir(path)

['features.csv',
 'kaggle_evaluation',
 'lags.parquet',
 'my_folder',
 'responders.csv',
 'sample_submission.csv',
 'team_folder',
 'test.parquet',
 'train.parquet']

In [4]:
train_df = pl.read_parquet(path + 'train.parquet/').select(pl.all().shrink_dtype())
lags_df = train_df.with_columns(pl.col('date_id') + 1).drop(['weight', 'partition_id'] + [col for col in train_df.columns if 'feature' in col]).rename({f'responder_{x}': f'responder_{x}_lag_1' for x in range(9)})
train_df = train_df.drop(['responder_0', 'responder_1', 'responder_2', 'responder_3', 'responder_4', 'responder_5', 'responder_7', 'responder_8', 'partition_id']).select(pl.all().shrink_dtype())
train_df = train_df.join(lags_df, on=['date_id', 'time_id', 'symbol_id'], how='left').select(pl.all().shrink_dtype())
del lags_df
gc.collect()
print(train_df.shape)
train_df.head()

(47127338, 93)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,,,,,,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,,-0.169586,,-1.335938,-1.707803,0.91013,,1.636431,1.522133,-1.551398,-0.229627,,,1.378301,-0.283712,0.123196,,,,0.28118,0.269163,0.349028,-0.012596,-0.225932,,-1.073602,,,-0.181716,,,,0.564021,2.088506,0.832022,,0.204797,,,-0.808103,,-2.037683,0.727661,,-0.989118,-0.345213,-1.36224,,,,,,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,,,-0.261412,-0.211486,-0.335556,-0.281498,0.775981,,,,,,,,,
0,0,7,1.370613,,,,,,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,,0.317467,,-1.250016,-1.682929,1.412757,,0.520378,0.744132,-0.788658,0.641776,,,0.2272,0.580907,1.128879,,,,-1.512286,-1.414357,-1.823322,-0.082763,-0.184119,,,,,,,,,-10.835207,-0.002704,-0.621836,,1.172836,,,-1.625862,,-1.410017,1.063013,,0.888355,0.467994,-1.36224,,,,,,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,,,-0.281207,-0.182894,-0.245565,-0.302441,0.703665,,,,,,,,,
0,0,9,2.285698,,,,,,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,,-0.814909,,-1.296782,-2.040234,0.639589,,1.597359,0.657514,-1.350148,0.364215,,,-0.017751,-0.317361,-0.122379,,,,-0.320921,-0.95809,-2.436589,0.070999,-0.245239,,,,,,,,,-1.420632,-3.515137,-4.67776,,0.535897,,,-0.72542,,-2.29417,1.764551,,-0.120789,-0.063458,-1.36224,,,,,,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,,,0.377131,0.300724,-0.106842,-0.096792,2.109352,,,,,,,,,
0,0,10,0.690606,,,,,,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,,-0.251882,,-1.902009,-0.979447,0.241165,,-0.392359,-0.224699,-2.129397,-0.855287,,,0.404142,-0.578156,0.105702,,,,0.544138,-0.087091,-1.500147,-0.201288,-0.038042,,,,,,,,,0.382074,2.669135,0.611711,,2.413415,,,1.313203,,-0.810125,2.939022,,3.988801,1.834661,-1.36224,,,,,,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,,,-0.226891,-0.251412,-0.215522,-0.296244,1.114137,,,,,,,,,
0,0,14,0.44057,,,,,,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,,0.646086,,-1.844685,-1.58656,-0.182024,,-0.969949,-0.673813,-1.282132,-1.399894,,,0.043815,-0.320225,-0.031713,,,,-0.08842,-0.995003,-2.635336,-0.196461,-0.618719,,,,,,,,,-2.0146,-2.321076,-3.711265,,1.253902,,,0.476195,,-0.771732,2.843421,,1.379815,0.411827,-1.36224,,,,,,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,,,3.678076,2.793581,2.61825,3.418133,-3.57282,,,,,,,,,


In [5]:
train_scan = pl.scan_parquet(path + 'train.parquet/')
test_scan = pl.scan_parquet(path + 'test.parquet/')

In [6]:
train_symbol_ids_list = sorted(train_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
test_symbol_ids_list = sorted(test_scan.select('symbol_id').unique().collect()['symbol_id'].to_list())
unique_symbol_ids_list = sorted(list(set(train_symbol_ids_list + test_symbol_ids_list)))
unique_symbol_ids_list

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38]

In [7]:
def one_hot_cat_cols(df):
    for v in tqdm(unique_symbol_ids_list):
        new_col_name = 'symbol_id_' + str(v)
        #df[new_col_name] = (df['symbol_id'] == v).astype(int)
        df = df.with_columns((pl.col('symbol_id') == v).cast(pl.Int8).alias(new_col_name))

    
    #df = df.drop('symbol_id', axis=1)

    return df.select(pl.all().shrink_dtype())

In [8]:
train_df.estimated_size() / 1e9

17.032444032

In [9]:
models_path = path + 'my_folder/models/20250104_01/'

In [None]:
def lgb_online_learning(train_data):
    weights = train_data['weight']
    y = train_data['responder_6']

    unique_date_ids = train_data['date_id'].unique()    
    train_date_id_cut = int(unique_date_ids.max() / 2)

    print('max date:', unique_date_ids.max())
    print('date id cut:', train_date_id_cut)

    X_train = train_data.filter(pl.col('date_id') <= train_date_id_cut).drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
    X_val = train_data.filter(pl.col('date_id') > train_date_id_cut).drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()

    print(X_train.shape[0] / train_data.shape[0])

    y_train = y[:X_train.shape[0]].to_pandas()
    y_val = y[X_train.shape[0]:].to_pandas()

    weights_train = weights[:X_train.shape[0]].to_pandas()
    weights_val = weights[X_train.shape[0]:].to_pandas()

    print(X_train.shape)
    display(X_train.head())
    display(X_train.tail())
    

    #train_dataset = lgb.Dataset(data=X_train, label=y_train, weight=weights_train)
    #val_dataset = lgb.Dataset(data=X_val, label=y_val, weight=weights_val)

    base_params = {
        'verbosity': -1,
        'learning_rate': 0.05,
        'feature_fraction': 0.8,
        'device': 'gpu',
        'early_stopping_round': 30,
        'lambda_l2': 100,
        #'metric': 'r2',
        #'seed': 42
    }

    '''model = lgb.train(
        params=base_params,
        train_set=train_dataset,
        num_boost_round=90
    )'''

    model = LGBMRegressor(
        **base_params,
        n_estimators=90000
    )

    model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val], callbacks=[log_evaluation(period=10)])#, init_model=current_model)
    #model.fit(X_train, y_train, sample_weight=weights_train)

    best_iteration = model.best_iteration_
    print(f"Best iteration: {best_iteration}")

    plt.figure()
    lgb.plot_metric(model)
    plt.ylim(0, 2)
    plt.show()

    val_preds = model.predict(X_val)

    print('Val Weighted R2 score is:', r2_score(y_val, val_preds, sample_weight=weights_val))

    return model

    val_date_ids = sorted(train_data.filter(pl.col('date_id') > train_date_id_cut)['date_id'].unique())
    
    for date_id_v in val_date_ids:
        for time_id_v in sorted(train_data.filter(pl.col('date_id') == date_id_v)['time_id'].unique()):
            time_id_df = train_data.filter((pl.col('date_id') == date_id_v) & (pl.col('time_id') == time_id_v))

            print(time_id_df.shape)
            display(time_id_df)

            time_id_X_train = time_id_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
            time_id_y_train = time_id_df['responder_6'].to_pandas()
            time_id_weights_train = time_id_df['weight'].to_pandas()

            val_data_df = train_data.filter(pl.col('date_id') >= date_id_v)[time_id_df.shape[0]:]

            return




    return
    
    '''weights = train_data['weight']
    y = train_data['responder_6']
    
    unique_date_ids = train_data['date_id'].unique()
    train_date_id_cut = int(unique_date_ids.max() - 10)

    print('max date:', unique_date_ids.max())
    print('date id cut:', train_date_id_cut)
    
    X_train = train_data.filter(pl.col('date_id') <= train_date_id_cut).drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
    X_val = train_data.filter(pl.col('date_id') > train_date_id_cut).drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()

    print(X_train.shape[0] / train_data.shape[0])
    
    y_train = y[:X_train.shape[0]].to_pandas()
    y_val = y[X_train.shape[0]:].to_pandas()
    
    weights_train = weights[:X_train.shape[0]].to_pandas()
    weights_val = weights[X_train.shape[0]:].to_pandas()

    print(X_train.shape)
    display(X_train.head())

    base_params = {
        'verbosity': -1,
        'learning_rate': 0.05,
        'feature_fraction': 0.8,
        'device': 'gpu',
        'early_stopping_round': 30,
        'lambda_l2': 100
    }
    
    model = LGBMRegressor(
        **base_params,
        n_estimators=100000
    )

    model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val], callbacks=[log_evaluation(period=50)])#, categorical_feature=['symbol_id'])

    best_iteration = model.best_iteration_
    print(f"Best iteration: {best_iteration}")

    val_preds = model.predict(X_val)

    plt.figure()
    lgb.plot_metric(model)
    plt.ylim(0, 1)
    plt.show()    

    if not os.path.exists(models_path):
        os.makedirs(models_path)

    with open(models_path + "lgb_model.pkl", 'wb') as file:
        pickle.dump(model, file)

    print('Val Weighted R2 score is:', r2_score(y_val, val_preds, sample_weight=weights_val))

    sample_val = X_val.sample(frac=0.001)
    sample_y = y_val.loc[sample_val.index]

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X=sample_val, y=sample_y)
    shap_importance = np.abs(shap_values).mean(axis=0)

    del X_train, y_train, X_val, y_val, weights_train, weights_val
    gc.collect()

    # Retraining on the full dataset using best_iteration
    X_full = train_data.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
    y_full = y.to_pandas()
    weights_full = weights.to_pandas()

    base_params.pop('early_stopping_round')

    model_full = LGBMRegressor(
        **base_params,
        n_estimators=best_iteration
    )
    
    model_full.fit(X_full, y_full, sample_weight=weights_full)

    with open(models_path + "lgb_model_full.pkl", 'wb') as file:
        pickle.dump(model_full, file)

    print("Retraining complete. Model saved as 'lgb_model_full.pkl'.")

    return shap_importance'''

In [None]:
lgb_model = lgb_online_learning(train_df)

In [None]:
lgb_model

In [None]:
if not os.path.exists(models_path):
    os.makedirs(models_path)

In [None]:
# save model
with open(models_path + "lgb_model.pkl", 'wb') as file:
    pickle.dump(lgb_model, file)

In [10]:
# load model
with open(f"{models_path}/lgb_model.pkl", "rb") as f:
    lgb_model = pickle.load(f)

In [11]:
val_df = train_df.filter(pl.col('date_id') > 849)
print(val_df.shape)
val_df.head()

(30302272, 93)


date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
850,0,0,2.087724,-0.276877,-2.385324,-1.086325,0.049463,3.427029,-4.671824,0.054977,-0.259751,1.343003,11,7,76,-0.793587,2.523406,0.303231,,0.523913,,-1.567069,-0.965586,0.014156,-0.171976,1.015679,0.746074,-1.633316,-1.309486,0.965614,1.612443,0.82306,-0.027811,0.72484,-0.184198,,,-0.187675,-0.346574,-1.421471,0.093479,1.387856,,1.254196,,,0.050192,,-0.954613,2.004981,-1.557791,0.678891,-0.066386,,2.456588,,,-1.159385,,-0.889724,1.428067,,0.817551,0.299599,0.352903,-0.328996,-0.151735,-0.224472,-1.477134,-1.643559,-0.556531,2.815019,0.356358,-0.527251,1.609195,0.076337,,,-0.228297,-0.273781,-0.277999,-0.295312,1.461546,0.402863,0.074029,0.36344,-0.558883,-0.419728,-0.238446,-1.213885,-0.616817,-1.411242
850,0,1,3.752097,-0.168178,-2.161023,-0.511679,0.192425,3.162096,-4.386098,0.130385,-0.368283,1.913416,11,7,76,-0.660111,3.052153,0.071869,,0.001913,,-0.625688,-1.11523,0.185483,0.019226,1.916643,0.710887,-1.102333,-0.981141,0.521467,1.665925,1.461316,-0.358575,0.058004,0.021168,,,-0.641563,-0.482115,-2.396556,-0.121039,1.409137,,0.932519,,,1.311157,,-0.749923,1.793136,-2.108881,1.227915,-0.146708,,0.888707,,,-1.427895,,-1.575317,0.556004,,0.321817,0.406464,0.352903,-0.388503,-0.100457,-0.201082,-1.926849,-1.763679,-0.612577,1.61283,-0.051637,-0.97052,2.79455,0.353143,,,-0.157027,-0.163802,-0.277016,-0.444008,0.789595,-0.240175,-0.445871,0.125748,0.264227,0.088307,-0.649466,0.597718,0.500387,-1.605249
850,0,2,1.225099,-0.520426,-1.718115,-0.817358,-0.270528,3.314825,-2.578923,0.1102,-0.20174,2.072351,81,2,59,-0.528026,3.354508,0.327966,,-0.215615,,-1.260532,-2.04301,-1.31462,-0.239955,0.017958,-0.27587,-0.705935,-0.782762,0.268385,1.391267,1.265022,-0.539895,-0.351402,-0.209022,,,-0.164031,-0.517534,0.71262,0.418721,1.150448,,-0.361983,,,-1.394171,,-1.067848,0.734942,-2.05364,-1.888152,-0.688585,,-0.588629,,,-2.212862,,-2.015984,0.025982,,-4.632971,-2.559358,0.352903,-0.316812,-0.264718,-0.248274,-1.383873,-2.433391,-0.728091,4.478824,0.497227,-0.449675,1.648489,-0.001233,,,-0.012737,-0.081892,-0.209053,-0.267447,-2.848316,-0.198698,-0.217445,0.086082,-0.509062,-0.734032,0.970075,-0.747389,-0.662997,1.603958
850,0,3,1.467042,-0.061985,-1.818735,-0.990254,0.274284,3.810929,-1.11177,0.043842,-0.090386,0.777759,4,3,11,-1.218813,1.769522,-0.076559,,-0.461771,,-1.905882,-2.141612,-0.347407,0.201398,0.629975,-0.282367,3.053007,2.292672,-0.312745,-0.370565,-0.473996,-0.535461,-0.979823,0.209806,,,-1.742934,-1.835673,1.511821,0.155715,0.446925,,0.255243,,,-0.066129,,-2.102276,1.224973,-0.932835,-0.402716,-0.369551,,0.237087,,,-1.359047,,-1.917404,1.456003,,-1.652316,-0.545286,0.352903,-0.646065,-0.38281,-0.345414,-1.605209,-1.561482,-0.529622,1.397832,-0.131058,-0.717516,2.507538,0.020102,,,0.377517,0.284319,-0.06742,-0.157564,-0.749164,0.170911,-0.580147,0.440651,-0.154337,-0.523854,-0.754856,-0.248197,-0.220052,-1.16779
850,0,5,3.144071,-0.321442,-1.964041,-0.409452,-0.343893,3.069664,-2.929145,0.084903,-0.214164,1.247011,2,10,171,-0.674077,2.17874,-0.058749,,-0.656464,,-1.158764,-1.013156,0.694589,0.229134,1.647839,0.301107,0.018292,0.099487,-0.772632,0.932901,1.751187,-0.621377,-0.6756,0.216979,,,-0.43032,0.001558,1.253969,0.103685,0.657733,,0.493284,,,-1.426275,,-1.534382,1.102636,-0.859073,0.555257,0.348532,,0.229975,,,-0.513925,,-2.259789,1.537998,,-1.00717,-0.823392,0.352903,-0.480987,-0.299463,-0.499426,-2.79966,-2.620816,-0.586428,2.413169,0.166888,-0.904685,1.361527,-0.238988,,,0.324627,0.262034,-0.147552,-0.138634,-0.8826,0.172549,0.048066,-0.045807,1.434816,0.04118,0.936963,1.769305,0.025372,1.907083


In [None]:
def val_online_learning(val_data, current_model, optuna_n_trials):
    val_data = val_data.clone()
    print(val_data.shape)
    display(val_data.head())
    display(val_data.tail())
    i = 0
    val_date_ids = sorted(val_data['date_id'].unique())
    for date_id_v in val_date_ids:
        
        for time_id_v in sorted(val_data.filter(pl.col('date_id') == date_id_v)['time_id'].unique()):
            time_id_df = val_data.filter((pl.col('date_id') == date_id_v) & (pl.col('time_id') == time_id_v))
    
            X_train = time_id_df.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
            y_train = time_id_df['responder_6'].to_pandas()
            weights_train = time_id_df['weight'].to_pandas()

            #train_dataset = lgb.Dataset(data=X_train, label=y_train, weight=weights_train)
    
            val_data = val_data[time_id_df.shape[0]:]

            X_val = val_data.drop(['date_id', 'time_id', 'symbol_id', 'weight', 'responder_6']).select(pl.all().shrink_dtype()).to_pandas()
            y_val = val_data['responder_6'].to_pandas()
            weights_val = val_data['weight'].to_pandas()

            #val_dataset = lgb.Dataset(data=X_val, label=y_val, weight=weights_val)

            '''base_params = {
                'verbosity': -1,
                'learning_rate': 1,
                #'feature_fraction': 0.8,
                'device': 'gpu',
                'early_stopping_round': 30,
                #'lambda_l2': 100
            }'''

            '''updated_model = lgb.train(
                params=base_params,
                train_set=train_dataset,
                valid_sets=[train_dataset, val_dataset],
                num_boost_round=90,
                init_model=current_model,
                callbacks=[log_evaluation(period=50), record_evaluation()]
            )'''
        
            '''online_model = LGBMRegressor(
                **base_params,
                n_estimators=100000
            )'''

            '''current_model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val], callbacks=[log_evaluation(period=10)], init_model=current_model)
            #current_model.fit(X_val, y_val, sample_weight=weights_val, eval_set=[(X_val, y_val), (X_train, y_train)], eval_sample_weight=[weights_val, weights_train], callbacks=[log_evaluation(period=10)], init_model=current_model)

            #display(online_model)

            plt.figure()
            lgb.plot_metric(current_model)
            plt.ylim(0, 2)
            plt.show()
            
            val_preds = current_model.predict(X_val)
            
            return current_model'''

            base_params = {
                'verbosity': -1,
                #'learning_rate': 0.05,
                #'feature_fraction': 0.8,
                'device': 'gpu',
                'early_stopping_round': 10,
                #'lambda_l2': 100,
                'seed': 42
            }

            def objective(trial):

                params_to_tune = {
                    'learning_rate': trial.suggest_float('learning_rate', 0.000001, 0.005),
                    'min_data_in_leaf': trial.suggest_int('min_data_in_leaf', 10, 300),
                    'num_leaves': trial.suggest_int('num_leaves', 20, 10000),
                    'max_depth': trial.suggest_int('max_depth', 3, 50),
                    'min_gain_to_split': trial.suggest_float('min_gain_to_split', 0, 0.3),
                    'lambda_l1': trial.suggest_float('lambda_l1', 0, 10),
                    'lambda_l2': trial.suggest_float('lambda_l2', 0, 2000)
                }

                online_model = LGBMRegressor(
                    **base_params,
                    **params_to_tune,
                    n_estimators=100000
                )

                online_model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val], init_model=current_model)
                #online_model.fit(X_val, y_val, sample_weight=weights_val, eval_set=[(X_val, y_val), (X_train, y_train)], eval_sample_weight=[weights_val, weights_train], callbacks=[log_evaluation(period=10)], init_model=current_model)

                plt.figure()
                lgb.plot_metric(online_model)
                plt.ylim(0, 2)
                plt.show()

                best_iteration = online_model.best_iteration_
                print(f"Best iteration: {best_iteration}")

                val_preds = online_model.predict(X_val)

                val_r2_score = r2_score(y_val, val_preds, sample_weight=weights_val)

                return val_r2_score

            with tqdm(total=optuna_n_trials, desc="Optimizing", unit="trial") as pbar:
        
                # Define a callback function to update the progress bar
                def progress_bar_callback(study, trial):
                    pbar.update(1)
            
                study = optuna.create_study(direction="maximize")
                study.optimize(objective, n_trials=optuna_n_trials, callbacks=[progress_bar_callback])

            return study
        
            best_params = study.best_params

            online_model.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_val, y_val)], eval_sample_weight=[weights_train, weights_val], callbacks=[log_evaluation(period=10)], init_model=current_model)

            display(online_model)

            plt.figure()
            lgb.plot_metric(online_model)
            plt.ylim(0, 2)
            plt.show()
            
            val_preds = online_model.predict(X_val)

            print('Val Weighted R2 score is:', r2_score(y_val, val_preds, sample_weight=weights_val))

            return online_model

            if i > 20:
                return

            i += 1

In [None]:
lgb_study = val_online_learning(val_df, lgb_model, 100)

In [None]:
lgb_study

In [None]:
cols = train_df.drop(['date_id', 'time_id', 'symbol_id']).columns
imp_df = pd.DataFrame(sorted(zip(cols, first_shap_importance)), columns=['Feature', 'Importance']).sort_values('Importance', ascending=False)

In [None]:
print(imp_df.shape)

In [None]:
imp_df

In [None]:
plt.figure(figsize=(10, 40))
plt.title("Feature importances")
plt.barh(imp_df['Feature'], imp_df['Importance'])
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.gca().invert_yaxis()
plt.show()

In [None]:
unimportant_df = imp_df[imp_df['Importance'] <= imp_df['Importance'].quantile(0.3)]
unimportant_cols = unimportant_df['Feature'].tolist()

In [None]:
train_selected_df = train_df.drop(unimportant_cols)
print(train_selected_df.shape)
train_selected_df.head()

In [None]:
second_shap_importance = lgb_train(train_selected_df, y_sr)