In [6]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

In [17]:
class CustomDataset(Dataset):
    def __init__(self, char_csv_file, macro_csv_file):
        self.char_df = pd.read_csv(char_csv_file)
        self.macro_df = pd.read_csv(macro_csv_file)
        self.dates = self.char_df["time_avail_m"].unique()[:-1]
        self.init_date = 196501
    def __len__(self):
        return len(self.dates)
    def __getitem__(self, idx):
        date = self.dates[idx]
        char, ret = self._get_char_data_by_date(date)
        macro = self._get_macro_by_date(self.init_date, date)
        return {"char": char, "macro": macro, "ret": ret}
    def _get_char_data_by_date(self, date):
        char_df = self.char_df
        next_month = (date + 1) if (date % 100 < 12) else (date + 89)
        ret = char_df[char_df["time_avail_m"] == next_month][["permno","RET"]]
        ret_cp = ret.copy()
        feature = char_df[char_df["time_avail_m"] == date][char_df.columns[0:29]]
        feature_cp = feature.copy()
        merge_df = pd.merge(feature_cp, ret_cp, left_on=["permno"],right_on=["permno"])

        merge_df.drop(merge_df.columns[[0, 1]], axis=1, inplace=True)

        feature_list = []
        ret_list = []
        for index, row in merge_df.iterrows():
            array = row.values
            feature = array[0:27]
            ret = array[27]
            feature_list.append(feature)

            ret_list.append(ret)

        feature_array = np.array(feature_list)
        ret_array = np.array(ret_list)
        return feature_array,ret_array
    def _get_macro_by_date(self, begin_date, end_date):
        
        macro_df = self.macro_df
        if (begin_date == end_date):
            sub_df = macro_df[macro_df["sasdate"]==begin_date]
            sub_df = sub_df.drop("sasdate", axis=1)
            out = sub_df.values
            return out
        else:
            sub_df = macro_df[(macro_df["sasdate"] >= begin_date)
                              & (macro_df["sasdate"] <= end_date)]

            sub_df = sub_df.drop("sasdate", axis=1)
            out = sub_df.values
            return out

In [18]:
custom_dataset = CustomDataset("./data/27_features_rets_normalized.csv","./data/124_macro_data.csv")

In [19]:
for i in range(len(custom_dataset)):
    sample = custom_dataset[i]

    print(i, sample['char'].shape, sample['macro'].shape,sample['ret'].shape)


0 (1965, 27) (1, 124) (1965,)
1 (1964, 27) (2, 124) (1964,)
2 (1970, 27) (3, 124) (1970,)
3 (1968, 27) (4, 124) (1968,)
4 (1965, 27) (5, 124) (1965,)
5 (1970, 27) (6, 124) (1970,)
6 (1975, 27) (7, 124) (1975,)
7 (1985, 27) (8, 124) (1985,)
8 (1981, 27) (9, 124) (1981,)
9 (1986, 27) (10, 124) (1986,)
10 (1975, 27) (11, 124) (1975,)
11 (1976, 27) (12, 124) (1976,)
12 (1978, 27) (13, 124) (1978,)
13 (2003, 27) (14, 124) (2003,)
14 (1999, 27) (15, 124) (1999,)
15 (1997, 27) (16, 124) (1997,)
16 (2001, 27) (17, 124) (2001,)
17 (2007, 27) (18, 124) (2007,)
18 (2011, 27) (19, 124) (2011,)
19 (2016, 27) (20, 124) (2016,)
20 (2013, 27) (21, 124) (2013,)
21 (2011, 27) (22, 124) (2011,)
22 (2013, 27) (23, 124) (2013,)
23 (2015, 27) (24, 124) (2015,)
24 (2030, 27) (25, 124) (2030,)
25 (2033, 27) (26, 124) (2033,)
26 (2036, 27) (27, 124) (2036,)
27 (2027, 27) (28, 124) (2027,)
28 (2026, 27) (29, 124) (2026,)
29 (2014, 27) (30, 124) (2014,)
30 (2011, 27) (31, 124) (2011,)
31 (2003, 27) (32, 124) (20

248 (5087, 27) (249, 124) (5087,)
249 (5050, 27) (250, 124) (5050,)
250 (5068, 27) (251, 124) (5068,)
251 (5118, 27) (252, 124) (5118,)
252 (5086, 27) (253, 124) (5086,)
253 (5075, 27) (254, 124) (5075,)
254 (5095, 27) (255, 124) (5095,)
255 (5096, 27) (256, 124) (5096,)
256 (5115, 27) (257, 124) (5115,)
257 (5109, 27) (258, 124) (5109,)
258 (5108, 27) (259, 124) (5108,)
259 (5085, 27) (260, 124) (5085,)
260 (5077, 27) (261, 124) (5077,)
261 (5053, 27) (262, 124) (5053,)
262 (5074, 27) (263, 124) (5074,)
263 (5177, 27) (264, 124) (5177,)
264 (5219, 27) (265, 124) (5219,)
265 (5250, 27) (266, 124) (5250,)
266 (5296, 27) (267, 124) (5296,)
267 (5357, 27) (268, 124) (5357,)
268 (5406, 27) (269, 124) (5406,)
269 (5468, 27) (270, 124) (5468,)
270 (5488, 27) (271, 124) (5488,)
271 (5504, 27) (272, 124) (5504,)
272 (5576, 27) (273, 124) (5576,)
273 (5603, 27) (274, 124) (5603,)
274 (5593, 27) (275, 124) (5593,)
275 (5625, 27) (276, 124) (5625,)
276 (5571, 27) (277, 124) (5571,)
277 (5602, 27)

489 (4481, 27) (490, 124) (4481,)
490 (4468, 27) (491, 124) (4468,)
491 (4466, 27) (492, 124) (4466,)
492 (4462, 27) (493, 124) (4462,)
493 (4472, 27) (494, 124) (4472,)
494 (4467, 27) (495, 124) (4467,)
495 (4464, 27) (496, 124) (4464,)
496 (4458, 27) (497, 124) (4458,)
497 (4455, 27) (498, 124) (4455,)
498 (4447, 27) (499, 124) (4447,)
499 (4450, 27) (500, 124) (4450,)
500 (4445, 27) (501, 124) (4445,)
501 (4435, 27) (502, 124) (4435,)
502 (4431, 27) (503, 124) (4431,)
503 (4417, 27) (504, 124) (4417,)
504 (4403, 27) (505, 124) (4403,)
505 (4394, 27) (506, 124) (4394,)
506 (4374, 27) (507, 124) (4374,)
507 (4372, 27) (508, 124) (4372,)
508 (4369, 27) (509, 124) (4369,)
509 (4365, 27) (510, 124) (4365,)
510 (4348, 27) (511, 124) (4348,)
511 (4341, 27) (512, 124) (4341,)
512 (4326, 27) (513, 124) (4326,)
513 (4311, 27) (514, 124) (4311,)
514 (4320, 27) (515, 124) (4320,)
515 (4329, 27) (516, 124) (4329,)
516 (4326, 27) (517, 124) (4326,)
517 (4322, 27) (518, 124) (4322,)
518 (4298, 27)