In [2]:
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split

In [3]:
score = []
year_list = [i for i in range(2018, 2024)]
for year in year_list:
    print(year)
    X_train = []
    y_train = []
    for i in [3, 2, 1]:
        factor_stack = pd.read_pickle(f"./data/processed_data/factor_stack_{year - i}.pkl").replace([np.inf, -np.inf], [0, 0])
        stock_return = pd.read_pickle(f"./data/processed_data/quantile_return_{year - i}.pkl")
        factor_stack = factor_stack.loc[stock_return.index, :]
        X_train.append(factor_stack)
        y_train.append(stock_return)
    del factor_stack, stock_return
    X_train = pd.concat(X_train)
    y_train = pd.concat(y_train)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train)
    # model = TabNetClassifier()
    model = LGBMClassifier()
    model.fit(X_train.values, y_train.values, eval_set=[(X_valid.values, y_valid.values)], early_stopping_rounds=200)
    # model.save_model(f"./log/tabnet/all_features/{year}")
    pd.to_pickle(model, f"./log/lgbm/{year}.pkl")
    X_test = pd.read_pickle(f"./data/processed_data/factor_stack_{year}.pkl").replace([np.inf, -np.inf], [0, 0])
    y_test = pd.read_pickle(f"./data/processed_data/quantile_return_{year}.pkl")
    X_test = X_test.loc[y_test.index, :]
    pred = model.predict_proba(X_test.values)[:, 1]
    pred = pd.Series(pred, X_test.index).unstack()
    score.append(pred)

score = pd.concat(score)
score.head()

2018




[1]	valid_0's binary_logloss: 0.692212
[2]	valid_0's binary_logloss: 0.691413
[3]	valid_0's binary_logloss: 0.690714
[4]	valid_0's binary_logloss: 0.690093
[5]	valid_0's binary_logloss: 0.689514
[6]	valid_0's binary_logloss: 0.689026
[7]	valid_0's binary_logloss: 0.688583
[8]	valid_0's binary_logloss: 0.688161
[9]	valid_0's binary_logloss: 0.687747
[10]	valid_0's binary_logloss: 0.687394
[11]	valid_0's binary_logloss: 0.687071
[12]	valid_0's binary_logloss: 0.686805
[13]	valid_0's binary_logloss: 0.68653
[14]	valid_0's binary_logloss: 0.68624
[15]	valid_0's binary_logloss: 0.685953
[16]	valid_0's binary_logloss: 0.685729
[17]	valid_0's binary_logloss: 0.685515
[18]	valid_0's binary_logloss: 0.685298
[19]	valid_0's binary_logloss: 0.68507
[20]	valid_0's binary_logloss: 0.684866
[21]	valid_0's binary_logloss: 0.684666
[22]	valid_0's binary_logloss: 0.684483
[23]	valid_0's binary_logloss: 0.684294
[24]	valid_0's binary_logloss: 0.684084
[25]	valid_0's binary_logloss: 0.683921
[26]	valid_0



[1]	valid_0's binary_logloss: 0.692347
[2]	valid_0's binary_logloss: 0.691654
[3]	valid_0's binary_logloss: 0.691063
[4]	valid_0's binary_logloss: 0.690529
[5]	valid_0's binary_logloss: 0.690048
[6]	valid_0's binary_logloss: 0.689593
[7]	valid_0's binary_logloss: 0.689209
[8]	valid_0's binary_logloss: 0.68885
[9]	valid_0's binary_logloss: 0.68852
[10]	valid_0's binary_logloss: 0.688217
[11]	valid_0's binary_logloss: 0.687914
[12]	valid_0's binary_logloss: 0.68767
[13]	valid_0's binary_logloss: 0.687422
[14]	valid_0's binary_logloss: 0.687173
[15]	valid_0's binary_logloss: 0.686931
[16]	valid_0's binary_logloss: 0.686699
[17]	valid_0's binary_logloss: 0.686499
[18]	valid_0's binary_logloss: 0.686267
[19]	valid_0's binary_logloss: 0.686064
[20]	valid_0's binary_logloss: 0.685874
[21]	valid_0's binary_logloss: 0.685695
[22]	valid_0's binary_logloss: 0.685498
[23]	valid_0's binary_logloss: 0.685335
[24]	valid_0's binary_logloss: 0.685155
[25]	valid_0's binary_logloss: 0.684962
[26]	valid_0



