In [1]:
from get_stock_data import Downloader, mkdir
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import catboost as cb
import datetime

# 获取全部股票的日K线数据
mkdir('stockdata/d_data')
raw_train_path = 'stockdata/d_train'
raw_test_path = 'stockdata/d_test'
train_path = 'stockdata/d_data/train.csv'
test_path = 'stockdata/d_data/test.csv'
mode = 'train'

In [2]:
if mode == 'debug':
    train = pd.read_csv(train_path, nrows=100000)
    test = pd.read_csv(test_path, nrows=100000)
else:
    train = pd.read_csv(train_path)
    test = pd.read_csv(test_path)

In [3]:
train.date = train.date.apply(lambda x: int(x.replace('-', '')))
test.date = test.date.apply(lambda x: int(x.replace('-', '')))

In [4]:
train = train[train['date'] >= 20220301].reset_index(drop=True)

In [5]:
def feature_engineer(train, test, split=20220501):
    train_len = len(train)
    data = pd.concat((train, test), sort=False).reset_index(drop=True)
    data = data.sort_values(by=['code', 'date'])
    
    stock_industry = pd.read_csv("stock_industry.csv", encoding="gbk")
    from sklearn.preprocessing import LabelEncoder
    lbe = LabelEncoder()
    stock_industry['industry'] = lbe.fit_transform(stock_industry['industry'])
    data = pd.merge(data, stock_industry[['code', 'industry']], how='left', on='code')

    # alpha net 
    length = 30
    for name in tqdm(['open', 'high', 'low', 'close', 'volume', 'amount', 'adjustflag', 'turn', 'pctChg', 'peTTM', 'psTTM', 'pcfNcfTTM', 'pbMRQ']):
#     for name in tqdm(['open']):
        roll_feature = []
        for i, group in data.groupby('code', sort=False)[name]:
            values = group.tolist()
            values = [0]*(length - 1) + values
            roll_feature = roll_feature + [values[i:i+length] for i in range(len(group))]
        data = pd.concat([data, pd.DataFrame(roll_feature, columns=[f'{name}_{i}' for i in range(length)])], axis=1).reset_index(drop=True)
        
    # generate label
    data['label'] = data.groupby('code').close.transform(lambda x:(x - x.shift(-14)) / (x + 1e-7) )
    data = data.dropna(subset = ['label'], inplace=False)
    data = data.replace(np.nan, 0)
    
    return data[data['date'] <= split].reset_index(drop=True), data[data['date'] > split].reset_index(drop=True)

In [6]:
train, test = feature_engineer(train, test)

# f_train_path = 'stockdata/d_data/f_train_debug.csv'
# f_test_path = 'stockdata/d_data/f_test_debug.csv'
# train.to_csv(f_train_path, index=False)
# test.to_csv(f_test_path, index=False)

100%|██████████| 13/13 [04:17<00:00, 19.79s/it]


In [7]:
ycol = 'label'
feature_names = list(
    filter(lambda x: x not in [ycol, 'code', 'date', ''], train.columns))

# print(feature_names)

In [8]:
def label_quantile(x):
    if x<quantile_30:
        return 0
    elif x<quantile_70:
        return 1
    else:
        return 2

In [9]:
quantile_30, quantile_70 = train.label.quantile([0.3, 0.7]).values

train['hard_label'] = train.label.apply(label_quantile)
test['hard_label'] = test.label.apply(label_quantile)

In [10]:
from_file = cb.CatBoostClassifier()

model = from_file.load_model("model/next_2week_alphanet30_1year.model")

In [11]:
X_val = test[feature_names]
Y_val = test[ycol]

test['preds'] = model.predict(X_val)
test['pred_probs'] = model.predict_proba(X_val)

In [24]:
model.predict_proba(X_val)

array([[0.14798592, 0.68496401, 0.16705007],
       [0.15632642, 0.65135681, 0.19231677],
       [0.15307564, 0.64996177, 0.19696259],
       ...,
       [0.23522732, 0.49584386, 0.26892883],
       [0.22686717, 0.51155928, 0.26157355],
       [0.21858472, 0.55528349, 0.22613178]])

In [12]:
test

