In [1]:
import polars as pl
import polars.selectors as cs
from utils import PathsData, import_data, split_data, plot_correlation
from features import FeaturesFrame

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, StandardScaler

x_validation = import_data(PathsData.X_TEST.value) #to use in the end to make a submission
x = import_data(PathsData.X_TRAIN.value)
y = import_data(PathsData.Y_TRAIN.value)

x_train, x_test, y_train, y_test = split_data(x=x, y=y)

In [2]:
test = FeaturesFrame(x_train)
test = (
    test
    .with_columns(pl.col('date').str.to_date("%Y-%m-%d"))
    .with_columns(
        pl.col('date').dt.weekday().alias('weekday'),
        # pl.col('date').dt.year().alias('year'), #! useless: only 2023
        pl.col('date').dt.month().alias('month'),
        pl.col('date').dt.day().alias('day')
    )
)
dict_string = test.select(cs.string()).to_dict(as_series=False)
test = (   
    test
    .with_columns(
        encoded_train = LabelEncoder().fit_transform(dict_string['train']),
        encoded_gare = LabelEncoder().fit_transform(dict_string['gare'])
    )
    .drop(cs.string(), cs.date())

)

In [3]:
test

arret,p2q0,p3q0,p4q0,p0q2,p0q3,p0q4,weekday,month,day,encoded_train,encoded_gare
i64,f64,f64,f64,f64,f64,f64,i8,i8,i8,i64,i64
8,0.0,0.0,1.0,-3.0,-1.0,-2.0,1,4,3,24311,34
9,0.0,0.0,0.0,1.0,0.0,1.0,1,4,3,24311,26
10,-1.0,0.0,0.0,-1.0,0.0,0.0,1,4,3,24311,14
11,-1.0,-1.0,0.0,2.0,-2.0,0.0,1,4,3,24311,68
12,-1.0,-1.0,-1.0,-1.0,3.0,2.0,1,4,3,24311,43
…,…,…,…,…,…,…,…,…,…,…,…
16,0.0,1.0,0.0,0.0,-2.0,0.0,2,10,3,1029,68
17,0.0,0.0,1.0,0.0,-2.0,0.0,2,10,3,1029,43
18,1.0,0.0,0.0,0.0,-3.0,0.0,2,10,3,1029,81
19,1.0,1.0,0.0,-3.0,-1.0,0.0,2,10,3,1029,64


In [4]:
StandardScaler().fit(test.to_numpy()).transform(test.to_numpy()).transpose()

array([[-1.46526568, -1.3231836 , -1.18110151, ..., -0.04444485,
         0.09763723,  0.23971931],
       [ 0.0860831 ,  0.0860831 , -0.42864226, ...,  0.60080847,
         0.60080847,  0.60080847],
       [ 0.08531018,  0.08531018,  0.08531018, ...,  0.08531018,
         0.6033934 ,  0.6033934 ],
       ...,
       [-1.55791028, -1.55791028, -1.55791028, ..., -1.55791028,
        -1.55791028, -1.55791028],
       [ 1.09222656,  1.09222656,  1.09222656, ..., -1.61397382,
        -1.61397382, -1.61397382],
       [-0.29245576, -0.61992056, -1.11111777, ...,  1.63139996,
         0.93553725, -0.21058956]])

In [5]:
test = pl.DataFrame(
    data = dict(zip(test.columns, StandardScaler().fit(test.to_numpy()).transform(test.to_numpy()).transpose() )) #! scaler fit que sur train 
)

In [None]:
test = FeaturesFrame(test).add_feature_interactions().add_feature_square().scale_standard(set="train")[1]

