In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils import data

  from .autonotebook import tqdm as notebook_tqdm


# 将prepareData里面的功能改成csv输入

In [9]:
def prepare_x(data):
    # 此时的data已经是df
    df1 = data.iloc[:,:40]
    return np.array(df1)

def get_label(data):
    lob = data.iloc[:,-5:]
    return lob

def data_classification(X, Y, T):
    [N, D] = X.shape
    # N是样本点的数量，D是特征数量(这里就是40)
    df = np.array(X)
    dY = np.array(Y)
    dataY = dY[T - 1:N]
    dataX = np.zeros((N - T + 1, T, D))
    for i in range(T, N + 1):
        dataX[i - T] = df[i - T:i, :]
    return dataX, dataY

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, data, k, num_classes, T):
        """Initialization""" 
        self.k = k
        self.num_classes = num_classes
        self.T = T
            
        x = prepare_x(data)
        y = get_label(data)
        x, y = data_classification(x, y, self.T)
        # 源数据中的标签是1，2，3；而预测模型中的输出是softmax层的索引0，1，2；故需要-1来相匹配
        y = y[:,self.k] - 1
        self.length = len(x)

        x = torch.from_numpy(x)
        # 给x在dim=1处增加一个维度，以作为pytorch统一输入格式
        self.x = torch.unsqueeze(x, 1)
        self.y = torch.from_numpy(y)

    def __len__(self):
        """Denotes the total number of samples"""
        return self.length

    def __getitem__(self, index):
        """Generates samples of data"""
        return self.x[index], self.y[index]
    
def splitDataset(train_test_rate, train_val_rate, train7path, test7path, test8path, test9path):
    '''
    train_test_rate: 训练集和测试集划分的比例，0.8表示训练集占据0.8，剩下的0.2属于测试集
    train_val_rate: 训练集和验证集划分的比例，0.8表示训练集占据0.8，剩下的0.2属于验证集
    '''
    dec_data = pd.read_csv(train7path)
    dec_train = dec_data.iloc[:int(np.floor(dec_data.shape[0] * train_test_rate))]
    dec_val = dec_data.iloc[int(np.floor(dec_data.shape[0] * train_val_rate)):]   

    dec_test1 = pd.read_csv(test7path)
    dec_test2 = pd.read_csv(test8path)
    dec_test3 = pd.read_csv(test9path)
    
    frames = [dec_test1, dec_test2, dec_test3]
    dec_test = pd.concat(frames)
    return dec_train, dec_val, dec_test

def getDataLoader(dec_train, dec_val, dec_test, k=4, num_classes=3, T=100, batch_size=64):
    '''
    k代表使用第几个label
    T代表对于每个样本点，一共采集多少个时间步的特征
    '''
    dataset_train = Dataset(data=dec_train, k=k, num_classes=num_classes, T=T)
    dataset_val = Dataset(data=dec_val, k=k, num_classes=num_classes, T=T)
    dataset_test = Dataset(data=dec_test, k=k, num_classes=num_classes, T=T)

    train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

In [10]:
traint7path = '../../data/random_processed/stock0/randomOrder_train7_part0.csv'
test7path = '../../data/random_processed/stock0/randomOrder_test7_part0.csv'
test8path = '../../data/random_processed/stock0/randomOrder_test8_part0.csv'
test9path = '../../data/random_processed/stock0/randomOrder_test9_part0.csv'
dec_train, dec_val, dec_test = splitDataset(0.8,0.8,traint7path,test7path,test8path,test9path)
train_loader, val_loader, test_loader = getDataLoader(dec_train,dec_val,dec_test,k=4,num_classes=3,T=100,batch_size=64)
for x,y in train_loader:
    print(x.shape)
    print(y.shape)
    break

torch.Size([64, 1, 100, 40])
torch.Size([64])


In [11]:
for x,y in train_loader:
    print(x)
    print(y)
    break