Unnamed: 0,date,code,open,high,low,close,preclose,volume,amount,adjustflag,...,pbMRQ_23,pbMRQ_24,pbMRQ_25,pbMRQ_26,pbMRQ_27,pbMRQ_28,pbMRQ_29,label,hard_label,preds
0,20220505,sh.000001,3044.8491,3082.2295,3042.1155,3067.7587,3047.0624,3.830729e+10,4.093520e+11,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.012943,0,1
1,20220506,sh.000001,3011.3187,3030.6926,2992.7152,3001.5605,3067.7587,3.432642e+10,3.462144e+11,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.040495,0,1
2,20220509,sh.000001,2990.1990,3015.9425,2983.6119,3004.1409,3001.5605,2.920616e+10,2.999833e+11,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.041975,0,1
3,20220510,sh.000001,2965.7759,3043.7789,2957.3968,3035.8442,3004.1409,3.706610e+10,3.840126e+11,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.037293,0,1
4,20220511,sh.000001,3035.3893,3100.8998,3034.6745,3058.7027,3035.8442,4.239919e+10,4.729634e+11,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.041758,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145634,20220607,sz.399998,2523.9088,2584.3235,2492.7403,2538.4848,2538.8463,1.895658e+09,1.732184e+10,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.008656,1,1
145635,20220608,sz.399998,2559.4755,2671.5649,2558.9731,2671.4231,2538.4848,2.809501e+09,2.723744e+10,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.035723,1,1
145636,20220609,sz.399998,2660.1679,2744.8170,2632.7360,2677.0317,2671.4231,2.964074e+09,2.919304e+10,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.051611,1,1
145637,20220610,sz.399998,2643.9133,2743.9429,2634.8225,2723.6248,2677.0317,2.873916e+09,2.541947e+10,3,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.069272,1,1


In [13]:
test[test['hard_label']==2]

Unnamed: 0,date,code,open,high,low,close,preclose,volume,amount,adjustflag,...,pbMRQ_23,pbMRQ_24,pbMRQ_25,pbMRQ_26,pbMRQ_27,pbMRQ_28,pbMRQ_29,label,hard_label,preds
6224,20220606,sh.600006,9.05,9.05,9.05,9.05,8.23,13948533.0,1.262342e+08,3,...,1.518438,1.518438,1.670773,1.837851,1.837851,2.022127,2.223603,0.132597,2,1
6225,20220607,sh.600006,9.40,9.86,9.05,9.26,9.05,368996416.0,3.464790e+09,3,...,1.518438,1.670773,1.837851,1.837851,2.022127,2.223603,2.275200,0.138229,2,2
6226,20220608,sh.600006,9.25,10.19,8.66,10.06,9.26,390761088.0,3.700435e+09,3,...,1.670773,1.837851,1.837851,2.022127,2.223603,2.275200,2.471762,0.127237,2,2
6227,20220609,sh.600006,9.31,9.38,9.05,9.05,10.06,186466710.0,1.699867e+09,3,...,1.837851,1.837851,2.022127,2.223603,2.275200,2.471762,2.223603,0.127072,2,2
6229,20220613,sh.600006,8.13,8.94,8.13,8.70,8.30,195553144.0,1.700105e+09,3,...,2.022127,2.223603,2.275200,2.471762,2.223603,2.039326,2.137607,0.129885,2,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
137637,20220523,sz.301288,27.37,27.68,26.07,26.94,28.19,5873693.0,1.569430e+08,3,...,8.385158,8.032987,7.785594,7.771042,7.852536,8.204707,7.840894,0.120267,2,2
137654,20220607,sz.301298,26.10,32.94,25.62,30.05,27.46,24182097.0,7.274795e+08,3,...,0.000000,0.000000,0.000000,0.000000,0.000000,8.875088,9.712178,0.114809,2,2
137656,20220609,sz.301298,28.28,31.70,27.40,30.75,29.30,19572044.0,5.852985e+08,3,...,0.000000,0.000000,0.000000,8.875088,9.712178,9.469777,9.938418,0.165203,2,2
137657,20220610,sz.301298,30.00,33.33,28.55,33.33,30.75,19943154.0,6.118434e+08,3,...,0.000000,0.000000,8.875088,9.712178,9.469777,9.938418,10.772276,0.236424,2,2


In [25]:
test[test['preds']==0]

