In [1]:
import numpy as np

In [2]:
from torch.utils.data import Dataset, ConcatDataset, DataLoader, random_split

In [9]:
import json, torch

In [1]:
# PyTorch Dataset for Loading Supervised Training data for Mariokart DS
class RaceDataset(Dataset):
    def __init__(self, folder_path, seq_len=32, stride=1, dilation=1):
        # Metadata stores the mean, std, min/max, and other data required for feature scaling
        with open(f"{folder_path}/metadata.json", 'r') as f:
            self.metadata: Metadata = json.load(f)
        
        self.obs_data = np.memmap(f"{folder_path}/samples.dat", dtype=np.float32, mode="r").reshape(-1, len(self.metadata['mean']))
        self.act_data = np.memmap(f"{folder_path}/targets.dat", dtype=np.int32, mode="r")

        self.seq_len = seq_len
        self.stride = stride
        self.dilation = dilation
        self.window_span = (seq_len - 1) * dilation + 1

        self.valid_indices = range(0, len(self.obs_data) - self.window_span + 1, stride)

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_indices[idx]
        
        end_idx = start_idx + self.window_span
        obs_seq = torch.from_numpy(self.obs_data[start_idx : end_idx : self.dilation])

        if start_idx == 0:
            future_prev_acts = torch.from_numpy(self.act_data[self.dilation - 1 : end_idx - 1 : self.dilation])
            prev_act_seq = torch.cat([torch.tensor([0]), future_prev_acts])
        else:
            prev_act_seq = torch.from_numpy(self.act_data[start_idx - 1 : end_idx - 1 : self.dilation])

        last_frame_idx = end_idx - 1
        target = torch.tensor(self.act_data[last_frame_idx], dtype=torch.long)

        return obs_seq, prev_act_seq, target

NameError: name 'Dataset' is not defined

In [11]:
ds = RaceDataset("private/training_data/rdp1_pikalex")

In [12]:
ds[0]

  obs_seq = torch.from_numpy(self.obs_data[start_idx : end_idx : self.dilation])


(tensor([[ 3.4322e+02,  3.4489e+02,  3.5018e+02,  ...,  2.9297e-03,
           2.9297e-03,  0.0000e+00],
         [ 3.4321e+02,  3.4488e+02,  3.5017e+02,  ...,  5.9204e-03,
           5.8594e-03,  0.0000e+00],
         [ 3.4320e+02,  3.4487e+02,  3.5017e+02,  ...,  8.8196e-03,
           8.7891e-03,  0.0000e+00],
         ...,
         [ 4.2706e+02,  4.0041e+02,  3.8057e+02,  ...,  5.9900e+00,
           4.1878e+00, -1.4964e-02],
         [ 4.3273e+02,  4.0500e+02,  3.8431e+02,  ...,  4.7621e+00,
           3.2282e+00, -1.2547e-02],
         [ 4.3873e+02,  4.0987e+02,  3.8831e+02,  ...,  4.6677e+00,
           3.1149e+00, -1.2407e-02]]),
 tensor([  0,   1,   1,   1,   1,   1,   1,  17, 273, 273, 273, 273, 273, 273,
         273, 273, 273, 273, 273, 273, 273, 273, 273, 273, 257, 257, 289, 289,
         289, 289, 289, 289]),
 tensor(289))