In [57]:
from torch.utils.data import Dataset, DataLoader
import torch
from fxdata import load_ticker

In [91]:
class ForexDataset(Dataset):
    """FX Timeseries data loader
    
    Uses the next Close as the target. Rescales all values by the last Close
    in the features and subtracts 1."""
    
    def __init__(self, ticker, granularity, look_back=1, tensor=False):
        self.feature_df = load_ticker(ticker, sample_mins=granularity)
        self.target_df = self.feature_df["Close"].shift(periods=-1)/self.feature_df["Close"] - 1
        
        self.tensor = tensor
        self.look_back = look_back
        
    def __len__(self):
        return len(self.feature_df.index) - self.look_back - 1
    
    def __getitem__(self, idx):
        features = self.feature_df.iloc[idx:idx+self.look_back]
        #print(features, self.feature_df["Close"].iloc[idx+self.look_back-1])
        features /= self.feature_df["Close"].iloc[idx+self.look_back-1]
        features = (features - 1).to_numpy()
        targets = self.target_df.iloc[idx+self.look_back]
        if self.tensor:
            features = torch.from_numpy(features)
            targets = torch.tensor(targets)
        
        return {"features": features,
                "target": targets}

In [95]:
fx = ForexDataset("EURUSD", 60, look_back=10, tensor=True)
fx[0]

{'features': tensor([[-0.0006, -0.0004, -0.0007, -0.0006],
         [-0.0006, -0.0006, -0.0012, -0.0011],
         [-0.0010, -0.0006, -0.0010, -0.0009],
         [-0.0009, -0.0008, -0.0012, -0.0011],
         [-0.0011, -0.0009, -0.0012, -0.0010],
         [-0.0010, -0.0009, -0.0011, -0.0010],
         [-0.0010, -0.0006, -0.0010, -0.0007],
         [-0.0008, -0.0003, -0.0009, -0.0003],
         [-0.0003,  0.0004, -0.0005,  0.0002],
         [ 0.0003,  0.0004, -0.0005,  0.0000]], dtype=torch.float64),
 'target': tensor(0.0008, dtype=torch.float64)}

In [96]:
fxd = DataLoader(fx, batch_size=4, shuffle=False, num_workers=0)
for x, y in enumerate(fxd):
    print(x, y)
    break

0 {'features': tensor([[[-6.2961e-04, -4.4972e-04, -7.1955e-04, -6.2961e-04],
         [-6.2961e-04, -6.2961e-04, -1.1693e-03, -1.0793e-03],
         [-9.8939e-04, -6.2961e-04, -9.8939e-04, -8.9944e-04],
         [-8.9944e-04, -8.0950e-04, -1.1693e-03, -1.0793e-03],
         [-1.0793e-03, -8.9944e-04, -1.1693e-03, -9.8939e-04],
         [-9.8939e-04, -8.9944e-04, -1.0793e-03, -9.8939e-04],
         [-9.8939e-04, -6.2961e-04, -9.8939e-04, -7.1955e-04],
         [-8.0950e-04, -2.6983e-04, -8.9944e-04, -2.6983e-04],
         [-2.6983e-04,  3.5978e-04, -5.3967e-04,  1.7989e-04],
         [ 2.6983e-04,  3.5978e-04, -5.3967e-04,  0.0000e+00]],

        [[-8.9993e-05, -8.9993e-05, -6.2995e-04, -5.3996e-04],
         [-4.4996e-04, -8.9993e-05, -4.4996e-04, -3.5997e-04],
         [-3.5997e-04, -2.6998e-04, -6.2995e-04, -5.3996e-04],
         [-5.3996e-04, -3.5997e-04, -6.2995e-04, -4.4996e-04],
         [-4.4996e-04, -3.5997e-04, -5.3996e-04, -4.4996e-04],
         [-4.4996e-04, -8.9993e-05, -4