In [128]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os

In [216]:
class CustomDataset(Dataset):
    def __init__(self, path_to_npz_files, transforms=None):
        self.transforms = transforms
        self.data = []
        for npz_file in os.listdir(path_to_npz_files):
            data = np.load(os.path.join(path_to_npz_files, npz_file))
            data = data['q_nm']
            data = np.array([data.real, data.imag])
            data = np.transpose(data, (1, 2, 0))
            self.data.append(data)
        self.data = np.array(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transforms(self.data[idx])

In [256]:
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128)),
    # normalize data by largest norm across all channels
    transforms.Lambda(lambda x: x / torch.max(torch.norm(x, dim=0)))
])

In [257]:
path_to_npz_files = './data/training_npz'
dataset = CustomDataset(path_to_npz_files, transforms=tfms)

In [258]:
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

In [259]:
for batch_idx, data in enumerate(dataloader):
    # Your training code here
    print(f"Batch {batch_idx}: Data shape {data.shape}")

Batch 0: Data shape torch.Size([16, 2, 128, 128])
Batch 1: Data shape torch.Size([16, 2, 128, 128])
Batch 2: Data shape torch.Size([16, 2, 128, 128])
Batch 3: Data shape torch.Size([16, 2, 128, 128])
Batch 4: Data shape torch.Size([16, 2, 128, 128])
Batch 5: Data shape torch.Size([16, 2, 128, 128])
Batch 6: Data shape torch.Size([16, 2, 128, 128])
Batch 7: Data shape torch.Size([16, 2, 128, 128])
Batch 8: Data shape torch.Size([16, 2, 128, 128])
Batch 9: Data shape torch.Size([16, 2, 128, 128])
Batch 10: Data shape torch.Size([16, 2, 128, 128])
Batch 11: Data shape torch.Size([9, 2, 128, 128])


In [260]:
print(data[0])

tensor([[[-1.8602e-03,  7.8000e-03, -1.6978e-02,  ...,  5.8471e-03,
          -5.3269e-03,  3.2776e-03],
         [-7.8214e-03,  2.8394e-02, -3.3722e-02,  ...,  4.3092e-03,
          -1.8106e-02,  2.0748e-03],
         [-5.3661e-02,  8.2054e-02, -4.0894e-02,  ...,  2.5697e-02,
           4.3883e-02, -4.4351e-02],
         ...,
         [ 1.6702e-04, -2.7817e-03,  2.4416e-03,  ...,  1.4293e-03,
           5.4289e-04, -1.1706e-03],
         [ 4.5386e-03, -4.5168e-03,  1.0166e-03,  ...,  1.8771e-03,
          -1.9520e-03,  2.4651e-03],
         [ 2.0920e-03,  1.5561e-03, -1.4456e-02,  ...,  1.3432e-02,
           2.9100e-03, -1.3333e-02]],

        [[ 4.1448e-03, -9.0129e-03,  1.2712e-02,  ..., -2.5560e-04,
           8.9126e-03, -9.2853e-03],
         [ 3.3426e-03, -1.3396e-02,  3.8496e-02,  ...,  1.0731e-03,
           1.6575e-02, -2.3991e-02],
         [ 3.2256e-05, -6.8214e-02,  7.1803e-02,  ...,  5.1753e-02,
           2.6384e-02, -7.8400e-03],
         ...,
         [ 1.5355e-03, -1