tensor([[[[3.4100e-01, 4.5000e-04, 3.4080e-01,  ..., 1.6160e-02,
           3.3940e-01, 4.1500e-03],
          [1.2200e-01, 6.2170e-02, 1.2170e-01,  ..., 4.2750e-02,
           1.2080e-01, 1.1950e-02],
          [2.6460e-01, 1.2000e-04, 2.6420e-01,  ..., 4.0000e-03,
           2.6150e-01, 3.0000e-03],
          ...,
          [1.2640e-01, 5.0100e-03, 1.2630e-01,  ..., 2.2000e-02,
           1.2540e-01, 3.2500e-03],
          [3.4260e-01, 2.0000e-03, 3.4200e-01,  ..., 2.5400e-03,
           3.3990e-01, 1.3610e-02],
          [1.2200e-01, 1.0000e-03, 1.2180e-01,  ..., 3.2360e-02,
           1.2090e-01, 3.2250e-01]]],


        [[[1.2620e-01, 3.2860e-02, 1.2600e-01,  ..., 1.0000e-02,
           1.2510e-01, 5.0870e-02],
          [1.2550e-01, 1.2000e-02, 1.2540e-01,  ..., 7.5000e-03,
           1.2450e-01, 3.2450e-02],
          [1.2580e-01, 1.4660e-02, 1.2570e-01,  ..., 2.4000e-02,
           1.2480e-01, 2.0750e-02],
          ...,
          [1.3310e-01, 4.0440e-02, 1.3290e-01,  ..., 3.12

# 下面是测试dataframe的合并操作

In [3]:
df1 = pd.read_csv('../../data/random_processed/stock0/randomOrder_test7_part0.csv')
df2 = pd.read_csv('../../data/random_processed/stock0/randomOrder_test8_part0.csv')

In [4]:
df1

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,139,140,141,142,143,144,145,146,147,148
0,0.1270,0.03340,0.1266,0.03827,0.1271,0.01100,0.1265,0.05399,0.1272,0.02507,...,0.000569,0.000569,0.0,0.0,0.0,2.0,1.0,1.0,1.0,1.0
1,0.3623,0.00706,0.3619,0.00430,0.3624,0.00618,0.3618,0.00200,0.3625,0.00139,...,0.114571,0.114571,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
2,0.1271,0.00659,0.1268,0.05778,0.1272,0.00500,0.1267,0.01100,0.1273,0.07353,...,0.006287,0.006287,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
3,0.3654,0.00200,0.3649,0.00353,0.3655,0.00500,0.3648,0.00100,0.3656,0.01307,...,0.005490,0.005490,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
4,0.1283,0.00200,0.1280,0.02828,0.1284,0.06545,0.1279,0.01100,0.1285,0.04861,...,0.004952,0.004952,0.0,0.0,0.0,2.0,2.0,2.0,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11090,0.3610,0.00132,0.3606,0.00535,0.3614,0.00183,0.3603,0.01050,0.3615,0.00325,...,0.112822,0.112822,0.0,0.0,0.0,2.0,2.0,2.0,2.0,3.0
11091,0.1730,0.00090,0.1728,0.01927,0.1731,0.02355,0.1727,0.02000,0.1732,0.01000,...,0.000288,0.000288,0.0,0.0,0.0,2.0,2.0,2.0,3.0,3.0
11092,0.3657,0.01081,0.3652,0.00308,0.3658,0.00400,0.3651,0.00100,0.3659,0.01745,...,0.111114,0.111114,0.0,0.0,0.0,2.0,1.0,1.0,2.0,2.0
11093,0.1724,0.02051,0.1720,0.01415,0.1725,0.06760,0.1719,0.04385,0.1726,0.09769,...,0.018871,0.018871,0.0,0.0,0.0,2.0,3.0,1.0,1.0,1.0


In [5]:
df2

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,139,140,141,142,143,144,145,146,147,148
0,0.1299,0.00459,0.1296,0.02434,0.1300,0.01379,0.1295,0.01900,0.1301,0.03200,...,0.002584,0.002584,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
1,0.3640,0.00863,0.3626,0.00200,0.3641,0.00328,0.3625,0.00200,0.3642,0.02074,...,0.185185,0.185185,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0
2,0.1322,0.01739,0.1319,0.00300,0.1323,0.01855,0.1318,0.02356,0.1324,0.01889,...,0.025144,0.025144,0.0,0.0,0.0,1.0,1.0,3.0,3.0,3.0
3,0.1297,0.03658,0.1295,0.02711,0.1298,0.01000,0.1294,0.01651,0.1299,0.03152,...,0.006173,0.006173,0.0,0.0,0.0,2.0,2.0,2.0,3.0,3.0
4,0.3601,0.00200,0.3592,0.00709,0.3604,0.00720,0.3591,0.00665,0.3605,0.00474,...,0.188889,0.188889,0.0,0.0,0.0,3.0,3.0,3.0,3.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10429,0.3624,0.00221,0.3619,0.00674,0.3625,0.00425,0.3618,0.00264,0.3626,0.00680,...,0.006823,0.006823,0.0,0.0,0.0,2.0,2.0,3.0,3.0,3.0
10430,0.3650,0.00125,0.3648,0.00417,0.3656,0.00100,0.3647,0.00200,0.3657,0.00200,...,0.023333,0.023333,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
10431,0.1324,0.02253,0.1322,0.03087,0.1325,0.00680,0.1321,0.01000,0.1326,0.07118,...,0.222781,0.222781,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
10432,0.3636,0.00110,0.3633,0.00148,0.3638,0.00566,0.3632,0.00200,0.3640,0.00730,...,0.046841,0.046841,0.0,0.0,0.0,2.0,1.0,1.0,1.0,2.0


In [6]:
result = df1.append(df2, ignore_index=True)
result


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,139,140,141,142,143,144,145,146,147,148
0,0.1270,0.03340,0.1266,0.03827,0.1271,0.01100,0.1265,0.05399,0.1272,0.02507,...,0.000569,0.000569,0.0,0.0,0.0,2.0,1.0,1.0,1.0,1.0
1,0.3623,0.00706,0.3619,0.00430,0.3624,0.00618,0.3618,0.00200,0.3625,0.00139,...,0.114571,0.114571,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
2,0.1271,0.00659,0.1268,0.05778,0.1272,0.00500,0.1267,0.01100,0.1273,0.07353,...,0.006287,0.006287,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
3,0.3654,0.00200,0.3649,0.00353,0.3655,0.00500,0.3648,0.00100,0.3656,0.01307,...,0.005490,0.005490,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
4,0.1283,0.00200,0.1280,0.02828,0.1284,0.06545,0.1279,0.01100,0.1285,0.04861,...,0.004952,0.004952,0.0,0.0,0.0,2.0,2.0,2.0,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21524,0.3624,0.00221,0.3619,0.00674,0.3625,0.00425,0.3618,0.00264,0.3626,0.00680,...,0.006823,0.006823,0.0,0.0,0.0,2.0,2.0,3.0,3.0,3.0
21525,0.3650,0.00125,0.3648,0.00417,0.3656,0.00100,0.3647,0.00200,0.3657,0.00200,...,0.023333,0.023333,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
21526,0.1324,0.02253,0.1322,0.03087,0.1325,0.00680,0.1321,0.01000,0.1326,0.07118,...,0.222781,0.222781,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
21527,0.3636,0.00110,0.3633,0.00148,0.3638,0.00566,0.3632,0.00200,0.3640,0.00730,...,0.046841,0.046841,0.0,0.0,0.0,2.0,1.0,1.0,1.0,2.0


In [8]:
frames = [df1, df2]
result = pd.concat(frames)
result

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,139,140,141,142,143,144,145,146,147,148
0,0.1270,0.03340,0.1266,0.03827,0.1271,0.01100,0.1265,0.05399,0.1272,0.02507,...,0.000569,0.000569,0.0,0.0,0.0,2.0,1.0,1.0,1.0,1.0
1,0.3623,0.00706,0.3619,0.00430,0.3624,0.00618,0.3618,0.00200,0.3625,0.00139,...,0.114571,0.114571,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
2,0.1271,0.00659,0.1268,0.05778,0.1272,0.00500,0.1267,0.01100,0.1273,0.07353,...,0.006287,0.006287,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
3,0.3654,0.00200,0.3649,0.00353,0.3655,0.00500,0.3648,0.00100,0.3656,0.01307,...,0.005490,0.005490,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
4,0.1283,0.00200,0.1280,0.02828,0.1284,0.06545,0.1279,0.01100,0.1285,0.04861,...,0.004952,0.004952,0.0,0.0,0.0,2.0,2.0,2.0,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10429,0.3624,0.00221,0.3619,0.00674,0.3625,0.00425,0.3618,0.00264,0.3626,0.00680,...,0.006823,0.006823,0.0,0.0,0.0,2.0,2.0,3.0,3.0,3.0
10430,0.3650,0.00125,0.3648,0.00417,0.3656,0.00100,0.3647,0.00200,0.3657,0.00200,...,0.023333,0.023333,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
10431,0.1324,0.02253,0.1322,0.03087,0.1325,0.00680,0.1321,0.01000,0.1326,0.07118,...,0.222781,0.222781,0.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0
10432,0.3636,0.00110,0.3633,0.00148,0.3638,0.00566,0.3632,0.00200,0.3640,0.00730,...,0.046841,0.046841,0.0,0.0,0.0,2.0,1.0,1.0,1.0,2.0