arret,p2q0,p3q0,p4q0,p0q2,p0q3,p0q4,weekday,month,day,encoded_train,encoded_gare,arret_p2q0_multiplied,arret_p3q0_multiplied,arret_p4q0_multiplied,arret_p0q2_multiplied,arret_p0q3_multiplied,arret_p0q4_multiplied,arret_weekday_multiplied,arret_month_multiplied,arret_day_multiplied,arret_encoded_train_multiplied,arret_encoded_gare_multiplied,p2q0_p3q0_multiplied,p2q0_p4q0_multiplied,p2q0_p0q2_multiplied,p2q0_p0q3_multiplied,p2q0_p0q4_multiplied,p2q0_weekday_multiplied,p2q0_month_multiplied,p2q0_day_multiplied,p2q0_encoded_train_multiplied,p2q0_encoded_gare_multiplied,p3q0_p4q0_multiplied,p3q0_p0q2_multiplied,p3q0_p0q3_multiplied,p3q0_p0q4_multiplied,…,p3q0_encoded_gare_multiplied_square,p4q0_p0q2_multiplied_square,p4q0_p0q3_multiplied_square,p4q0_p0q4_multiplied_square,p4q0_weekday_multiplied_square,p4q0_month_multiplied_square,p4q0_day_multiplied_square,p4q0_encoded_train_multiplied_square,p4q0_encoded_gare_multiplied_square,p0q2_p0q3_multiplied_square,p0q2_p0q4_multiplied_square,p0q2_weekday_multiplied_square,p0q2_month_multiplied_square,p0q2_day_multiplied_square,p0q2_encoded_train_multiplied_square,p0q2_encoded_gare_multiplied_square,p0q3_p0q4_multiplied_square,p0q3_weekday_multiplied_square,p0q3_month_multiplied_square,p0q3_day_multiplied_square,p0q3_encoded_train_multiplied_square,p0q3_encoded_gare_multiplied_square,p0q4_weekday_multiplied_square,p0q4_month_multiplied_square,p0q4_day_multiplied_square,p0q4_encoded_train_multiplied_square,p0q4_encoded_gare_multiplied_square,weekday_month_multiplied_square,weekday_day_multiplied_square,weekday_encoded_train_multiplied_square,weekday_encoded_gare_multiplied_square,month_day_multiplied_square,month_encoded_train_multiplied_square,month_encoded_gare_multiplied_square,day_encoded_train_multiplied_square,day_encoded_gare_multiplied_square,encoded_train_encoded_gare_multiplied_square
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
-1.465266,0.086083,0.08531,0.620097,-1.467297,-0.30091,-0.352369,-1.441567,-1.216083,-1.55791,1.092227,-0.292456,-0.150739,-0.138783,-0.979862,2.173955,0.585502,0.663018,2.117761,1.780098,2.273517,-1.599571,0.442847,-0.017766,0.052483,-0.014439,-0.006982,-0.010063,-0.128835,-0.0856,-0.154948,0.100499,-0.009989,-0.015745,-0.019814,-0.010132,-0.01471,…,-0.017596,-0.002547,-0.003511,-0.002774,-0.005701,-0.008929,0.012344,-0.014383,-0.017184,-0.002953,-0.003799,0.148665,0.039701,0.268262,0.035892,-0.015417,-0.002919,-0.024908,-0.015028,-0.002405,-0.006651,-0.002528,-0.001731,-0.001998,-0.000836,-0.001937,-0.003202,1.271563,2.5736,1.013612,-0.541802,1.377603,0.511604,-0.562254,1.186005,-0.483858,-0.577292
-1.323184,0.086083,0.08531,0.091992,0.598992,0.06198,0.226497,-1.441567,-1.216083,-1.55791,1.092227,-0.619921,-0.13771,-0.125636,-0.11268,-0.869438,-0.151226,-0.419219,1.912768,1.606858,2.052319,-1.444242,0.848702,-0.017766,0.032114,-0.000506,-0.00332,0.003622,-0.128835,-0.0856,-0.154948,0.100499,-0.038983,-0.017628,-0.004027,-0.007286,-0.003919,…,-0.017557,-0.002557,-0.003513,-0.002781,-0.035694,-0.018359,-0.044869,-0.030252,-0.01766,-0.002959,-0.003845,-0.009169,-0.009784,0.005931,-0.012175,-0.016224,-0.002921,-0.040964,-0.019801,-0.003823,-0.008081,-0.002542,-0.002215,-0.003222,-0.002939,-0.002245,-0.003174,1.271563,2.5736,1.013612,-0.132041,1.377603,0.511604,-0.280077,1.186005,-0.04136,-0.34825
-1.181102,-0.428642,0.08531,0.091992,-0.434153,0.06198,0.033542,-1.441567,-1.216083,-1.55791,1.092227,-1.111118,0.522926,-0.112488,-0.098276,0.480649,-0.138819,-0.074286,1.707776,1.433618,1.83112,-1.288913,1.358501,-0.019594,0.010899,0.010032,-0.00706,-0.005681,0.632158,0.507262,0.770082,-0.500854,0.50578,-0.017628,-0.011921,-0.007286,-0.007516,…,-0.017447,-0.002557,-0.003513,-0.002781,-0.035694,-0.018359,-0.044869,-0.030252,-0.017544,-0.002959,-0.003849,-0.024151,-0.014481,-0.01897,-0.016737,-0.014571,-0.002921,-0.040964,-0.019801,-0.003823,-0.008081,-0.002535,-0.002549,-0.004065,-0.004387,-0.002457,-0.00323,1.271563,2.5736,1.013612,1.034108,1.377603,0.511604,0.522977,1.186005,1.217959,0.303585
-1.039019,-0.428642,-0.432773,0.091992,1.115564,-0.663801,0.033542,-1.441567,-1.216083,-1.55791,1.092227,1.09927,0.45805,0.484558,-0.083872,-1.248515,0.93601,-0.067965,1.502783,1.260378,1.609922,-1.133584,-1.184419,-0.01035,0.010899,-0.042001,0.029417,-0.005681,0.632158,0.507262,0.770082,-0.500854,-0.468745,-0.01962,-0.05184,0.018639,-0.011313,…,-0.013576,-0.002557,-0.003512,-0.002781,-0.035694,-0.018359,-0.044869,-0.030252,-0.017548,-0.002941,-0.003849,0.068746,0.014645,0.135431,0.011553,0.007594,-0.002921,0.039923,0.004242,0.003319,-0.000873,-0.00137,-0.002549,-0.004065,-0.004387,-0.002457,-0.00323,1.271563,2.5736,1.013612,0.99819,1.377603,0.511604,0.498243,1.186005,1.179172,0.283508
-0.896937,-0.428642,-0.432773,-0.436113,-0.434153,1.150652,0.419453,-1.441567,-1.216083,-1.55791,1.092227,0.075942,0.393174,0.41786,0.452546,0.35305,-1.489716,-0.520711,1.297791,1.087138,1.388723,-0.978255,-0.071683,-0.01035,0.112324,0.010032,-0.061777,-0.051109,0.632158,0.507262,0.770082,-0.500854,-0.017576,-0.010068,0.008223,-0.053555,-0.047806,…,-0.017588,-0.002557,-0.003502,-0.002776,-0.0212,-0.013802,-0.01722,-0.022583,-0.017695,-0.002951,-0.003843,-0.024151,-0.014481,-0.01897,-0.016737,-0.018609,-0.002864,0.20351,0.052865,0.017762,0.013704,-0.002528,-0.001387,-0.001128,0.000658,-0.001718,-0.003231,1.271563,2.5736,1.013612,-0.651196,1.377603,0.511604,-0.637587,1.186005,-0.601992,-0.638439
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
-0.328609,0.086083,0.603393,0.091992,0.08242,-0.663801,0.033542,-0.721911,1.924277,-1.55791,-1.613974,1.09927,-0.046508,-0.218269,-0.011851,-0.077714,0.271633,-0.036362,0.241138,-0.6405,0.503929,0.533164,-0.375356,-0.01591,0.032114,-0.003989,-0.010645,-0.00094,-0.0653,0.170443,-0.154948,-0.148685,0.113237,-0.015636,-0.00415,-0.044596,-0.003719,…,-0.00977,-0.002557,-0.003512,-0.002781,-0.0362,-0.01804,-0.044869,-0.02983,-0.017548,-0.002959,-0.003849,-0.040582,-0.019211,-0.045536,-0.021389,-0.018485,-0.002921,-0.021212,0.040716,0.003319,0.007733,-0.00137,-0.002554,-0.004036,-0.004387,-0.002452,-0.00323,0.551696,0.161536,0.245921,-0.243484,4.288681,5.779812,2.214555,3.335374,1.179172,1.379932
-0.186527,0.086083,0.08531,0.620097,0.08242,-0.663801,0.033542,-0.721911,1.924277,-1.55791,-1.613974,0.075942,-0.033479,-0.020453,-0.106004,-0.065602,0.138757,-0.030042,0.138482,-0.366372,0.282731,0.303635,-0.01579,-0.017766,0.052483,-0.003989,-0.010645,-0.00094,-0.0653,0.170443,-0.154948,-0.148685,0.02263,-0.015745,-0.007974,-0.012978,-0.007516,…,-0.017606,-0.002557,-0.003506,-0.002781,-0.028678,0.005573,0.012344,0.004823,-0.017677,-0.002959,-0.003849,-0.040582,-0.019211,-0.045536,-0.021389,-0.018628,-0.002921,-0.021212,0.040716,0.003319,0.007733,-0.00254,-0.002554,-0.004036,-0.004387,-0.002452,-0.003234,0.551696,0.161536,0.245921,-0.657122,4.288681,5.779812,-0.629396,3.335374,-0.601992,-0.633206
-0.044445,0.600808,0.08531,0.091992,0.08242,-1.026691,0.033542,-0.721911,1.924277,-1.55791,-1.613974,1.6314,-0.04482,-0.007305,0.016958,-0.05349,0.028605,-0.023721,0.035825,-0.092245,0.061533,0.074107,-0.076234,-0.015938,0.05333,-0.000666,-0.076273,0.003801,-0.446392,1.108562,-1.079978,-1.037299,1.024058,-0.017628,-0.007974,-0.015825,-0.007516,…,-0.017262,-0.002557,-0.003512,-0.002781,-0.0362,-0.01804,-0.044869,-0.02983,-0.01735,-0.002959,-0.003849,-0.040582,-0.019211,-0.045536,-0.021389,-0.018313,-0.002921,0.007278,0.125264,0.013349,0.029839,0.003647,-0.002554,-0.004036,-0.004387,-0.002452,-0.003225,0.551696,0.161536,0.245921,0.256293,4.288681,5.779812,5.650755,3.335374,3.331261,3.812305
0.097637,0.600808,0.603393,0.091992,-1.467297,-0.30091,0.033542,-0.721911,1.924277,-1.55791,-1.613974,0.935537,0.046114,0.060712,0.031362,-0.197874,-0.077076,-0.017401,-0.066832,0.181883,-0.159666,-0.155422,0.093519,-0.00298,0.05333,-0.073598,-0.025143,0.003801,-0.446392,1.108562,-1.079978,-1.037299,0.594037,-0.015636,-0.087893,-0.024465,-0.003719,…,-0.011931,-0.002557,-0.003513,-0.002781,-0.0362,-0.01804,-0.044869,-0.02983,-0.017593,-0.002953,-0.003848,0.006766,0.129002,0.268262,0.104158,0.014229,-0.002921,-0.037471,-0.007533,-0.002405,-0.004882,-0.00237,-0.002554,-0.004036,-0.004387,-0.002452,-0.003231,0.551696,0.161536,0.245921,-0.358074,4.288681,5.779812,1.426695,3.335374,0.685735,0.822232
