File to compute the baselines with skitlearn

In [28]:
from master_bert import MASTERModel
import pickle
import numpy as np
import time

from utils import load_all_csv_data_with_market_indexes, load_all_csv_data_without_index, csvs_to_qlib_df, PandasDataLoader
# Please install qlib first before load the data.

# Qlib
# import qlib
# from qlib.config import REG_US           # S&P 500 is a US market
# qlib.init(provider_uri=".", region=REG_US)   # provider_uri just needs to exist





# ------------------------------------------------------------
# 1.  Init Qlib and build *one* handler
import qlib, pandas as pd, numpy as np, torch
qlib.init()                               # client mode is fine

from qlib.data.dataset.loader import StaticDataLoader
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset import TSDatasetH          # <-- here
from qlib.data.dataset.processor import (
    DropnaProcessor, CSZScoreNorm, DropnaLabel,
)

# your tensor, names, dates exactly as before  ----------------
stock_tensor, stock_names, feature_names = load_all_csv_data_without_index()
# stock_tensor, stock_names, feature_names = load_all_csv_data_with_market_indexes()
N, T, K   = stock_tensor.shape
print("Shape: ", stock_tensor.shape)
# dates     = pd.read_csv("data/enriched/market_indexes_aggregated.csv")["Date"]
dates = pd.to_datetime(                     # <-- NEW
    pd.read_csv("data/enriched/market_indexes_aggregated.csv")["Date"]
)

# tensor ➜ tidy multi-index frame --------------------------------
def tensor_to_df(tensor, inst, feats, dt_index):
    flat = tensor.numpy().reshape(N * T, K)
    idx  = pd.MultiIndex.from_product([dt_index, inst],
                                      names=["datetime", "instrument"])
    cols = pd.MultiIndex.from_product([["feature"], feats])
    return pd.DataFrame(flat, index=idx, columns=cols)

df_raw = tensor_to_df(stock_tensor, stock_names, feature_names, dates)

# optional: build a forward-return label
df_raw[("label", "FWD_RET")] = (
    df_raw[("feature", "Adjusted Close")]
      .groupby("instrument").shift(-1) / df_raw[("feature", "Adjusted Close")] - 1
)

last_date = dates.iloc[-1]
df_raw = df_raw.drop(index=last_date, level="datetime")

# handler with learn / infer processors ------------------------
# proc_feat = [
#     {"class": "DropnaProcessor", "kwargs": {"fields_group": "feature"}},
#     {"class": "CSZScoreNorm",   "kwargs": {"fields_group": "feature"}},
# ]

# proc_feat = [
#     {"class": "CSZScoreNorm",   "kwargs": {"fields_group": "feature"}},
# ]

proc_feat = [
    {"class": "Fillna",          # <— correct name
     "kwargs": {"fields_group": "feature", "fill_value": 0}},  # zero-fill; choose ffill/bfill/etc. if you like
    # {"class": "CSZScoreNorm",
    #  "kwargs": {"fields_group": "feature"}},
]

proc_label = [{"class": "DropnaLabel"}]

handler = DataHandlerLP(
    data_loader      = StaticDataLoader(df_raw),
    infer_processors = proc_feat,          # what the model will see later
    learn_processors = proc_feat + proc_label,
)
handler.fit_process_data()                 # learn z-scores, etc.

# ------------------------------------------------------------
# 2.  Attach time splits in a TSDatasetH
split = {
    "train": (dates.iloc[0],              dates.iloc[int(T*0.8) - 1]),
    "valid": (dates.iloc[int(T*0.8)],     dates.iloc[int(T*0.9) - 1]),
    "test" : (dates.iloc[int(T*0.9)],     dates.iloc[-2]),
}

ts_ds = TSDatasetH(
    handler  = handler,
    segments = split,
    step_len = 8,          # same window the MASTER code expects
)

