In [2]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from my_model.linear_model import layer_4 as MyModel
from my_dataset.linear_dataset import MyDataset

In [3]:
factor_concat = pd.read_pickle("F:\\Neural_Networks\data\\factor_concat_2018_2019.pkl")
stock_return = pd.read_pickle("../data/stock_return.pkl")
print("loaded")

loaded


In [3]:
dataset = MyDataset(factor_concat, stock_return)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, drop_last=False)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
feature_num = len(factor_concat.columns.levels[0])
model = MyModel(975)
model.load_state_dict(torch.load('./log/optimizer_3.pth'))
model.to(device)

MyModel(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=975, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=16, bias=True)
    (5): ReLU()
    (6): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [5]:
score = stock_return.copy()
score.iloc[:,:] = np.nan
score_dict = {}
for date in dataset.date_num_dict.keys():
    score_dict[date] = {}
for date_num, code_num, x, y in tqdm(dataloader):
    date_list = date_num.numpy().flatten()
    code_list = code_num.numpy().flatten()
    x = x.float().to(device)
    y_pred = model(x)
    y_pred = torch.sigmoid(y_pred).to("cpu").detach().numpy().flatten()
    for i in range(len(date_list)):
        date = dataset.num_date_dict[date_list[i]]
        code = dataset.num_code_dict[code_list[i]]
        score_dict[date][code] = y_pred[i]

100%|██████████| 2824/2824 [00:37<00:00, 74.79it/s]


In [6]:
score = pd.DataFrame(score_dict)
score.reindex(stock_return.columns)
score = score.sort_index().T
score.to_pickle("F:\Multifactor_Project\score.pkl")
score

Unnamed: 0,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000006.SZ,000007.SZ,000008.SZ,000009.SZ,000010.SZ,000011.SZ,...,603987.SH,603988.SH,603989.SH,603990.SH,603991.SH,603993.SH,603996.SH,603997.SH,603998.SH,603999.SH
2018-01-02,1.674431e-09,0.000061,9.367107e-01,1.964904e-08,,,2.181135e-03,2.760838e-03,5.561548e-05,8.349029e-03,...,8.993637e-01,1.684310e-09,5.218646e-09,9.910941e-01,,9.910941e-01,3.314729e-04,1.978273e-07,2.179466e-05,0.000034
2018-01-03,7.049394e-03,0.051552,7.180343e-09,1.231847e-09,,,5.317497e-03,8.014403e-01,2.452897e-08,8.402367e-01,...,1.938251e-05,6.693937e-13,8.695626e-09,9.910941e-01,,9.910941e-01,2.982812e-03,1.196739e-01,9.910941e-01,0.991094
2018-01-04,1.345494e-04,0.000167,5.793924e-05,4.346762e-06,,,8.676589e-04,5.084209e-10,1.341131e-10,6.881686e-01,...,1.501812e-06,7.016414e-14,9.787318e-01,9.910941e-01,,1.063346e-04,1.895422e-09,9.468600e-04,4.352411e-11,0.775334
2018-01-05,1.466781e-04,0.000386,7.622755e-02,2.226872e-08,,,1.428700e-01,2.016734e-02,1.795323e-17,1.531438e-18,...,3.755191e-05,6.454463e-13,9.871002e-01,9.581183e-09,,9.674773e-11,7.638127e-06,9.910941e-01,4.052931e-08,0.006881
2018-01-08,9.372060e-05,0.000203,9.910941e-01,6.434327e-10,,,4.259214e-01,1.495047e-01,,3.924577e-14,...,3.285881e-01,3.440670e-07,2.387230e-01,5.932048e-10,,4.222314e-16,5.874323e-03,9.910941e-01,1.785989e-01,0.024010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2018-12-24,1.596441e-01,0.991094,3.692443e-14,8.149223e-24,1.179057e-23,5.078403e-10,4.761026e-22,2.943645e-01,2.692356e-13,9.910941e-01,...,7.846678e-09,4.481966e-21,2.181221e-05,3.371537e-03,9.910941e-01,6.823422e-03,1.373195e-01,1.930915e-10,4.406850e-06,0.000043
2018-12-25,7.199393e-03,0.991094,8.392545e-08,6.631390e-13,1.355390e-22,1.766647e-16,1.943868e-24,9.802596e-01,2.325973e-13,9.910941e-01,...,9.910941e-01,2.411407e-16,5.627196e-03,9.919490e-07,5.853765e-13,2.316767e-04,7.036497e-07,6.291958e-10,3.286891e-03,0.991094
2018-12-26,7.280520e-01,0.991094,3.732641e-11,7.646464e-10,2.345414e-32,6.043851e-12,4.251574e-21,4.708403e-01,2.530946e-10,9.910941e-01,...,7.959883e-01,2.509803e-07,6.028892e-04,4.234668e-13,1.809653e-13,2.887509e-06,9.221887e-09,8.901346e-01,9.899746e-01,0.000003
2018-12-27,1.766204e-02,0.915746,4.002693e-06,7.986500e-12,1.988469e-14,8.318009e-16,1.972729e-20,9.644081e-01,1.566926e-02,9.910941e-01,...,1.302051e-01,3.219269e-09,3.278830e-02,1.320072e-16,2.631806e-15,1.659571e-08,1.965088e-15,7.767084e-01,9.910941e-01,0.000024


In [7]:
del factor_concat

In [85]:
quantile = 5
return_stack = stock_return.stack().dropna()
quantile_return = return_stack.groupby("dt").apply(
    lambda x: pd.qcut(
        x, np.arange(quantile + 1) / quantile, np.arange(quantile)
    )
)
quantile_return[quantile_return < (quantile - 1)] = 0
quantile_return[quantile_return == (quantile - 1)] = 1
quantile_return.unstack().loc["2019-01-01":, :]

code,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000006.SZ,000007.SZ,000008.SZ,000009.SZ,000010.SZ,000011.SZ,...,688786.SH,688787.SH,688788.SH,688789.SH,688793.SH,688798.SH,688799.SH,688800.SH,688819.SH,688981.SH
dt,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2019-01-02,1,0,0,0,0,0,0,0,0,0,...,,,,,,,,,,
2019-01-03,1,0,0,1,0,0,0,0,0,0,...,,,,,,,,,,
2019-01-04,0,0,0,1,0,0,0,0,1,0,...,,,,,,,,,,
2019-01-07,1,0,0,1,0,0,0,0,1,0,...,,,,,,,,,,
2019-01-08,1,0,0,1,0,0,0,0,1,0,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-12-16,0,0,0,1,1,0,0,0,0,0,...,0,0,0,1,1,0,0,0,0,0
2022-12-19,0,0,1,1,1,1,0,1,1,0,...,0,0,0,0,0,0,0,0,0,0
2022-12-20,1,0,1,1,1,0,0,0,0,1,...,0,0,0,0,1,0,0,1,0,0
2022-12-21,0,0,1,1,1,0,0,0,0,1,...,0,0,0,0,1,0,0,0,1,0
