In [1]:
import datetime as dt

import plotly.express as px
import polars as pl

from stocksense.config import config
from stocksense.database import DatabaseHandler
from stocksense.model import XGBoostClassifier
from stocksense.pipeline import clean, engineer_features

features = config.model.features
date_col = config.model.date_col
aux_cols = ["tic", "datadate", "rdq"]
target = "excess_return_4Q_pos"

In [2]:
constituents = DatabaseHandler().fetch_constituents(dt.datetime(2022, 6, 1))

[32m2024-12-12 20:35:46.853[0m | [32m[1mSUCCESS [0m | [36mstocksense.database.schema[0m:[36mcreate_tables[0m:[36m121[0m - [32m[1mTables created successfully[0m


In [3]:
data = engineer_features()
data = clean(data)
data.head()

[32m2024-12-12 20:35:46.862[0m | [1mINFO    [0m | [36mstocksense.pipeline.preprocess[0m:[36mengineer_features[0m:[36m20[0m - [1mSTART processing stock data[0m
[32m2024-12-12 20:35:46.864[0m | [32m[1mSUCCESS [0m | [36mstocksense.database.schema[0m:[36mcreate_tables[0m:[36m121[0m - [32m[1mTables created successfully[0m
[32m2024-12-12 20:35:50.783[0m | [1mINFO    [0m | [36mstocksense.pipeline.preprocess[0m:[36mengineer_features[0m:[36m33[0m - [1mSTART feature engineering[0m
[32m2024-12-12 20:36:05.135[0m | [32m[1mSUCCESS [0m | [36mstocksense.pipeline.preprocess[0m:[36mengineer_features[0m:[36m46[0m - [32m[1mEND 58960 rows PROCESSED[0m
[32m2024-12-12 20:36:05.138[0m | [1mINFO    [0m | [36mstocksense.pipeline.preprocess[0m:[36mclean[0m:[36m68[0m - [1mSTART cleaning data[0m
[32m2024-12-12 20:36:05.280[0m | [32m[1mSUCCESS [0m | [36mstocksense.pipeline.preprocess[0m:[36mclean[0m:[36m106[0m - [32m[1m37372 rows retained

tdq,tic,datadate,rdq,saleq,cogsq,xsgaq,niq,ebitdaq,cshoq,actq,atq,cheq,rectq,invtq,ppentq,lctq,dlttq,ltq,req,seqq,oancfq,ivncfq,fincfq,dvq,capxq,icaptq,surprise_pct,stock_split,n_purch,val_purch,n_sales,val_sales,insider_balance,roa,roi,roe,…,f_score,f_score_gr1,f_score_gr4,forward_vol_yoy,forward_vol_sos,forward_vol_qoq,excess_return_1Q,sharpe_ratio_1Q,risk_return_1Q,fwd_return_1Q_pos,excess_return_1Q_pos,excess_return_2Q,sharpe_ratio_2Q,risk_return_2Q,fwd_return_2Q_pos,excess_return_2Q_pos,excess_return_3Q,sharpe_ratio_3Q,risk_return_3Q,fwd_return_3Q_pos,excess_return_3Q_pos,excess_return_4Q,sharpe_ratio_4Q,risk_return_4Q,fwd_return_4Q_pos,excess_return_4Q_pos,sector_communication_services,sector_consumer_discretionary,sector_consumer_staples,sector_energy,sector_financials,sector_health_care,sector_industrials,sector_information_technology,sector_materials,sector_real_estate,sector_utilities
date,str,date,date,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8,u32,f64,u32,f64,f64,f64,f64,f64,…,i8,i8,i8,f64,f64,f64,f64,f64,f64,i8,i8,f64,f64,f64,i8,i8,f64,f64,f64,i8,i8,f64,f64,f64,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
2007-03-01,"""A""",2007-01-31,2007-02-15,1280.0,540.0,564.0,150.0,176.0,405.0,3749.0,7193.0,2090.0,671.0,648.0,775.0,1372.0,1500.0,3536.0,1705.0,3657.0,93.0,-93.0,-169.0,0.0,37.0,5157.0,0.011,0,0,0.0,1,8.598,8.598,0.089114,0.029087,0.17528,…,7,1,,1.780824,1.749898,1.355307,8.952132,6.605244,10.89898,1,1,13.694795,10.104573,15.070857,1,1,8.583285,4.905021,9.151445,1,1,10.709357,6.013709,6.242292,1,1,0,0,0,0,0,1,0,0,0,0,0
2007-06-01,"""A""",2007-04-30,2007-05-14,1320.0,540.0,586.0,123.0,194.0,395.958,3791.0,7283.0,2050.0,718.0,650.0,777.0,1554.0,1500.0,3728.0,1881.0,3555.0,302.0,-36.0,-320.0,0.0,42.0,5055.0,-0.023,0,0,0.0,3,2.695,2.695,0.089112,0.024332,0.18256,…,7,0,,1.927423,1.918095,2.106162,1.404396,0.666804,-0.836463,0,1,-2.215787,-1.05205,-2.205524,0,0,-0.292617,-0.152556,-4.454244,0,0,-5.91296,-3.067807,-8.865139,0,0,0,0,0,0,0,1,0,0,0,0,0
2007-09-01,"""A""",2007-07-31,2007-08-14,1374.0,563.0,580.0,185.0,231.0,386.548,3281.0,7024.0,1486.0,738.0,674.0,787.0,1510.0,1500.0,3736.0,2069.0,3288.0,176.0,-258.0,-484.0,0.0,36.0,4788.0,-0.01,0,0,0.0,2,5.747,5.747,0.086418,0.038638,0.184611,…,6,-1,,1.772393,1.803532,1.789715,-1.703536,-0.951848,0.168495,1,0,-0.089423,-0.049965,-2.674412,0,0,-6.394643,-3.545622,-7.722925,0,0,10.925009,6.163987,-0.514147,0,1,0,0,0,0,0,1,0,0,0,0,0
2007-12-01,"""A""",2007-10-31,2007-11-15,1446.0,611.0,598.0,180.0,237.0,370.0,3671.0,7554.0,1826.0,735.0,643.0,801.0,1663.0,2087.0,4320.0,2580.0,3234.0,398.0,-69.0,-7.0,0.0,39.0,5321.0,0.022,0,0,0.0,2,2.75,2.75,0.084459,0.033828,0.197279,…,5,-1,-1.0,3.004222,1.948736,1.848409,-3.735506,-2.02093,-6.186043,0,0,-8.003799,-4.330101,-8.594016,0,0,8.239344,4.228045,-2.316993,0,1,-2.055557,-0.684223,-10.49239,0,0,0,0,0,0,0,1,0,0,0,0,0
2008-03-01,"""A""",2008-01-31,2008-02-13,1393.0,580.0,617.0,120.0,196.0,368.0,5070.0,7459.0,3148.0,726.0,674.0,801.0,2674.0,626.0,4286.0,2657.0,3173.0,4.0,-295.0,-168.0,0.0,34.0,3799.0,0.021,0,0,0.0,2,3.418,3.418,0.081512,0.031587,0.191617,…,6,1,-1.0,3.238751,1.741436,2.050642,1.053516,0.513749,2.536928,1,1,20.442391,9.968778,8.760186,1,1,6.726355,3.862535,-7.914079,0,1,-6.951969,-2.146497,-13.35568,0,0,0,0,0,0,0,1,0,0,0,0,0


In [4]:
df = data.to_pandas()
df_standardized = (df[features] - df[features].mean()) / df[features].std()
corr = df_standardized.corrwith(df[target])
corr_df = pl.DataFrame({"Feature": corr.index, "Correlation": corr.values}).sort(
    "Correlation", descending=False
)

fig = px.bar(
    corr_df,
    x="Correlation",
    y="Feature",
    orientation="h",
    title=f"Feature Correlations with Target ({target})",
    width=1000,
    height=1200,
)
fig.update_layout(yaxis={"tickfont": {"size": 10}}, showlegend=False, margin={"l": 200})
fig.add_vline(x=0, line_dash="dash", line_color="gray")
fig.show()

In [5]:
data = data.select(["tic", date_col] + features + ["fwd_return_4Q", target])
data = data.filter((~pl.all_horizontal(pl.col(target).is_null())))
train = data.filter((pl.col("tdq").dt.year() >= 2007) & (pl.col("tdq").dt.year() < 2021))
val = data.filter(
    (pl.col("tdq").dt.year() >= 2022)
    & (pl.col("tdq").dt.year() <= 2023)
    & pl.col("tic").is_in(constituents)
)

X_train = train.select(features).to_pandas()
y_train = train.select(target).to_pandas().values.ravel()
X_val = val.select(features).to_pandas()
y_val = val.select(target).to_pandas().values.ravel()

params = {
    "learning_rate": 0.01,
    "n_estimators": 100,
    "max_depth": 7,
    "min_child_weight": 3.21,
    "gamma": 0.45,
    "subsample": 0.50,
    "colsample_bytree": 0.53,
    "reg_alpha": 1.83,
    "reg_lambda": 1.2,
    "nthread": -1,
    "seed": 100,
}

model = XGBoostClassifier(params)
model.train(X_train, y_train)

print(f"PR AUC: {model.get_pr_auc(X_val, y_val)}")
print(f"ROC AUC: {model.get_roc_auc(X_val, y_val)}")

PR AUC: 0.40451412813410054
ROC AUC: 0.5182640890573731


In [6]:
val

tic,tdq,n_purch,n_sales,insider_balance,volume_ma20,volume_ma50,price_mom,price_qoq,price_yoy,price_2y,price_risk_qoq,price_risk_sos,price_risk_yoy,price_risk_2y,rsi_14d,rsi_30d,rsi_60d,rsi_90d,rsi_1y,vol_mom,vol_qoq,vol_sos,vol_yoy,vol_2y,rel_vol_mom,rel_vol_qoq,rel_vol_yoy,momentum_mom,momentum_qoq,momentum_yoy,momentum_2y,index_mom,index_qoq,index_sos,index_yoy,index_2y,…,niq_2y,ltq_yoy,ltq_2y,dlttq_yoy,gpm_yoy,gpm_2y,roa_yoy,roa_2y,roi_yoy,roi_2y,roe_yoy,fcf_yoy,der_yoy,dr_yoy,dr_2y,ltda_yoy,ev_ebitda_yoy,ltcr_yoy,pe_yoy,pe_2y,pb_yoy,ps_yoy,atr_yoy,size_yoy,sector_utilities,sector_health_care,sector_financials,sector_consumer_discretionary,sector_consumer_staples,sector_energy,sector_industrials,sector_information_technology,sector_communication_services,sector_materials,sector_real_estate,fwd_return_4Q,excess_return_4Q_pos
str,date,u32,u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,f64,i8
"""A""",2022-03-01,0,5,19.293,133.226941,117.561988,-3.74289,-10.984423,6.704944,60.949116,-150.0,-150.0,200.0,300.0,43.158399,42.053527,43.715374,45.665021,51.522485,2.161613,1.857107,1.641768,1.424811,1.894748,50.0,150.0,160.441865,50.0,50.0,50.0,50.0,-2.833813,-4.58184,-4.913681,11.264523,39.350786,…,43.654822,6.221766,11.175586,24.942792,0.418303,2.029907,39.358651,45.107156,-12.893125,20.895223,38.663007,7.142857,-0.991586,-0.494881,2.283261,17.042371,-23.129185,-14.246468,-30.140404,5.066142,-3.130584,-10.827631,-0.76069,0.711764,0,1,0,0,0,0,0,0,0,0,0,15.553968,1
"""A""",2022-06-01,0,1,0.165,116.531326,115.558614,2.943879,-8.066325,-9.870404,36.327384,-150.0,-150.0,-200.0,300.0,49.272664,47.577979,46.541324,46.77372,50.376317,3.014695,2.469218,2.175493,1.758636,1.591078,50.0,147.055598,153.737685,-50.0,50.0,50.0,50.0,-1.303128,-5.25865,-10.198599,-2.540092,33.121373,…,171.287129,-4.56335,13.782803,0.110011,0.008074,0.958421,35.79613,67.232292,21.762787,126.509135,28.22335,-40.042373,-10.37675,-5.083665,2.899704,-0.435782,-28.314049,-40.10826,-35.577858,-28.576028,-17.395771,-21.606108,1.776929,0.059105,0,1,0,0,0,0,0,0,0,0,0,9.318676,1
"""A""",2022-09-01,0,2,0.868,79.928753,82.561189,-4.184011,1.018566,-26.514679,27.501967,46.302237,-150.0,-200.0,300.0,46.529941,50.295003,49.851588,49.37193,50.749185,2.316539,2.19982,2.343169,2.031332,1.71518,50.0,148.682818,153.04007,50.0,-27.157106,50.0,50.0,-4.532181,-3.750642,-9.089974,-12.317168,12.482106,…,65.326633,-2.741208,18.138007,0.146628,1.743347,3.022338,34.230978,74.990489,22.24579,53.513405,30.320847,-2.39521,-5.511298,-2.67627,7.568238,0.213494,-37.978017,-2.538116,-46.471204,-36.421035,-30.24082,-33.770549,8.064162,-0.007209,0,1,0,0,0,0,0,0,0,0,0,-5.57725,0
"""A""",2022-12-01,0,14,29.504,84.294291,79.954035,10.809849,20.742453,5.337014,35.332866,150.0,150.0,200.0,300.0,64.82618,61.518024,57.445711,55.2601,52.630406,2.743982,2.300815,2.214761,2.21394,1.804737,50.0,134.867777,145.846538,50.0,50.0,-50.0,50.0,5.717434,4.308392,-2.400146,-9.671307,11.307188,…,65.765766,-1.674191,9.949516,0.146574,0.950664,3.005054,5.338708,59.422189,-15.912726,47.596638,5.277354,1.587302,-0.117289,-0.059079,0.501708,1.791594,-10.702488,1.438619,-0.676569,-25.389269,4.565081,-5.016431,12.339549,-0.175597,0,1,0,0,0,0,0,0,0,0,0,-26.851359,0
"""A""",2023-03-01,0,5,9.322,85.991285,81.411412,-9.37789,-11.272423,5.484807,10.467538,-150.0,150.0,200.0,300.0,27.642313,39.893702,46.853905,48.655084,50.816258,1.430282,1.57085,1.948803,2.161252,1.828145,50.0,148.188834,145.011024,50.0,50.0,-50.0,50.0,-1.652163,-3.154822,-1.965465,-9.660628,1.270428,…,22.222222,2.648367,9.034908,0.10989,1.406293,1.830478,3.839863,44.709833,17.55397,2.397589,0.886202,16.078431,-5.678431,-2.916963,-3.397409,-5.31781,-10.779239,15.951013,-6.946706,-34.993345,-6.122066,-4.984549,2.849477,0.60311,0,1,0,0,0,0,0,0,0,0,0,-3.250208,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""ZTS""",2022-12-01,0,0,0.0,139.909142,110.406077,2.566788,-0.127517,-28.689795,-2.925137,-4.954486,-150.0,-200.0,-174.274658,61.163976,53.440047,49.431148,48.308086,49.582144,3.726094,2.573759,2.202554,2.014371,1.678464,50.0,150.0,132.699596,44.894051,-2.959725,50.0,-25.869712,5.717434,4.308392,-2.400146,-9.671307,11.307188,…,10.438413,-0.132979,-10.931014,-20.964806,-1.236584,0.362138,4.525028,24.757024,9.226616,13.914201,4.713626,-19.448698,0.27403,0.093427,-10.598813,-20.785627,-31.462157,1.918269,-32.655876,-23.432926,-29.481525,-33.197357,1.812715,-0.023773,0,1,0,0,0,0,0,0,0,0,0,12.704655,1
"""ZTS""",2023-03-01,0,1,2.277,101.933659,103.999792,1.742569,8.712864,-13.46759,6.265461,150.0,150.0,-200.0,300.0,53.3973,54.744564,52.724753,51.216697,50.329901,1.935118,1.719409,2.152536,2.038631,1.717367,50.0,150.0,136.783662,-50.0,-50.0,50.0,50.0,-1.652163,-3.154822,-1.965465,-9.660628,1.270428,…,28.412256,12.462591,6.97438,-0.606796,-2.167728,1.872541,-3.347206,17.680081,13.220269,21.607934,7.103482,9.131075,16.06405,4.739029,-2.458001,-7.432795,-16.962455,9.797321,-17.252629,-19.069925,-11.374684,-17.355672,0.1066,0.745823,0,1,0,0,0,0,0,0,0,0,0,14.673911,0
"""ZTS""",2023-06-01,0,2,3.059,113.412167,93.601906,-6.16167,-2.163021,-2.884217,-4.701764,-150.0,150.0,-150.545009,-274.970853,36.872081,44.714764,49.034177,49.615524,50.134387,1.691554,1.411919,1.591245,1.915851,1.709914,50.0,150.0,145.437831,-50.0,-50.0,-50.0,-50.0,2.462385,4.263394,6.655249,2.150699,0.451685,…,-1.252236,0.662899,-4.544518,25.459067,-1.050388,-0.081241,0.673463,17.098316,-17.010903,-4.570163,3.618482,77.669903,4.406098,1.438693,-4.253029,26.425961,-4.914543,41.615834,-2.508209,-20.640202,1.019514,-5.045032,-2.52209,-0.080502,0,1,0,0,0,0,0,0,0,0,0,3.072419,0
"""ZTS""",2023-09-01,0,2,2.564,82.354987,88.746835,6.435205,13.153876,22.992401,-5.8858,150.0,150.0,200.0,-300.0,62.382436,59.24916,56.672467,55.061073,51.742329,1.613205,1.651363,1.542894,1.870124,1.778463,50.0,150.0,164.216626,50.0,50.0,50.0,50.0,0.308312,5.41382,11.620662,14.178762,-0.152789,…,31.054688,-0.674646,-6.061542,25.550661,3.758961,1.85494,6.046895,18.622645,11.177887,28.516994,4.854933,-94.940476,-1.641055,-0.522938,-3.875178,25.742425,9.935686,-95.970134,11.457617,-22.294022,16.969973,14.758768,6.728389,-0.016014,0,1,0,0,0,0,0,0,0,0,0,-6.69165,0


In [7]:
y_pred = model.predict(X_val)
val_subset = val.with_columns(pl.Series("pred", y_pred)).sort("pred", descending=True)
n = 100
top = val_subset.head(n)
bottom = val_subset.tail(n)

top_freturn = top.select(pl.col(target)).mean().item()
bottom_freturn = bottom.select(pl.col(target)).mean().item()
freturn = val_subset.select(pl.col(target)).mean().item()

print(f"Average freturn: {freturn:.2f}%")

print(f"\nTop {n} stocks: {top_freturn * 100:.2f}%")
print(f"\nBottom {n} stocks: {bottom_freturn * 100:.2f}%")

Average freturn: 0.40%

Top 100 stocks: 34.00%

Bottom 100 stocks: 42.00%


In [8]:
top = val_subset.head(100).filter((pl.col("pe") < 50) & (pl.col("f_score") >= 5))
print(top.select(pl.col(target)).mean().item())
top

0.36363636363636365


tic,tdq,n_purch,n_sales,insider_balance,volume_ma20,volume_ma50,price_mom,price_qoq,price_yoy,price_2y,price_risk_qoq,price_risk_sos,price_risk_yoy,price_risk_2y,rsi_14d,rsi_30d,rsi_60d,rsi_90d,rsi_1y,vol_mom,vol_qoq,vol_sos,vol_yoy,vol_2y,rel_vol_mom,rel_vol_qoq,rel_vol_yoy,momentum_mom,momentum_qoq,momentum_yoy,momentum_2y,index_mom,index_qoq,index_sos,index_yoy,index_2y,…,ltq_yoy,ltq_2y,dlttq_yoy,gpm_yoy,gpm_2y,roa_yoy,roa_2y,roi_yoy,roi_2y,roe_yoy,fcf_yoy,der_yoy,dr_yoy,dr_2y,ltda_yoy,ev_ebitda_yoy,ltcr_yoy,pe_yoy,pe_2y,pb_yoy,ps_yoy,atr_yoy,size_yoy,sector_utilities,sector_health_care,sector_financials,sector_consumer_discretionary,sector_consumer_staples,sector_energy,sector_industrials,sector_information_technology,sector_communication_services,sector_materials,sector_real_estate,fwd_return_4Q,excess_return_4Q_pos,pred
str,date,u32,u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,f64,i8,i64
"""A""",2022-03-01,0,5,19.293,133.226941,117.561988,-3.74289,-10.984423,6.704944,60.949116,-150.0,-150.0,200.0,300.0,43.158399,42.053527,43.715374,45.665021,51.522485,2.161613,1.857107,1.641768,1.424811,1.894748,50.0,150.0,160.441865,50.0,50.0,50.0,50.0,-2.833813,-4.58184,-4.913681,11.264523,39.350786,…,6.221766,11.175586,24.942792,0.418303,2.029907,39.358651,45.107156,-12.893125,20.895223,38.663007,7.142857,-0.991586,-0.494881,2.283261,17.042371,-23.129185,-14.246468,-30.140404,5.066142,-3.130584,-10.827631,-0.76069,0.711764,0,1,0,0,0,0,0,0,0,0,0,15.553968,1,1
"""A""",2022-06-01,0,1,0.165,116.531326,115.558614,2.943879,-8.066325,-9.870404,36.327384,-150.0,-150.0,-200.0,300.0,49.272664,47.577979,46.541324,46.77372,50.376317,3.014695,2.469218,2.175493,1.758636,1.591078,50.0,147.055598,153.737685,-50.0,50.0,50.0,50.0,-1.303128,-5.25865,-10.198599,-2.540092,33.121373,…,-4.56335,13.782803,0.110011,0.008074,0.958421,35.79613,67.232292,21.762787,126.509135,28.22335,-40.042373,-10.37675,-5.083665,2.899704,-0.435782,-28.314049,-40.10826,-35.577858,-28.576028,-17.395771,-21.606108,1.776929,0.059105,0,1,0,0,0,0,0,0,0,0,0,9.318676,1,1
"""A""",2022-09-01,0,2,0.868,79.928753,82.561189,-4.184011,1.018566,-26.514679,27.501967,46.302237,-150.0,-200.0,300.0,46.529941,50.295003,49.851588,49.37193,50.749185,2.316539,2.19982,2.343169,2.031332,1.71518,50.0,148.682818,153.04007,50.0,-27.157106,50.0,50.0,-4.532181,-3.750642,-9.089974,-12.317168,12.482106,…,-2.741208,18.138007,0.146628,1.743347,3.022338,34.230978,74.990489,22.24579,53.513405,30.320847,-2.39521,-5.511298,-2.67627,7.568238,0.213494,-37.978017,-2.538116,-46.471204,-36.421035,-30.24082,-33.770549,8.064162,-0.007209,0,1,0,0,0,0,0,0,0,0,0,-5.57725,0,1
"""A""",2022-12-01,0,14,29.504,84.294291,79.954035,10.809849,20.742453,5.337014,35.332866,150.0,150.0,200.0,300.0,64.82618,61.518024,57.445711,55.2601,52.630406,2.743982,2.300815,2.214761,2.21394,1.804737,50.0,134.867777,145.846538,50.0,50.0,-50.0,50.0,5.717434,4.308392,-2.400146,-9.671307,11.307188,…,-1.674191,9.949516,0.146574,0.950664,3.005054,5.338708,59.422189,-15.912726,47.596638,5.277354,1.587302,-0.117289,-0.059079,0.501708,1.791594,-10.702488,1.438619,-0.676569,-25.389269,4.565081,-5.016431,12.339549,-0.175597,0,1,0,0,0,0,0,0,0,0,0,-26.851359,0,1
"""A""",2023-03-01,0,5,9.322,85.991285,81.411412,-9.37789,-11.272423,5.484807,10.467538,-150.0,150.0,200.0,300.0,27.642313,39.893702,46.853905,48.655084,50.816258,1.430282,1.57085,1.948803,2.161252,1.828145,50.0,148.188834,145.011024,50.0,50.0,-50.0,50.0,-1.652163,-3.154822,-1.965465,-9.660628,1.270428,…,2.648367,9.034908,0.10989,1.406293,1.830478,3.839863,44.709833,17.55397,2.397589,0.886202,16.078431,-5.678431,-2.916963,-3.397409,-5.31781,-10.779239,15.951013,-6.946706,-34.993345,-6.122066,-4.984549,2.849477,0.60311,0,1,0,0,0,0,0,0,0,0,0,-3.250208,0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""AFL""",2022-06-01,1,4,0.706,99.709075,88.81247,4.380458,-0.482524,4.654418,60.606871,-32.525135,150.0,200.0,300.0,54.610327,49.381937,49.720993,50.589655,51.892807,1.588698,1.483543,1.512047,1.393566,1.645896,50.0,88.35318,121.823718,-50.0,9.175821,-50.0,50.0,-1.303128,-5.25865,-10.198599,-2.540092,33.121373,…,-6.770908,-6.145479,-3.956479,-3.434403,26.386743,-20.304906,42.956154,-12.73818,64.467749,-19.439403,-7.759883,1.362602,0.273612,-3.228601,3.300703,-2.620483,-3.960084,33.117502,4.705925,7.240255,5.287569,-6.460572,-0.608476,0,0,1,0,0,0,0,0,0,0,0,12.116626,1,1
"""AFL""",2022-09-01,0,4,2.339,72.651825,78.859641,0.864264,0.60852,5.644303,62.179834,41.50154,-150.0,200.0,300.0,44.451863,50.626418,50.664731,50.748047,51.576331,1.266545,1.46626,1.460371,1.439141,1.525142,50.0,99.102487,108.424517,-19.069493,-16.224432,-45.824681,50.0,-4.532181,-3.750642,-9.089974,-12.317168,12.482106,…,-14.492357,-14.432077,-8.681197,12.046824,29.782344,-10.529083,72.427181,55.534091,89.702007,-3.933021,-46.985447,9.318996,1.813037,-0.89429,8.732322,-10.822963,-41.945633,32.813317,-2.643767,27.589742,8.089467,9.746982,-1.455381,0,0,1,0,0,0,0,0,0,0,0,22.114378,1,1
"""AFL""",2022-12-01,0,5,4.705,102.120517,94.914768,7.633597,20.796246,33.985478,60.872489,150.0,150.0,200.0,300.0,68.400056,65.801368,60.926112,58.377648,53.893066,1.332801,1.56518,1.501103,1.502817,1.391457,50.0,91.746767,99.000261,50.0,50.0,-50.0,50.0,5.717434,4.308392,-2.400146,-9.671307,11.307188,…,-18.180172,-19.091503,-6.79395,3.774564,10.444318,50.016073,38.215451,136.185965,-17.299836,66.232068,-41.932002,13.664245,2.57626,1.659852,16.850992,-7.800877,-37.699326,5.309464,29.475684,75.058101,36.624981,12.404182,-1.886155,0,0,1,0,0,0,0,0,0,0,0,11.275425,1,1
"""AFL""",2023-03-01,0,4,3.293,89.37487,101.888445,-6.736032,-5.296813,11.507616,39.819378,-150.0,150.0,200.0,300.0,39.16616,45.705162,50.513318,51.8188,52.48003,1.628044,1.347271,1.453388,1.491305,1.384386,50.0,127.097152,100.060404,50.0,50.0,-50.0,50.0,-1.652163,-3.154822,-1.965465,-9.660628,1.270428,…,-10.210879,-15.152022,-6.460533,-6.414139,-0.615698,16.158731,10.180748,-75.556261,-73.133092,60.375468,18.735632,48.25013,7.376441,6.326294,11.861382,7.034539,26.9364,12.998248,41.584964,81.221468,24.414011,-7.451168,-1.494696,0,0,1,0,0,0,0,0,0,0,0,21.965642,1,1


In [9]:
eval_dates = [
    "2022-03-01",
    "2022-06-01",
    "2022-09-01",
    "2022-12-01",
    "2023-03-01",
    "2023-06-01",
    "2023-09-01",
    "2023-12-01",
]

for date in eval_dates:
    y_pred = model.predict(X_val)
    val_subset = (
        val.with_columns(pl.Series("pred", y_pred))
        .filter(pl.col("tdq") == pl.lit(date).str.to_date())
        .sort("pred", descending=True)
    )
    n = 50
    top = val_subset.head(n)
    bottom = val_subset.tail(n)

    # Calculate average returns for top stocks
    top_freturn = top.select(pl.col("fwd_return_4Q")).mean().item()

    # Calculate average returns for bottom stocks
    bottom_freturn = bottom.select(pl.col("fwd_return_4Q")).mean().item()

    print(f"\nDATE {date}")
    print(f"\nTop {n} stocks:")
    print(f"Average freturn: {top_freturn * 100:.2f}%")
    print(f"\nBottom {n} stocks:")
    print(f"Average freturn: {bottom_freturn * 100:.2f}%")


DATE 2022-03-01

Top 50 stocks:
Average freturn: 34.90%

Bottom 50 stocks:
Average freturn: -403.35%

DATE 2022-06-01

Top 50 stocks:
Average freturn: -103.18%

Bottom 50 stocks:
Average freturn: -202.87%

DATE 2022-09-01

Top 50 stocks:
Average freturn: 984.15%

Bottom 50 stocks:
Average freturn: 1141.68%

DATE 2022-12-01

Top 50 stocks:
Average freturn: 17.13%

Bottom 50 stocks:
Average freturn: -434.97%

DATE 2023-03-01

Top 50 stocks:
Average freturn: 1340.95%

Bottom 50 stocks:
Average freturn: -139.93%

DATE 2023-06-01

Top 50 stocks:
Average freturn: 2060.39%

Bottom 50 stocks:
Average freturn: 2086.70%

DATE 2023-09-01

Top 50 stocks:
Average freturn: 877.32%

Bottom 50 stocks:
Average freturn: 1577.30%

DATE 2023-12-01

Top 50 stocks:
Average freturn: 2006.79%

Bottom 50 stocks:
Average freturn: 2286.45%


In [10]:
model.get_importance("gain")

[('sector_energy', 40.191165924072266),
 ('momentum_yoy', 27.2480411529541),
 ('size', 25.87224769592285),
 ('vol_sos', 23.586788177490234),
 ('fear_ma30', 20.31974220275879),
 ('index_2y', 16.7189998626709),
 ('high_fear', 16.481006622314453),
 ('sector_real_estate', 15.660086631774902),
 ('sector_health_care', 15.417682647705078),
 ('vol_2y', 14.626969337463379),
 ('gpm_2y', 14.605377197265625),
 ('vol_yoy', 14.286331176757812),
 ('rsi_1y', 14.145264625549316),
 ('sector_information_technology', 13.999712944030762),
 ('sector_utilities', 13.80540943145752),
 ('index_sos', 13.57997989654541),
 ('index_yoy', 13.11374282836914),
 ('rel_vol_yoy', 12.731752395629883),
 ('index_mom', 12.614106178283691),
 ('rel_vol_qoq', 12.564565658569336),
 ('low_fear', 12.17973804473877),
 ('sector_communication_services', 11.731181144714355),
 ('vol_qoq', 11.615520477294922),
 ('price_qoq', 11.137831687927246),
 ('gpm_yoy', 11.062193870544434),
 ('sector_industrials', 10.805398941040039),
 ('ps_yoy', 1

In [11]:
model.get_importance("weight")

[('size', 249.0),
 ('vol_2y', 247.0),
 ('index_2y', 240.0),
 ('fear_ma30', 212.0),
 ('gpm_2y', 197.0),
 ('ev_ebitda', 189.0),
 ('ps', 188.0),
 ('ebitdam', 176.0),
 ('roa_2y', 174.0),
 ('roa', 164.0),
 ('dr_2y', 154.0),
 ('vol_qoq', 153.0),
 ('vol_sos', 153.0),
 ('saleq_2y', 150.0),
 ('pe_2y', 150.0),
 ('ltda_yoy', 148.0),
 ('index_mom', 147.0),
 ('index_sos', 141.0),
 ('vol_yoy', 140.0),
 ('ltcr_yoy', 138.0),
 ('pb', 136.0),
 ('dlttq_yoy', 132.0),
 ('roa_yoy', 124.0),
 ('rsi_1y', 123.0),
 ('ev_ebitda_yoy', 122.0),
 ('gpm_yoy', 121.0),
 ('price_2y', 120.0),
 ('atr_yoy', 118.0),
 ('ps_yoy', 114.0),
 ('volume_ma20', 113.0),
 ('rel_vol_yoy', 113.0),
 ('pe', 113.0),
 ('index_qoq', 112.0),
 ('gpm', 112.0),
 ('ltq_2y', 110.0),
 ('roe', 106.0),
 ('index_yoy', 101.0),
 ('price_qoq', 100.0),
 ('saleq_yoy', 98.0),
 ('vol_mom', 97.0),
 ('der_yoy', 97.0),
 ('size_yoy', 97.0),
 ('niq_2y', 96.0),
 ('price_mom', 88.0),
 ('insider_balance', 81.0),
 ('volume_ma50', 81.0),
 ('price_yoy', 81.0),
 ('dr_yoy

In [12]:
import numpy as np
import shap


def plot_shap(model, X_train):
    explainer = shap.TreeExplainer(model.model)
    shap_values = explainer(X_train)

    feature_names = [
        a + ": " + str(b)
        for a, b in zip(X_train.columns, np.abs(shap_values.values).mean(0).round(2), strict=False)
    ]

    shap.summary_plot(
        shap_values,
        X_train,
        max_display=X_train.shape[1],
        feature_names=feature_names,
        plot_size=(8, 13.5),
    )


#  plot_shap(model, X_train)