dl_train = ts_ds.prepare("train")   # ➜ TSDataSampler
dl_valid = ts_ds.prepare("valid")
dl_test  = ts_ds.prepare("test")

print(len(dl_train), len(dl_valid), len(dl_test))
#  → continue with your for-loop over seeds exactly as before
# ------------------------------------------------------------



[1340207:MainThread](2025-05-15 19:58:32,752) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[1340207:MainThread](2025-05-15 19:58:32,754) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[1340207:MainThread](2025-05-15 19:58:32,755) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/gabrielecarrino/.qlib/qlib_data/cn_data')}


Shape:  torch.Size([336, 3764, 224])


[1340207:MainThread](2025-05-15 19:58:52,365) INFO - qlib.timer - [log.py:127] - Time cost: 0.003s | Loading data Done
[1340207:MainThread](2025-05-15 19:58:54,130) INFO - qlib.timer - [log.py:127] - Time cost: 0.094s | Fillna Done
[1340207:MainThread](2025-05-15 19:58:54,346) INFO - qlib.timer - [log.py:127] - Time cost: 0.094s | Fillna Done
[1340207:MainThread](2025-05-15 19:58:54,473) INFO - qlib.timer - [log.py:127] - Time cost: 0.126s | DropnaLabel Done
[1340207:MainThread](2025-05-15 19:58:54,475) INFO - qlib.timer - [log.py:127] - Time cost: 2.109s | fit & process data Done
[1340207:MainThread](2025-05-15 19:58:54,475) INFO - qlib.timer - [log.py:127] - Time cost: 2.114s | Init data Done
[1340207:MainThread](2025-05-15 19:58:56,218) INFO - qlib.timer - [log.py:127] - Time cost: 0.086s | Fillna Done
[1340207:MainThread](2025-05-15 19:58:56,428) INFO - qlib.timer - [log.py:127] - Time cost: 0.087s | Fillna Done
[1340207:MainThread](2025-05-15 19:58:56,549) INFO - qlib.timer - [log

1011696 126336 126336


In [36]:
# grab the very first sample
sample = dl_train[0]

# this will print something like (step_len, num_features)
print("Sample shape:", sample.shape)

# so the number of features is the second entry:
print("Number of features:", sample.shape[1])

Sample shape: (8, 225)
Number of features: 225


In [37]:
first_element = dl_train[0]
print(first_element)

[[           nan            nan            nan ...            nan
             nan            nan]
 [           nan            nan            nan ...            nan
             nan            nan]
 [           nan            nan            nan ...            nan
             nan            nan]
 ...
 [           nan            nan            nan ...            nan
             nan            nan]
 [           nan            nan            nan ...            nan
             nan            nan]
 [ 2.5836910e+01  2.6230330e+01  2.5987420e+06 ...  2.6354023e+01
  -7.9903883e-01 -4.6501362e-01]]


In [50]:
sample = dl_test[0]    # The first 6 samples have nan!!!
print(type(sample))
print(len(sample))
print(sample)

<class 'numpy.ndarray'>
8
[[ 1.2448000e+02  1.2827000e+02  1.3756000e+06 ...  1.2529812e+02
  -1.0189654e+00 -1.1230236e-01]
 [ 1.1251000e+02  1.1324000e+02  1.2317000e+06 ...  1.1262406e+02
  -1.5748899e+00  5.0312972e-01]
 [ 1.6992999e+02  1.7081000e+02  1.4009000e+06 ...  1.7469330e+02
  -9.0097588e-01 -3.0856305e-01]
 ...
 [ 2.3150000e+01  2.3200001e+01  1.6852400e+07 ...  2.3418959e+01
  -1.8052806e-01  2.7266800e-01]
 [ 2.9870001e+01  3.0299999e+01  7.8257000e+06 ...  3.0379154e+01
   4.4281870e-01  2.1042514e-01]
 [ 3.4820000e+01  3.5299999e+01  1.1510500e+07 ...  3.5295067e+01
   1.9438314e+00  4.4736218e-01]]


In [39]:
features = sample[:, :-1]
labels = sample[:, -1]
print("Features shape:", features.shape)
print("Labels shape:", labels.shape)
print("First row of features:", features[0])
print("First label:", labels[0])

Features shape: (8, 224)
Labels shape: (8,)
First row of features: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan]
First label: nan


