In [8]:
import numpy as np
import pandas as pd
import os
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import MultiHeadAttention
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from tensorflow.keras.layers import TimeDistributed, Flatten

# 加载数据
data = pd.read_csv('stock_data.csv')  # 确保文件路径是正确的
data.sort_index(ascending=False, inplace=True)
data = data.reset_index(drop=True)
data.columns = ['price', 'vol']

# 计算差值
data_diff = data.diff().dropna()

# 数据缩放
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data_diff)

lookback = 500
delay = 1

# 定义生成器函数
def generator(data, lookback=10, delay=1, batch_size=32):
    max_index = len(data) - delay - 1
    i = lookback
    while 1:
        if i + delay >= max_index:
            i = lookback
        samples = np.zeros((batch_size, lookback, data.shape[-1]))
        targets = np.zeros((batch_size, lookback, data.shape[-1]))  # targets should be 3D, same length as inputs
        for j in range(batch_size):
            if i + delay >= max_index:
                i = lookback
            rows = np.arange(i - lookback, i)
            samples[j] = data[rows]
            targets[j] = data[rows + delay]  # targets sequence is shifted by delay
            i += 1
        yield samples, targets

train_gen = generator(data_scaled, lookback=lookback, delay=delay)

model_path = 'transformer_model.h5'  # 确保路径是正确的

if os.path.exists(model_path):
    # 加载模型
    model = load_model(model_path)
else:
    # 定义模型
    inputs = Input(shape=(lookback, data_diff.shape[-1]))
    x = MultiHeadAttention(num_heads=2, key_dim=2)(inputs, inputs)
    x = Dropout(0.5)(x)
    x = MultiHeadAttention(num_heads=2, key_dim=2)(x, x)
    x = Dropout(0.5)(x)
    x = TimeDistributed(Dense(20, activation='relu'))(x)
    x = TimeDistributed(Dense(20, activation='relu'))(x)
    outputs = TimeDistributed(Dense(data_diff.shape[-1]))(x)  # 修改为预测价格和成交量
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(), loss='mse')
    model.fit(train_gen, epochs=30, steps_per_epoch=20)

    # 保存模型
    model.save(model_path)

# 对未来100个价格数据进行预测
future = 10
data_to_predict = data_scaled[-lookback:]  # 最后lookback个数据点
predicted_data = []
for _ in range(future):
    samples = np.reshape(data_to_predict[-lookback:], (1, lookback, data_diff.shape[-1]))
    predictions = model.predict(samples)
    last_prediction = predictions[0, -1, :]
    predicted_data.append(last_prediction)
    # 使用模型的最后一个预测值和前面的实际值（除去最早的一个）来形成新的输入窗口
    data_to_predict = np.vstack([data_to_predict[1:], last_prediction])  

# 将差值转化为原始预测值
predicted_diffs = scaler.inverse_transform(predicted_data)
predicted_data = np.cumsum(predicted_diffs, axis=0) + data.iloc[-lookback-1: -lookback].values

# 保存预测结果到csv文件
predicted_df = pd.DataFrame(predicted_data, columns=['predicted_price', 'predicted_vol'])
predicted_df.to_csv('predicted_data_transformer.csv', index=False)


       price      vol
1      -0.55  97000.0
2      -2.52 -21100.0
3      -0.72 -18700.0
4       0.25  23697.0
5       0.35 -20497.0
...      ...      ...
27950   0.08  17240.0
27951   0.17 -17440.0
27952   0.02  10100.0
27953  -0.11 -38600.0
27954   0.00      0.0

[27954 rows x 2 columns]
Epoch 1/30


2023-06-30 11:05:18.218283: I tensorflow/core/common_runtime/executor.cc:1210] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


  saving_api.save_model(


