In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os

In [90]:
class CustomDataset(Dataset):
    def __init__(self, path_to_npz_files, polar=False, transforms=None):
        self.transforms = transforms
        self.data = []
        for npz_file in sorted(os.listdir(path_to_npz_files)):
            data = np.load(os.path.join(path_to_npz_files, npz_file))
            data = data['q_nm']
            if polar:
                data = np.array([np.abs(data), np.angle(data)])
            else:
                data = np.array([data.real, data.imag])
            data = np.transpose(data, (1, 2, 0)) # (width, height, channels)
            self.data.append(data)
        self.data = np.array(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transforms(self.data[idx]) if self.transforms is not None else self.data[idx]

In [78]:
tfms = transforms.Compose([
    transforms.ToTensor(), # convert to tensor of shape (channels, width, height)
    transforms.Resize((128, 128)),
    transforms.Lambda(lambda x: x / torch.max(torch.norm(x, dim=0))) # normalize data by largest norm across all channels
])

In [127]:
tfms_polar = transforms.Compose([
    transforms.ToTensor(), # convert to tensor of shape (channels, width, height)
    transforms.Resize((128, 128)),
    transforms.Lambda(lambda x: torch.stack([
        (x[0] / torch.max(torch.abs(x[0])) - 0.5) * 2, # normalize amplitude to [-1, 1]
        x[1] / torch.max(torch.abs(x[1])) # normalize phase to [-1, 1]
    ]))
])

In [128]:
path_to_npz_files = './data/training_npz'
dataset = CustomDataset(path_to_npz_files, polar=True, transforms=tfms_polar)

In [129]:
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

In [130]:
for batch_idx, data in enumerate(dataloader):
    # Your training code here
    print(f"Batch {batch_idx}: Data shape {data.shape}")
    print(torch.min(data[:, 0, :, :]), torch.max(data[:, 0, :, :]))
    print(torch.min(data[:, 1, :, :]), torch.max(data[:, 1, :, :]))

Batch 0: Data shape torch.Size([16, 2, 128, 128])
tensor(-0.9999, dtype=torch.float64) tensor(1., dtype=torch.float64)
tensor(-1., dtype=torch.float64) tensor(1., dtype=torch.float64)
Batch 1: Data shape torch.Size([16, 2, 128, 128])
tensor(-0.9999, dtype=torch.float64) tensor(1., dtype=torch.float64)
tensor(-1., dtype=torch.float64) tensor(1., dtype=torch.float64)
Batch 2: Data shape torch.Size([16, 2, 128, 128])
tensor(-0.9999, dtype=torch.float64) tensor(1., dtype=torch.float64)
tensor(-1., dtype=torch.float64) tensor(1., dtype=torch.float64)
Batch 3: Data shape torch.Size([16, 2, 128, 128])
tensor(-1.0000, dtype=torch.float64) tensor(1., dtype=torch.float64)
tensor(-1., dtype=torch.float64) tensor(1., dtype=torch.float64)
Batch 4: Data shape torch.Size([16, 2, 128, 128])
tensor(-1.0000, dtype=torch.float64) tensor(1., dtype=torch.float64)
tensor(-1., dtype=torch.float64) tensor(1., dtype=torch.float64)
Batch 5: Data shape torch.Size([16, 2, 128, 128])
tensor(-0.9999, dtype=torch.fl