[1]	valid_0's binary_logloss: 0.692284
[2]	valid_0's binary_logloss: 0.691558
[3]	valid_0's binary_logloss: 0.690929
[4]	valid_0's binary_logloss: 0.690382
[5]	valid_0's binary_logloss: 0.689901
[6]	valid_0's binary_logloss: 0.689479
[7]	valid_0's binary_logloss: 0.689076
[8]	valid_0's binary_logloss: 0.688732
[9]	valid_0's binary_logloss: 0.688403
[10]	valid_0's binary_logloss: 0.688111
[11]	valid_0's binary_logloss: 0.687839
[12]	valid_0's binary_logloss: 0.687598
[13]	valid_0's binary_logloss: 0.687371
[14]	valid_0's binary_logloss: 0.687137
[15]	valid_0's binary_logloss: 0.686956
[16]	valid_0's binary_logloss: 0.686741
[17]	valid_0's binary_logloss: 0.686593
[18]	valid_0's binary_logloss: 0.686406
[19]	valid_0's binary_logloss: 0.686253
[20]	valid_0's binary_logloss: 0.68611
[21]	valid_0's binary_logloss: 0.685954
[22]	valid_0's binary_logloss: 0.685818
[23]	valid_0's binary_logloss: 0.685678
[24]	valid_0's binary_logloss: 0.685529
[25]	valid_0's binary_logloss: 0.685392
[26]	valid



[1]	valid_0's binary_logloss: 0.692349
[2]	valid_0's binary_logloss: 0.691703
[3]	valid_0's binary_logloss: 0.691119
[4]	valid_0's binary_logloss: 0.690631
[5]	valid_0's binary_logloss: 0.690183
[6]	valid_0's binary_logloss: 0.689833
[7]	valid_0's binary_logloss: 0.689487
[8]	valid_0's binary_logloss: 0.689171
[9]	valid_0's binary_logloss: 0.688899
[10]	valid_0's binary_logloss: 0.688651
[11]	valid_0's binary_logloss: 0.688361
[12]	valid_0's binary_logloss: 0.688143
[13]	valid_0's binary_logloss: 0.687936
[14]	valid_0's binary_logloss: 0.68771
[15]	valid_0's binary_logloss: 0.687537
[16]	valid_0's binary_logloss: 0.687351
[17]	valid_0's binary_logloss: 0.687171
[18]	valid_0's binary_logloss: 0.687019
[19]	valid_0's binary_logloss: 0.686835
[20]	valid_0's binary_logloss: 0.686686
[21]	valid_0's binary_logloss: 0.686528
[22]	valid_0's binary_logloss: 0.686403
[23]	valid_0's binary_logloss: 0.686278
[24]	valid_0's binary_logloss: 0.686162
[25]	valid_0's binary_logloss: 0.686041
[26]	valid



[1]	valid_0's binary_logloss: 0.692467
[2]	valid_0's binary_logloss: 0.691926
[3]	valid_0's binary_logloss: 0.691429
[4]	valid_0's binary_logloss: 0.690986
[5]	valid_0's binary_logloss: 0.690602
[6]	valid_0's binary_logloss: 0.690273
[7]	valid_0's binary_logloss: 0.689952
[8]	valid_0's binary_logloss: 0.689644
[9]	valid_0's binary_logloss: 0.689404
[10]	valid_0's binary_logloss: 0.689169
[11]	valid_0's binary_logloss: 0.688943
[12]	valid_0's binary_logloss: 0.688722
[13]	valid_0's binary_logloss: 0.688532
[14]	valid_0's binary_logloss: 0.688321
[15]	valid_0's binary_logloss: 0.688154
[16]	valid_0's binary_logloss: 0.687979
[17]	valid_0's binary_logloss: 0.687839
[18]	valid_0's binary_logloss: 0.687682
[19]	valid_0's binary_logloss: 0.687526
[20]	valid_0's binary_logloss: 0.687398
[21]	valid_0's binary_logloss: 0.687226
[22]	valid_0's binary_logloss: 0.687118
[23]	valid_0's binary_logloss: 0.687004
[24]	valid_0's binary_logloss: 0.68689
[25]	valid_0's binary_logloss: 0.686756
[26]	valid



