In [1]:
import pandas as pd
import torch as th
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader,Dataset
from matplotlib import pyplot as plt
# %matplotlib inline

In [2]:
train_data = pd.read_csv('../../data/cbtc/test/train.csv') # 训练集
test_data = pd.read_csv('../../data/cbtc/test/test.csv') # 测试集

train_input = train_data.loc[:,['brake','target','speed','slope']] # 训练集输入
train_output = train_data.loc[:,['acc']] # 训练集输出

test_input = test_data.loc[:,['brake','target','speed','slope']] # 测试集输入
test_output = test_data.loc[:,['acc']] # 测试集输出

In [3]:
# 输入标准化
train_features = StandardScaler().fit_transform(train_input)
test_features = StandardScaler().fit_transform(test_input)

In [4]:
# 将数据转换成张量：x:输入
def pre_data(x):
    x_array = np.array(x)
    x_tensor = th.tensor(x_array)
    x = x_tensor.to(th.float32)
    return x

In [5]:
train_features = pre_data(train_features)
train_labels = pre_data(train_output)

test_features = pre_data(test_features)
test_labels = pre_data(test_output)

In [6]:
in_size = 4 # 每组数据4个输入
hidden_size = 128 # 隐层
out_size = 1 # 每组数据1个输出
batch_size = 32
drop_out = 0.1
epoch_nums = 100
learn_rate = 0.01
wgt_dcy = 0.1

In [7]:
class DatasetHandler(Dataset):
    def __init__(self,x, y):
        self.x = x
        self.y = y
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, idx):
        return self.x[idx,:], self.y[idx]

train_dataset = DatasetHandler(train_features, train_labels)

train_dataset = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [8]:
# 模型初始化
class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self.bn = nn.BatchNorm1d(in_size)
        self.linear1 = nn.Linear(in_size, hidden_size)
        self.RELU = nn.ReLU()
        self.dropout = nn.Dropout(drop_out)
        self.linear2 = nn.Linear(hidden_size, out_size)

    def forward(self,x):
        x1 = self.bn(x)
        x2 = self.linear1(x1)
        x3 = self.RELU(x2)
        x4 = self.dropout(x3)
        x5 = self.linear2(x4)

        return x5

In [9]:
model =Module()
cost = nn.MSELoss(reduction='mean') # 损失函数
optimizer = th.optim.Adam(model.parameters(), lr=learn_rate,
                          weight_decay=wgt_dcy)

In [None]:
# train loop
for epoch in range(epoch_nums):
        loop = tqdm(enumerate(train_dataset), total=len(train_dataset))

        model.train()
        # for i, data in enumerate(train_dataset, 0):
        for i, data in loop:
            features, labels = data

            optimizer.zero_grad()

            y_pre = model(features)
            loss = cost(y_pre, labels)

            # optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_description(f'Epoch [{epoch}/{epoch_nums}]')
            loop.set_postfix(loss=loss.item())

In [11]:
def predict(data_test_input):
    model.eval()
    pred_data = model(data_test_input)
    return pred_data

In [12]:
test_pred = predict(test_features)
test_pred = test_pred.data.numpy()

In [None]:
# test = test_data.iloc[:, 2] # test speed
test = test_data.loc[:, ['speed']] # test speed
pred = np.zeros(len(test)) # pred speed

for index in range(len(test)):
    pred[index] = sum(test_pred[0:index]) / 5

plt.figure(figsize=(16, 4))

y_pred = test_pred
y_test = test_labels

# 画出加速度-时间曲线
plt.title("Train Acceleration")
plt.xlabel('Time / s')
plt.ylabel('Acceleration m/s')

data_len = len(y_test) - 1
x = 0.2 * np.linspace(0, data_len, data_len+1, endpoint=True)

plt.plot(x, y_test/100, 'b', linewidth=1, label="test acceleration")
plt.plot(x, y_pred/100, 'r', linewidth=1, label="pred acceleration")
plt.legend()
plt.show()

In [None]:
# 画出速度-时间曲线
plt.title("Train Speed")
plt.xlabel('Time / s')
plt.ylabel('Speed m/s')

data_len = len(test) - 1
x = 0.2 * np.linspace(0, data_len, data_len+1, endpoint=True)

plt.plot(test/100, 'b', linewidth=1, label="test speed")
plt.plot(pred/100, 'r', linewidth=1, label="pred speed")
plt.legend()
plt.show()