In [1]:
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [192]:
class Stock(Dataset):
    def __init__(self, df, news_max = 20) -> None:
        self.len = df.shape[0]
        features, prices, values = [], [], []
        
        date = df["date"]
        symbols = df["symbol"]
        date, symbols = list(set(date)), list(set(symbols))
        for days in tqdm(date):
            for symbol in symbols:
                day = df["date"] == days
                sym = df["symbol"] == symbol
                day_data = df[day & sym]
                news_num = len(day_data)
                if news_num == 0: continue
                feature = pd.concat([day_data[f"f{idx+1:02d}"] 
                                        for idx in range(16)], axis=1).to_numpy()
                price = [day_data[f"pre{idx}dprice"].to_numpy() for idx in range(3, -1, -1)]
                price = np.concatenate([price], axis=0)[:,0]
                value = day_data["nextprice"].to_numpy()[0]
                if news_num > news_max:
                    choice = np.random.choice(news_num, news_max, replace=False)
                    feature, news_num = feature[choice], news_max
                feature = np.pad(feature, [(0, news_max - news_num), (0, 0)])
                features.append(feature)
                prices.append(price)
                values.append(value)
        self.features, self.prices, self.values = features, prices, values
        
    def __getitem__(self, idx):
        return self.features[idx], self.prices[idx], self.values[idx]

    def __len__(self):
        return self.len

In [193]:
merge = pd.read_csv("../data/predict_dataset.csv")
dataset = Stock(merge)
loader = DataLoader(dataset, batch_size=24, num_workers=24) 


100%|██████████| 571/571 [00:09<00:00, 60.71it/s]


In [200]:
feature, price, value = next(iter(loader))
print(feature.shape)
print(price.shape)
print(value.shape)

torch.Size([24, 20, 16])
torch.Size([24, 4])
torch.Size([24])
