参考自NLP代码 [PyTorch快餐教程2019 (1) - 从Transformer说起](https://blog.csdn.net/lusing/article/details/102666617/)


In [142]:
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn

import sys
sys.path.append(r'./')
from AModelFactory import ModelFactory

In [143]:
'''全局参数'''
r_train = 0.9

batch_size = 20 
seq_len = 35    
emsize = 6

nhid = 256      # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2     # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2       # the number of heads in the multiheadattention models
dropout = 0.2   # the dropout value

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [144]:
'''1. 读取数据，选择需要的列'''
data = pd.read_csv('XSHG300.csv', index_col=0)  # , parse_dates=True
data = data[['open', 'close', 'high', 'low', 'volume', 'money']]

data

Unnamed: 0,open,close,high,low,volume,money
2005-04-08,984.66,1003.45,1003.70,979.53,1.476253e+09,9.151350e+09
2005-04-11,1003.88,995.42,1008.74,992.77,1.593607e+09,1.043623e+10
2005-04-12,993.71,978.70,993.71,978.20,1.022619e+09,6.479563e+09
2005-04-13,987.95,1000.90,1006.50,987.95,1.607169e+09,1.002960e+10
2005-04-14,1004.64,986.98,1006.42,985.58,1.294571e+09,7.813425e+09
...,...,...,...,...,...,...
2015-12-25,3832.09,3838.20,3848.03,3813.20,1.196240e+10,1.633145e+11
2015-12-28,3847.53,3727.63,3853.39,3727.63,1.539884e+10,2.100260e+11
2015-12-29,3723.05,3761.88,3762.05,3710.48,1.018856e+10,1.404051e+11
2015-12-30,3762.91,3765.18,3765.66,3726.28,1.056300e+10,1.557441e+11


In [145]:
'''2.1 划分"训练集、测试集"；构造“序列 - trian_seq/test_seq” '''
data_seq = list(data.index)
data_seq = list(map(lambda x: int(x.replace('-', '')), data_seq))   # Tensor只接受数值数据
data_seq = torch.tensor(data_seq)
data.index = data_seq

train_len = int(0.9*len(data_seq))
train_seq = data_seq[:train_len]
train_data = data[:train_len]

train_seq

tensor([20050408, 20050411, 20050412,  ..., 20141204, 20141205, 20141208])

In [146]:
'''2.2 归一化；构造"字典 - Dict" '''
scaler = StandardScaler().fit(train_data)
data_norm = scaler.transform(data)
Dict = DataFrame(data_norm, index=data.index, columns=data.columns)

Dict

Unnamed: 0,open,close,high,low,volume,money
20050408,-1.594964,-1.579738,-1.584240,-1.598419,-1.132051,-1.217806
20050411,-1.575985,-1.587671,-1.579326,-1.585116,-1.102353,-1.190584
20050412,-1.586027,-1.604188,-1.593979,-1.599756,-1.246846,-1.274410
20050413,-1.591715,-1.582257,-1.581510,-1.589959,-1.098921,-1.199199
20050414,-1.575234,-1.596008,-1.581588,-1.592340,-1.178027,-1.246151
...,...,...,...,...,...,...
20151225,1.216760,1.220656,1.188694,1.248826,1.521554,2.048287
20151228,1.232006,1.111426,1.193919,1.162847,2.391173,3.037915
20151229,1.109087,1.145261,1.104872,1.145614,1.072668,1.562929
20151230,1.148448,1.148521,1.108391,1.161490,1.167423,1.887901


In [147]:
'''3. 模型训练的准备'''
'''扩展出seq_len
'''
def get_batch(source, i):

    real_seq_len = min(seq_len, len(source) - 1 - i)

    data = source[i:i + real_seq_len]
    target = source[i + 1:i + 1 + real_seq_len].view(-1)
    # 输入的查表在模型中做的，考虑都移到这里
    data = ...
    target = torch.tensor(list(map(lambda x: Dict.loc[int(x)]['close'], target)), dtype=torch.float32)

    return data, target

In [148]:
model = ModelFactory.createTransformer(Dict, emsize, nhead, nhid, nlayers, dropout).to(device)

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [149]:
def train():
    model.train()  # Turn on the train mode

    total_loss = 0.
    n = 0
    for batch, i in enumerate(range(0, train_seq.size(0) - 1, bptt)):

        data, targets = get_batch(train_seq, i)
        optimizer.zero_grad()
        
        output = model(data)
        
        output = torch.squeeze(output)
        loss = loss_func(output.view(-1), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        n += len(targets)

    return total_loss / n
