In [6]:
import h5py
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch

class ImageDataset(Dataset):
    def __init__(self, file_path, indices):
        self.file = h5py.File(file_path, 'r')
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        group_key = list(self.file)[self.indices[idx]]
        data = self.file[group_key]
        
        input_color = torch.tensor(data['original'][:], dtype=torch.float32).permute(2, 0, 1) / 255.0  # Normalize
        input_gray = torch.tensor(data['grayscale'][:], dtype=torch.float32).unsqueeze(0) / 255.0  # Normalize and add channel
        output_a = torch.tensor(data['A_channel'][:], dtype=torch.float32)
        output_b = torch.tensor(data['B_channel'][:], dtype=torch.float32)
        
        return (input_color, input_gray), (output_a, output_b)

    def close(self):
        self.file.close()

In [7]:
def create_loaders(file_path, batch_size=32, test_size=0.2, random_seed=42):
    with h5py.File(file_path, 'r') as file:
        total_images = len(file)
        indices = list(range(total_images))

    print(len(indices))
    
    train_indices, test_indices = train_test_split(indices, test_size=test_size, random_state=random_seed)

    print(len(train_indices))
    print(len(test_indices))
    
    train_dataset = ImageDataset(file_path, train_indices)
    test_dataset = ImageDataset(file_path, test_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [10]:
file_path = 'image_data.h5'
train_loader, test_loader = create_loaders(file_path, batch_size=64)

for (inputs_color, inputs_gray), (targets_a, targets_b) in train_loader:
    print(inputs_color.shape, inputs_gray.shape, targets_a.shape, targets_b.shape)

5
4
1
torch.Size([4, 3, 256, 256]) torch.Size([4, 1, 256, 256]) torch.Size([4, 256, 256]) torch.Size([4, 256, 256])