[1]	valid_0's binary_logloss: 0.692475
[2]	valid_0's binary_logloss: 0.691932
[3]	valid_0's binary_logloss: 0.691415
[4]	valid_0's binary_logloss: 0.691001
[5]	valid_0's binary_logloss: 0.690632
[6]	valid_0's binary_logloss: 0.690292
[7]	valid_0's binary_logloss: 0.689963
[8]	valid_0's binary_logloss: 0.689671
[9]	valid_0's binary_logloss: 0.689366
[10]	valid_0's binary_logloss: 0.689121
[11]	valid_0's binary_logloss: 0.688911
[12]	valid_0's binary_logloss: 0.688663
[13]	valid_0's binary_logloss: 0.688472
[14]	valid_0's binary_logloss: 0.688268
[15]	valid_0's binary_logloss: 0.688071
[16]	valid_0's binary_logloss: 0.687904
[17]	valid_0's binary_logloss: 0.687739
[18]	valid_0's binary_logloss: 0.687584
[19]	valid_0's binary_logloss: 0.687454
[20]	valid_0's binary_logloss: 0.687315
[21]	valid_0's binary_logloss: 0.687181
[22]	valid_0's binary_logloss: 0.687034
[23]	valid_0's binary_logloss: 0.686893
[24]	valid_0's binary_logloss: 0.686746
[25]	valid_0's binary_logloss: 0.686569
[26]	vali

Unnamed: 0_level_0,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000008.SZ,000009.SZ,000010.SZ,000011.SZ,000012.SZ,000014.SZ,...,688306.SH,301237.SZ,688197.SH,603209.SH,301226.SZ,603051.SH,301102.SZ,301216.SZ,301258.SZ,301263.SZ
dt,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2018-01-02,0.493846,0.444707,0.60768,0.465254,0.471888,0.522898,0.452501,0.522979,0.459546,0.499655,...,,,,,,,,,,
2018-01-03,0.478124,0.474631,0.466554,0.429369,0.523186,0.519098,0.447803,0.50777,0.446901,0.529041,...,,,,,,,,,,
2018-01-04,0.470652,0.449117,0.50404,0.472498,0.507653,0.459414,0.492937,0.534335,0.434041,0.53294,...,,,,,,,,,,
2018-01-05,0.450209,0.461409,0.515412,0.452461,0.559107,0.455729,0.455199,0.488204,0.496707,,...,,,,,,,,,,
2018-01-08,0.490038,0.475742,0.52239,0.433434,0.512971,0.466106,,0.422902,0.49505,,...,,,,,,,,,,


In [17]:
score = []
year_list = [i for i in range(2018, 2024)]
for year in year_list:
    model.load_model(f"./log/tabnet/all_features/{year}.zip")
    X_test = pd.read_pickle(f"./data/processed_data/factor_stack_{year}.pkl").replace([np.inf, -np.inf], [0, 0])
    y_test = pd.read_pickle(f"./data/processed_data/quantile_return_{year}.pkl")
    X_test = X_test.loc[y_test.index, :]
    pred = model.predict_proba(X_test.values)[:, 1]
    pred = pd.Series(pred, X_test.index).unstack()
    score.append(pred)
score = pd.concat(score)
score.head()



Unnamed: 0_level_0,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000008.SZ,000009.SZ,000010.SZ,000011.SZ,000012.SZ,000014.SZ,...,688306.SH,301237.SZ,688197.SH,603209.SH,301226.SZ,603051.SH,301102.SZ,301216.SZ,301258.SZ,301263.SZ
dt,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2018-01-02,0.450708,0.489603,0.426504,0.42052,0.509872,0.756832,0.435362,0.396833,0.252163,0.689781,...,,,,,,,,,,
2018-01-03,0.438138,0.518821,0.394988,0.418144,0.620185,0.678303,0.402919,0.40957,0.507633,0.673246,...,,,,,,,,,,
2018-01-04,0.450056,0.347225,0.37555,0.430921,0.640478,0.612415,0.451481,0.421166,0.259603,0.671142,...,,,,,,,,,,
2018-01-05,0.361006,0.374675,0.416526,0.417695,0.630809,0.710633,0.378134,0.333742,0.360839,,...,,,,,,,,,,
2018-01-08,0.373535,0.608893,0.460506,0.426929,0.65096,0.775479,,0.316277,0.507321,,...,,,,,,,,,,


In [4]:
score.to_pickle("./data/score/score_lgbm.pkl")