In [1]:
import numpy as np
import torch
import os
from torch.utils.data import Dataset
from pathlib import Path


In [2]:
class CANDataset(Dataset):
    def __init__(self, root_dir, is_train=True, transform=None):
        self.root_dir = Path(root_dir) / ('train' if is_train else 'val')
        self.is_train = is_train
        self.transform = transform
        self.total_size = len(os.listdir(self.root_dir))
            
    def __getitem__(self, idx):
        filename = f'{idx}.npz'
        filename = self.root_dir / filename
        data = np.load(filename)
        X, y = data['X'], data['y']
        X_tensor = torch.tensor(X)
        y_tensor = torch.tensor(y)
        return X_tensor, y_tensor

    def __len__(self):
        return self.total_size

In [3]:
data_dir = '../Data/CHD_w29_s14_ID_Data/1/'
train_dataset = CANDataset(root_dir=data_dir, is_train=True)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128,
    shuffle=True, num_workers=8, 
    pin_memory=True, sampler=None
)

In [4]:
X, y = next(iter(train_dataloader))

In [31]:
filename = '../Data/CHD_w29_s14_ID_Data/1/train/1.npz'
data = np.load(filename)
X, y = data['X'], data['y']

In [33]:
torch.tensor(y)

tensor(1, dtype=torch.uint8)

In [27]:
torch.tensor(X)

tensor([[[ 4.0188e+01, -4.3156e+01, -3.6750e+01,  ..., -2.5578e+01,
           3.8906e+01, -1.5950e+02],
         [-1.2175e+02,  7.4727e+00,  2.5500e+02,  ...,  7.1719e+00,
          -3.3300e+02,  3.5075e+02],
         [-8.5125e+01,  1.9700e+02,  6.2688e+01,  ...,  9.3062e+01,
          -8.1500e+01,  2.1862e+02],
         ...,
         [-5.3750e+02, -4.1400e+02, -4.6600e+02,  ..., -2.8688e+01,
          -9.2062e+01, -1.7988e+02],
         [-5.5600e+02, -5.2000e+02, -4.5900e+02,  ...,  4.9938e+01,
          -9.0125e+01, -1.3838e+02],
         [-6.1700e+02, -5.5650e+02, -3.4425e+02,  ...,  5.2656e+01,
          -1.5912e+02, -1.4088e+02]],

        [[-1.1945e-04, -8.9874e-03,  3.1738e-02,  ..., -7.3047e+00,
           7.1758e+00, -1.6621e+00],
         [ 2.6245e-03, -3.8916e-01,  1.2275e+00,  ..., -5.6406e+00,
          -2.4141e+01,  2.1641e+01],
         [ 3.5229e-01,  1.1230e+00, -9.5215e-01,  ...,  4.3562e+01,
          -1.0523e+01, -2.0906e+01],
         ...,
         [-5.4594e+01, -5