In [1]:
import math
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
from datetime import datetime
import idx_w_stock_basic_filter
from sklearn.preprocessing import normalize

return_threshold = 0.097 # 涨停bar

## Helper Function

In [2]:
def get_date_interval_idx(time,start,end):
    '''
    得到时间区间开头和结尾的idx
    '''
    time = np.array(time)
    start_idx = np.where(time>=start)
    end_idx   = np.where(time<=end)
    return start_idx[0][0],end_idx[0][-1]+1

def complete_all_datelist_info(stock_data5,stock_data10,datelist):
    '''
    个股某一天某些数据有nan则这一天的个股数据都是nan
    '''
    idx5  = np.unique((stock_data5 != stock_data5).nonzero()[0])
    idx10 = np.unique((stock_data10 != stock_data10).nonzero()[0])
    stock_data5[idx5,1:]   = np.nan
    stock_data10[idx10,1:] = np.nan

    complete_stock_data_5  = np.empty((len(datelist),stock_data5.shape[1]))
    complete_stock_data_10 = np.empty((len(datelist),stock_data10.shape[1]))
    iter_ = 0
    for date in datelist:
        if date in stock_data5[:,0]:
            idx = np.where(stock_data5[:,0] == date)
            idx = idx[0][0]
            complete_stock_data_5[iter_,:] = stock_data5[idx,:]
            complete_stock_data_10[iter_,:] = stock_data10[idx,:]
        else:
            complete_stock_data_5[iter_,0]  = date
            complete_stock_data_5[iter_,1:] = np.nan
            complete_stock_data_10[iter_,0]  = date
            complete_stock_data_10[iter_,:] = np.nan
        iter_ += 1
    return complete_stock_data_5,complete_stock_data_10


## 读取数据

In [3]:
def load_data(ts_code):
    mat_daily       = sio.loadmat('../daily/' + ts_code + '.mat')
    mat_daily_basic = sio.loadmat('../daily_basic/'+ ts_code + '.mat')
    daily           = mat_daily['daily']
    daily_basic     = mat_daily_basic['daily_basic']
    start_date      = 20160101
    end_date        = 20200801
    '''
    daily
    '''
    # read data starting from 2016/01 to 2020/08/01
    time_daily   = daily[0,0]['time']
    start_index_daily,end_index_daily  = get_date_interval_idx(time_daily,start_date,end_date)
    open_        = daily[0,0]['open'][start_index_daily:end_index_daily]
    high         = daily[0,0]['high'][start_index_daily:end_index_daily]
    low          = daily[0,0]['low'][start_index_daily:end_index_daily]
    close        = daily[0,0]['close'][start_index_daily:end_index_daily]
    volume       = daily[0,0]['volume'][start_index_daily:end_index_daily] # 成交量
    turnover     = daily[0,0]['turnover'][start_index_daily:end_index_daily] # 成交额
    adj_factor   = daily[0,0]['adj_factor'][start_index_daily:end_index_daily] # 复权因子
    time_daily   = time_daily[start_index_daily:end_index_daily]
    matrix_daily = np.concatenate((time_daily,open_,high,low,close,volume,turnover,adj_factor),axis=1)
    matrix_daily = matrix_daily[np.lexsort((matrix_daily[:,-1],matrix_daily[:,0]))] # sort the matrix by time then by adj_factor
    # remove duplicate daily: duplicate rows of same time
    _,de_duplicate_index = np.unique(matrix_daily[:,0],return_index=True)
    matrix_daily = matrix_daily[de_duplicate_index]

    '''
    daily_basic
    '''
    # read data starting from 2016/01 to 2020/08/01
    time_daily_basic  = daily_basic[0,0]['time'] # obtain time in the format yyyymmdd as integer
    start_index_daily_basic,end_index_daily_basic = get_date_interval_idx(time_daily_basic,start_date,end_date)
    turnover_rate      = daily_basic[0,0]['turnover_rate'][start_index_daily_basic:end_index_daily_basic]
    turnover_rate_free = daily_basic[0,0]['turnover_rate_free'][start_index_daily_basic:end_index_daily_basic]
    float_share        = daily_basic[0,0]['float_share'][start_index_daily_basic:end_index_daily_basic]
    free_share         = daily_basic[0,0]['free_share'][start_index_daily_basic:end_index_daily_basic]
    time_daily_basic   = time_daily_basic[start_index_daily_basic:end_index_daily_basic]
    matrix_daily_basic = np.concatenate((time_daily_basic,turnover_rate,turnover_rate_free,float_share,free_share),axis=1)
    matrix_daily_basic = matrix_daily_basic[matrix_daily_basic[:,0].argsort()] # sort the matrix by time   
    # remove duplicate daily_basic: duplicate rows of same time
    _,de_duplicate_index = np.unique(matrix_daily_basic[:,0],return_index=True)
    matrix_daily_basic = matrix_daily_basic[de_duplicate_index]
    return matrix_daily, matrix_daily_basic

