In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import polars as pl
import seaborn as sns
from tqdm import tqdm


from river.tree import HoeffdingTreeRegressor
from river.metrics import MAE, r2

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pd.set_option('display.float_format', '{:.2f}'.format)

In [2]:
multi_index = ['date_id', 'time_id', 'symbol_id']
feature_col = ['weight'] + [f'feature_{i:02d}' for i in range(79)]

most_na_drop = ['feature_00', 'feature_01', 'feature_02', 'feature_03', 'feature_04',
                'feature_21', 'feature_26', 'feature_27', 'feature_31', ]

relate_time_drop = ['feature_15', 'feature_17', 'feature_50', 'feature_52', 'feture_73', 'feature_74']

relate_date_symbol = ['weight', 'feature_09', 'feature_10', 'feature_11', 'feature_20', 'feature_22', 
                      'feature_23', 'feature_24', 'feature_25', 'feature_28','feature_29', 'feature_30']
drop_feat = most_na_drop + relate_time_drop + relate_date_symbol
target = 'responder_6'
interest_col = multi_index + feature_col + [target]
resp = multi_index + [f'responder_{i:d}' for i in range(9)]

In [4]:
all_df = pl.scan_parquet('data/train.parquet')
symbol = 4
df = all_df.select(interest_col).drop(most_na_drop).collect()

df.head()

date_id,time_id,symbol_id,weight,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_22,feature_23,feature_24,feature_25,feature_28,feature_29,feature_30,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
i16,i16,i8,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,,-0.169586,,-1.335938,-1.707803,0.91013,1.636431,1.522133,-1.551398,-0.229627,1.378301,-0.283712,0.123196,,,0.28118,0.269163,0.349028,-0.012596,-0.225932,,-1.073602,,,-0.181716,,,,0.564021,2.088506,0.832022,,0.204797,,,-0.808103,,-2.037683,0.727661,,-0.989118,-0.345213,-1.36224,,,,,,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,,,-0.261412,-0.211486,-0.335556,-0.281498,0.775981
0,0,7,1.370613,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,,0.317467,,-1.250016,-1.682929,1.412757,0.520378,0.744132,-0.788658,0.641776,0.2272,0.580907,1.128879,,,-1.512286,-1.414357,-1.823322,-0.082763,-0.184119,,,,,,,,,-10.835207,-0.002704,-0.621836,,1.172836,,,-1.625862,,-1.410017,1.063013,,0.888355,0.467994,-1.36224,,,,,,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,,,-0.281207,-0.182894,-0.245565,-0.302441,0.703665
0,0,9,2.285698,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,,-0.814909,,-1.296782,-2.040234,0.639589,1.597359,0.657514,-1.350148,0.364215,-0.017751,-0.317361,-0.122379,,,-0.320921,-0.95809,-2.436589,0.070999,-0.245239,,,,,,,,,-1.420632,-3.515137,-4.67776,,0.535897,,,-0.72542,,-2.29417,1.764551,,-0.120789,-0.063458,-1.36224,,,,,,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,,,0.377131,0.300724,-0.106842,-0.096792,2.109352
0,0,10,0.690606,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,,-0.251882,,-1.902009,-0.979447,0.241165,-0.392359,-0.224699,-2.129397,-0.855287,0.404142,-0.578156,0.105702,,,0.544138,-0.087091,-1.500147,-0.201288,-0.038042,,,,,,,,,0.382074,2.669135,0.611711,,2.413415,,,1.313203,,-0.810125,2.939022,,3.988801,1.834661,-1.36224,,,,,,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,,,-0.226891,-0.251412,-0.215522,-0.296244,1.114137
0,0,14,0.44057,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,,0.646086,,-1.844685,-1.58656,-0.182024,-0.969949,-0.673813,-1.282132,-1.399894,0.043815,-0.320225,-0.031713,,,-0.08842,-0.995003,-2.635336,-0.196461,-0.618719,,,,,,,,,-2.0146,-2.321076,-3.711265,,1.253902,,,0.476195,,-0.771732,2.843421,,1.379815,0.411827,-1.36224,,,,,,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,,,3.678076,2.793581,2.61825,3.418133,-3.57282


In [18]:
df.filter(pl.col('symbol_id') == 23).select(multi_index+['feature_51']).group_by('date_id').agg(pl.all()).sort('date_id')

date_id,time_id,symbol_id,feature_51
i16,list[i16],list[i8],list[f32]
487,"[0, 1, … 848]","[23, 23, … 23]","[-2.689121, -2.518059, … -0.60992]"
488,"[0, 1, … 848]","[23, 23, … 23]","[-1.423291, -2.384849, … -0.717525]"
489,"[0, 1, … 848]","[23, 23, … 23]","[-2.750756, 2.405879, … 0.830122]"
490,"[0, 1, … 848]","[23, 23, … 23]","[0.23566, 0.614342, … 0.248478]"
491,"[0, 1, … 848]","[23, 23, … 23]","[3.336976, 3.209603, … 0.967412]"
492,"[0, 1, … 848]","[23, 23, … 23]","[2.611406, 2.300066, … 1.300363]"
493,"[0, 1, … 848]","[23, 23, … 23]","[0.353284, 0.676122, … -0.635858]"
494,"[0, 1, … 848]","[23, 23, … 23]","[1.351237, 0.265874, … 1.121995]"
495,"[0, 1, … 848]","[23, 23, … 23]","[2.851859, 2.003689, … -0.605163]"
496,"[0, 1, … 848]","[23, 23, … 23]","[3.281615, 3.246241, … -0.84717]"