In [33]:
print(stock_tensor.shape)

torch.Size([336, 3764, 224])


In [34]:
# Drop all rows in df_raw where the datetime is the last date in 'dates'
# last_date = dates.iloc[-1]
# df_raw = df_raw.drop(index=last_date, level="datetime")

In [35]:
df_raw

Unnamed: 0_level_0,Unnamed: 1_level_0,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,feature,label
Unnamed: 0_level_1,Unnamed: 1_level_1,Low,Open,Volume,High,Close,Adjusted Close,ABER_ZG_5_15,ABER_SG_5_15,ABER_XG_5_15,ABER_ATR_5_15,...,VTXP_14,VTXM_14,VWAP_D,VWMA_10,WCP,WILLR_14,WMA_10,ZL_EMA_10,ZS_30,FWD_RET
datetime,instrument,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
2008-01-02,A,25.836910,26.230330,2598742.0,26.323318,25.965666,23.579496,26.494516,27.074533,25.914499,0.580017,...,0.923283,1.068689,26.041965,26.186686,26.022890,-58.490540,26.367668,26.354023,-0.799039,-0.465014
2008-01-02,AAL,25.658083,25.965666,2789569.0,26.001431,25.708155,23.345657,26.321888,26.886127,25.757648,0.564239,...,0.886000,1.159000,25.789223,26.249157,25.768955,-62.352940,26.277409,26.080526,-1.267800,-0.478854
2008-01-02,AAP,24.835480,25.450644,4939274.0,25.643778,24.871244,22.585649,26.013353,26.598154,25.428551,0.584802,...,0.779037,1.098206,25.116835,26.101868,25.055437,-98.220673,26.024191,25.506905,-2.399250,-0.446937
2008-01-02,AAPL,25.071531,25.278971,4901108.0,25.665236,25.278971,22.955914,25.736290,26.335037,25.137543,0.598748,...,0.820000,1.044546,25.338579,25.975060,25.323677,-77.935921,25.854597,25.283384,-1.620026,-0.468874
2008-01-02,ABC,25.143063,25.293276,5680773.0,25.572247,25.243204,22.923433,25.521221,26.108664,24.933777,0.587443,...,0.904943,1.011407,25.319504,25.777113,25.300430,-79.715340,25.692549,25.144722,-1.545506,-0.451969
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-12-09,XEL,54.029999,54.029999,1641200.0,55.759998,54.990002,52.839680,52.978001,54.819992,51.136009,1.841992,...,1.179321,0.689813,54.926666,52.587219,54.942501,-9.999960,52.887455,53.657833,1.929239,-0.107867
2022-12-09,XOM,54.160000,54.619999,1327500.0,55.400002,54.549999,52.416885,53.442001,55.243858,51.640141,1.801860,...,1.177995,0.811385,54.703335,52.899105,54.665001,-21.116915,53.256184,54.149136,1.738636,-0.102389
2022-12-09,YUM,53.700001,54.029999,1185100.0,55.549999,55.320000,53.156773,53.974667,55.779736,52.169598,1.805069,...,1.069615,0.858144,54.856667,53.265713,54.972500,-7.678860,53.709091,54.923840,2.072297,-0.108674
2022-12-09,ZBH,54.240002,54.869999,1240800.0,55.980000,55.959999,54.143669,54.585999,56.386730,52.785271,1.800731,...,1.116674,0.791902,55.393333,53.649754,55.535000,-0.361672,54.212181,55.628593,2.195595,-0.122889