## 前复权; vwap, turn, free turn; 得第二交易日是否涨跌停; Label: return5 and return10

In [4]:
def split_adjust(daily):
    last_adj = daily[:,-1][-1]
    adj = (daily[:,-1] / last_adj).reshape(daily.shape[0],1)
    daily[:,1:5] = daily[:,1:5]*adj
    daily = daily[:,:7]
    return daily

def get_return_vwap_turnover(daily,daily_basic):
    # get return
    close          = daily[:,4].reshape(len(daily),1)
    close_shift    = np.roll(close,-1)
    close_shift[-1] = np.nan
    return1        = (close_shift - close) / close
    # get vwap
    volume         = daily[:,5].reshape(len(daily),1)
    turnover       = daily[:,6].reshape(len(daily),1)
    vwap           = turnover / volume
    # get turnover 
    float_share    = daily_basic[:,3].reshape(len(daily_basic),1)
    free_share     = daily_basic[:,4].reshape(len(daily_basic),1)
    turn           = volume / float_share
    free_turn      = volume / free_share
    daily      = daily[:,:5]
    stock_data = np.concatenate((daily,vwap,volume,return1,turn,free_turn),axis=1)
    return stock_data

def get_reach_limit(stock_data):
    return1     = stock_data[:,7].reshape(len(stock_data),1)
    limit_up    = (return1 > return_threshold)*1  #涨停
    limit_down  = (return1 < -return_threshold)*1 #跌停
    limit       = limit_up+limit_down             #涨跌停的交易日
    # 每个交易日得到下一个交易日是否涨跌停
    limit       = np.roll(limit,-1).reshape(len(stock_data),1)
    return np.concatenate((stock_data,limit),axis=1)

def get_return(stock_data):
    close   = stock_data[:,4].reshape(len(stock_data),1)
    close5  = np.roll(close,-5).reshape(len(stock_data),1)
    close5[-5:,:]  = np.nan
    close10 = np.roll(close,-10).reshape(len(stock_data),1)
    close10[-10:,:] = np.nan
    return5  = (close5 - close) / close
    return10 = (close10 - close) / close
    stock_data5  = np.concatenate((stock_data,return5),axis=1)
    stock_data10 = np.concatenate((stock_data,return10),axis=1)
    return stock_data5, stock_data10

def get_datelist(ts_code_list):
    datelist = np.array([])
    for ts_code in ts_code_list:
        daily, daily_basic = load_data(ts_code)
        date_daily       = daily[:,0]
        date_daily_basic = daily_basic[:,0]
        datelist = np.concatenate((datelist,date_daily,date_daily_basic),axis=0)
        datelist = np.unique(np.sort(datelist,axis=0))
    return datelist

## 处理原始数据，为生成数据图片准备

In [5]:
def process_raw_data(ts_code:str,datelist:list):
    daily,daily_basic = load_data(ts_code)
    daily = split_adjust(daily) # 前复权，从矩阵中去除adj_factor
    stock_data = get_return_vwap_turnover(daily,daily_basic) # 计算vwap, turn, free turn, 从矩阵去除turnover
    stock_data = get_reach_limit(stock_data) # 得到下一交易日是否涨跌停, (0 or 1)
    stock_data5, stock_data10 = get_return(stock_data) # 计算个股5天后收益和个股10天后收益
    # 有所有交易日数据，如果一个交易日有nan则整个交易日所有数据都是nan
    stock_data5, stock_data10 = complete_all_datelist_info(stock_data5,stock_data10,datelist)
    return stock_data5, stock_data10

## debug

In [67]:
daily,daily_basic = load_data('002266.SZ')
ts_code_list = idx_w_stock_basic_filter.get_filtered_ts_code()[:,0]
datelist = get_datelist(ts_code_list)
# datelist = [datetime.strptime(str(int(item)),"%Y%m%d") for item in datelist]

df_daily       = pd.DataFrame(daily,columns=['date','open','high','low','close','volume','turnover','adj_factor'])
df_daily_basic = pd.DataFrame(daily_basic,columns=['date','turnover_rate','turnover_rate_free','float_share','free_share'])

daily = split_adjust(daily) # 前复权，从矩阵中去除adj_factor
df_daily       = pd.DataFrame(daily,columns=['date','open','high','low','close','volume','turnover'])
stock_data = get_return_vwap_turnover(daily,daily_basic) # 从矩阵去除turnover
df_stock       = pd.DataFrame(stock_data,columns=['date','open','high','low','close','vwap','volume','return','turn','free turn'])
# print(df_stock)

