In [1]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from tqdm import tqdm
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
class Stock(Dataset):
    def __init__(self, df, news_max = 20) -> None:
        features, prices, values = [], [], []
        
        date = df["date"]
        symbols = df["symbol"]
        date, symbols = list(set(date)), list(set(symbols))
        for days in tqdm(date):
            for symbol in symbols:
                day = df["date"] == days
                sym = df["symbol"] == symbol
                day_data = df[day & sym]
                news_num = len(day_data)
                if news_num == 0: continue
                feature = pd.concat([day_data[f"f{idx+1:02d}"] 
                                        for idx in range(16)], axis=1).to_numpy()
                price = [day_data[f"pre{idx}dreturn"].to_numpy() for idx in range(2, -1, -1)]
                price = np.concatenate([price], axis=0)[:,0]
                value = day_data["nextreturn"].to_numpy()[0]
                if news_num > news_max:
                    choice = np.random.choice(news_num, news_max, replace=False)
                    feature, news_num = feature[choice], news_max
                feature = np.pad(feature, [(0, news_max - news_num), (0, 0)])
                features.append(feature)
                prices.append(price)
                values.append(value)
                
        self.len = len(features)
        self.features, self.prices, self.values = features, prices, values
        
    def __getitem__(self, idx):
        return self.features[idx], self.prices[idx], self.values[idx]

    def __len__(self):
        return self.len

In [17]:
merge = pd.read_csv("../data/predict_dataset.csv")
dataset = Stock(merge)

100%|██████████| 571/571 [00:09<00:00, 62.37it/s]


In [20]:
batch_size = 16
validation_split = .2

In [29]:
np.random.seed(42)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
np.random.shuffle(indices)
indices = {
    "train" : indices[split:],
    "valid" : indices[:split]
}
sampler = {x : SubsetRandomSampler(indices[x]) for x in ["train", "valid"]}
loader = {x : DataLoader(dataset, batch_size=batch_size, sampler=sampler[x]) for x in ["train", "valid"]}

In [30]:
print(len(loader["train"]), len(loader["valid"]))

111 28


In [31]:
for state in ["train", "valid"]:
    for p in loader[state]:
        continue
print("SUCCESS")

SUCCESS