Unnamed: 0,date,code,open,high,low,close,preclose,volume,amount,adjustflag,...,pbMRQ_23,pbMRQ_24,pbMRQ_25,pbMRQ_26,pbMRQ_27,pbMRQ_28,pbMRQ_29,label,hard_label,preds
6681,20220510,sh.600026,7.56,7.68,7.24,7.36,8.04,74085083.0,5.448044e+08,3,...,1.237688,1.300989,1.314315,1.332085,1.385368,1.338745,1.225518,-0.302989,0,0
7099,20220506,sh.600051,6.35,6.41,6.21,6.25,6.55,16520000.0,1.041189e+08,3,...,0.826003,0.745002,0.670502,0.636729,0.646662,0.650635,0.620835,-0.075200,0,0
7100,20220509,sh.600051,6.26,6.45,6.25,6.41,6.25,11986188.0,7.640041e+07,3,...,0.745002,0.670502,0.636729,0.646662,0.650635,0.620835,0.636729,-0.054602,0,0
7112,20220525,sh.600051,6.48,6.74,6.46,6.70,6.51,10334470.0,6.868015e+07,3,...,0.650635,0.660569,0.666529,0.676462,0.696329,0.646662,0.665535,0.028358,1,0
7508,20220531,sh.600066,8.43,8.43,8.05,8.12,8.50,74141889.0,6.065845e+08,3,...,1.293693,1.259414,1.216191,1.277299,1.292203,1.266866,1.210229,-0.004926,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
137607,20220519,sz.301279,24.70,25.49,24.37,25.27,25.09,2505762.0,6.303300e+07,3,...,4.800301,4.712424,4.710593,4.653839,4.478084,4.593423,4.626377,0.012663,1,0
137608,20220520,sz.301279,25.35,25.53,24.80,25.05,25.27,2360934.0,5.919936e+07,3,...,4.712424,4.710593,4.653839,4.478084,4.593423,4.626377,4.586100,0.037126,1,0
137609,20220523,sz.301279,24.86,25.17,24.60,25.13,25.05,2084686.0,5.202261e+07,3,...,4.710593,4.653839,4.478084,4.593423,4.626377,4.586100,4.600747,0.027457,1,0
137610,20220524,sz.301279,25.13,25.33,23.28,23.33,25.13,3604171.0,8.713657e+07,3,...,4.653839,4.478084,4.593423,4.626377,4.586100,4.600747,4.271206,-0.064295,0,0


In [26]:
test[test['preds']==2]

Unnamed: 0,date,code,open,high,low,close,preclose,volume,amount,adjustflag,...,pbMRQ_23,pbMRQ_24,pbMRQ_25,pbMRQ_26,pbMRQ_27,pbMRQ_28,pbMRQ_29,label,hard_label,preds
6215,20220524,sh.600006,5.89,6.34,5.63,6.18,5.76,125825439.0,7.528270e+08,3,...,1.221139,1.275193,1.275193,1.294849,1.314505,1.415243,1.518438,-0.407767,0,2
6225,20220607,sh.600006,9.40,9.86,9.05,9.26,9.05,368996416.0,3.464790e+09,3,...,1.518438,1.670773,1.837851,1.837851,2.022127,2.223603,2.275200,0.138229,2,2
6226,20220608,sh.600006,9.25,10.19,8.66,10.06,9.26,390761088.0,3.700435e+09,3,...,1.670773,1.837851,1.837851,2.022127,2.223603,2.275200,2.471762,0.127237,2,2
6227,20220609,sh.600006,9.31,9.38,9.05,9.05,10.06,186466710.0,1.699867e+09,3,...,1.837851,1.837851,2.022127,2.223603,2.275200,2.471762,2.223603,0.127072,2,2
6228,20220610,sh.600006,8.17,9.07,8.17,8.30,9.05,230373422.0,1.969980e+09,3,...,1.837851,2.022127,2.223603,2.275200,2.471762,2.223603,2.039326,0.090361,1,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
137654,20220607,sz.301298,26.10,32.94,25.62,30.05,27.46,24182097.0,7.274795e+08,3,...,0.000000,0.000000,0.000000,0.000000,0.000000,8.875088,9.712178,0.114809,2,2
137655,20220608,sz.301298,27.60,30.95,26.58,29.30,30.05,19599110.0,5.513417e+08,3,...,0.000000,0.000000,0.000000,0.000000,8.875088,9.712178,9.469777,0.048805,1,2
137656,20220609,sz.301298,28.28,31.70,27.40,30.75,29.30,19572044.0,5.852985e+08,3,...,0.000000,0.000000,0.000000,8.875088,9.712178,9.469777,9.938418,0.165203,2,2
137657,20220610,sz.301298,30.00,33.33,28.55,33.33,30.75,19943154.0,6.118434e+08,3,...,0.000000,0.000000,8.875088,9.712178,9.469777,9.938418,10.772276,0.236424,2,2