stock_data = get_reach_limit(stock_data) # 得到下一交易日是否涨跌停, (0 or 1)
st = pd.DataFrame(stock_data,columns=['date','open','high','low','close','vwap', 'volume','return','turn','free turn','limit'])
# print(st)
stock_data5, stock_data10 = get_return(stock_data)

# print(stock_data5)
# 有所有交易日数据，如果一个交易日有nan则整个交易日所有数据都是nan
# stock_data5, stock_data10 = complete_all_datelist_info(stock_data5,stock_data10,datelist)
df_stock = pd.DataFrame(stock_data5,columns=['date','open','high','low','close','vwap','volume','return','turn','free turn','limit','return5'])
# df_stock['return1']

In [68]:
# with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
#     display(df_stock)
# df_stock.to_string()
df_stock[-45:]

Unnamed: 0,date,open,high,low,close,vwap,volume,return,turn,free turn,limit,return5
971,20200528.0,3.67,3.67,3.61,3.63,3.639446,7247810.0,0.002755,0.004426,0.004732,0.0,0.121212
972,20200529.0,3.63,3.66,3.62,3.64,3.638521,8071120.0,0.071429,0.004929,0.00527,0.0,0.134615
973,20200601.0,3.66,4.0,3.66,3.9,3.872923,40626801.0,0.005128,0.02481,0.026527,0.0,0.05641
974,20200602.0,3.9,3.97,3.89,3.92,3.919855,21821896.0,0.022959,0.013326,0.014248,0.0,0.035714
975,20200603.0,3.94,4.04,3.91,4.01,3.961288,22450247.0,0.014963,0.01371,0.014659,0.0,0.012469
976,20200604.0,4.07,4.22,3.94,4.07,4.070979,39053017.0,0.014742,0.023848,0.025499,0.0,-0.002457
977,20200605.0,4.07,4.2,4.05,4.13,4.142647,24354796.0,-0.002421,0.014873,0.015902,0.0,-0.016949
978,20200608.0,4.18,4.19,4.1,4.12,4.130878,13844584.0,-0.014563,0.008454,0.00904,0.0,-0.024272
979,20200609.0,4.11,4.14,4.04,4.06,4.077769,17872080.0,0.0,0.010914,0.011669,0.0,0.002463
980,20200610.0,4.1,4.15,4.05,4.06,4.081959,9543113.0,0.0,0.005828,0.006231,0.0,-0.002463


## 生成处理数据图片,预测目标

In [65]:
def get_single_data_matrix(data,pass_date,sample_size,sample_interval):
    data_img = np.empty((sample_size,pass_date))
    for i in range(sample_size):
        data_img[i,:] = data[:,i*sample_interval:pass_date+i*sample_interval]
    return data_img

def get_img(stock_data,pass_date:int,sample_interval:int):
    '''
    pass_date: 过去30天的个股量价数据作为数据图片
    sample_interval: 每隔两天采样一次
    '''
    date_size      = len(stock_data)
    sample_size    = int((date_size-pass_date)/2) + 1 # 采样天数
    col_size       = stock_data.shape[1]
    stock_data_img = np.empty((col_size,sample_size,pass_date))
    for i in range(col_size):
        data     = stock_data[:,i].reshape(1,date_size)
        data_img = get_single_data_matrix(data,pass_date,sample_size,sample_interval)
        stock_data_img[i,...] = data_img
    stock_data_img = np.transpose(stock_data_img,(1,0,2))
    return stock_data_img

def handle_nan_limit_up_down(stock_data):
    '''
    handle_nan: 当图片中有nan, 整个图片都是nan
    limit_up_down: 截面期下一交易日涨跌停时, 整个图片都是nan
    '''
    sample_size = len(stock_data)
    limit = stock_data[:,-2,-1]
    nan_img_idx = np.unique((stock_data != stock_data).nonzero()[0])
    limit_idx   = np.where(limit == 1)
    stock_data[nan_img_idx,1:] = np.nan
    stock_data[limit_idx,1:]   = np.nan  
    return stock_data

def process_data_img(stock_data5,stock_data10,pass_date=30,sample_interval=2):
    stock_data5  = get_img(stock_data5,pass_date,sample_interval)
    stock_data10 = get_img(stock_data10,pass_date,sample_interval)
    display(pd.DataFrame(stock_data5[-1].T))
    stock_data5  = handle_nan_limit_up_down(stock_data5)
    print(stock_data5[-3])

    stock_data10 = handle_nan_limit_up_down(stock_data10)
    return5      = stock_data5[:,-1,-1].reshape(len(stock_data5),1)
    return10     = stock_data10[:,-1,-1].reshape(len(stock_data10),1)
    stock_data5  = stock_data5[:,0:10,:]
    stock_data10 = stock_data10[:,0:10,:]
    return stock_data5,stock_data10,return5,return10

In [66]:
s5,s10,return5,return10 = process_data_img(stock_data5,stock_data10)
# for i in range(len(s5)):
#     print(s5[i])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,20200617.0,4.05,4.07,3.98,4.05,4.030943,11796314.0,0.002469,0.007204,0.007702,0.0,-0.022222
1,20200618.0,4.06,4.08,4.02,4.06,4.050609,8644420.0,0.007389,0.005279,0.005644,0.0,-0.046798
2,20200619.0,4.04,4.09,4.04,4.09,4.066936,12609606.0,-0.022005,0.0077,0.008233,0.0,-0.05379
3,20200622.0,4.09,4.13,3.99,4.0,4.052745,16522118.0,0.0,0.01009,0.010788,0.0,-0.0375
4,20200623.0,4.03,4.05,3.97,4.0,3.999123,11422048.0,-0.01,0.006975,0.007458,0.0,-0.01
5,20200624.0,4.01,4.03,3.96,3.96,3.984874,8106967.0,-0.022727,0.004951,0.005293,0.0,0.10101
6,20200629.0,3.96,3.97,3.84,3.87,3.895048,12934839.0,0.0,0.007899,0.008446,0.0,0.196382
7,20200630.0,3.9,3.9,3.85,3.87,3.868815,12901040.0,-0.005168,0.007878,0.008424,0.0,0.217054
8,20200701.0,3.9,3.9,3.83,3.85,3.858038,14102873.0,0.028571,0.008612,0.009208,1.0,0.264935
9,20200702.0,3.86,3.96,3.83,3.96,3.899034,20678911.0,0.10101,0.012628,0.013502,0.0,0.229798


[[20200611. 20200612. 20200615. 20200616. 20200617. 20200618. 20200619.
  20200622. 20200623. 20200624. 20200629. 20200630. 20200701. 20200702.
  20200703. 20200706. 20200707. 20200708. 20200709. 20200710. 20200713.
  20200714. 20200715. 20200716. 20200717. 20200720. 20200721. 20200722.
  20200723. 20200724.]
 [      nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan]
 [      nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan       nan       nan       nan       nan       nan
        nan       nan]
 [      nan       nan       nan       nan       nan       nan      

## 生成dataset

In [19]:
def construct_dataset():
    ts_code_list = idx_w_stock_basic_filter.get_filtered_ts_code()[:,0]
    datelist = get_datelist(ts_code_list)
    ts_code_size = len(ts_code_list) 
    sample_size  = int((len(datelist)-30)/2) + 1 # 采样天数
    x5  = np.empty((ts_code_size,sample_size,10,30))
    x10 = np.empty((ts_code_size,sample_size,10,30))
    y5  = np.empty((ts_code_size,sample_size,1))
    y10 = np.empty((ts_code_size,sample_size,1))
    iter_ = 0
    for ts_code in ts_code_list:
        stock_data5,stock_data10 = process_raw_data(ts_code,datelist)
        stock_data5,stock_data10,return5,return10 = process_data_img(stock_data5,stock_data10)
        x5[iter_,:]  = stock_data5
        x10[iter_,:] = stock_data10
        y5[iter_,:]  = return5
        y10[iter_,:] = return10
        iter_ += 1
    x5  = np.transpose(x5,(1,0,2,3)) 
    x10 = np.transpose(x10,(1,0,2,3))
    y5  = np.transpose(y5,(1,0,2)) 
    y10 = np.transpose(y10,(1,0,2)) 
    x5  = x5[:,:,np.newaxis,...]
    x10 = x10[:,:,np.newaxis,...]
    return x5,x10,y5,y10,ts_code_list

In [20]:
x5,x10,y5,y10,ts_code_list = construct_dataset()

In [28]:
np.save('./x5.npy',x5)
np.save('./x10.npy',x10)
np.save('./y5.npy',y5)
np.save('./y10.npy',y10)
np.save('./ts_code_list.npy',ts_code_list)

In [35]:
pd.DataFrame(np.squeeze(x5[-3,145].T))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,20200611.0,,,,,,,,,
1,20200612.0,,,,,,,,,
2,20200615.0,,,,,,,,,
3,20200616.0,,,,,,,,,
4,20200617.0,,,,,,,,,
5,20200618.0,,,,,,,,,
6,20200619.0,,,,,,,,,
7,20200622.0,,,,,,,,,
8,20200623.0,,,,,,,,,
9,20200624.0,,,,,,,,,
