In [25]:
import h5py
import numpy as np
import torch
import torchvision
import glob
import os
from tqdm.auto import tqdm

In [33]:
class Rain13K(torch.utils.data.Dataset):
    def __init__(self, root, split, transform=None):
        if split == "train":
            self.base_dir = os.path.join(root, split, "Rain13K")
        else:
            self.base_dir = os.path.join(root, split, "Test100")
        self.input_dir = os.path.join(self.base_dir, "input")
        self.target_dir = os.path.join(self.base_dir, "target")
        self.transform = transform
        self.length = len(os.listdir(self.input_dir))

        if split == "train":
            self.image_ids = glob.glob(os.path.join(self.input_dir, "*.jpg"))
        elif split == "test":
            self.image_ids = glob.glob(os.path.join(self.input_dir, "*.png"))
        self.image_ids = [image_id.split("/")[-1].split(".")[0] for image_id in self.image_ids]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        x_path = list(glob.glob(os.path.join(self.input_dir, f"{self.image_ids[idx]}.*")))[0]
        y_path = list(glob.glob(os.path.join(self.target_dir, f"{self.image_ids[idx]}.*")))[0]
        x = torchvision.io.read_image(x_path).float().to("cuda")
        y = torchvision.io.read_image(y_path).float().to("cuda")
        x = self.transform(x)
        y = self.transform(y)
        return x, y

In [38]:
def divide_by_255(x):
    return x / 255


ds_mean = [0.49139968, 0.48215841, 0.44653091]
ds_std = [0.24703223, 0.24348513, 0.26158784]

transform = torchvision.transforms.Compose(
    [
        divide_by_255,
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.Normalize(ds_mean, ds_std),
    ]
)

train_dataset = Rain13K("data/Rain13K", "train", transform=transform)
test_dataset = Rain13K("data/Rain13K", "test", transform=transform)

batch_size = 256
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [39]:
with h5py.File("../../workspace/Rain13K.hdf5", "w") as f:
    train_h5 = f.create_dataset("train", (len(train_dataset), 2, 3, 256, 256), dtype="f")
    for batch_idx, (X, Y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        XY = torch.stack([X, Y], dim=1)
        train_h5[batch_idx * batch_size : (batch_idx + 1) * batch_size] = XY.to("cpu")

    test_h5 = f.create_dataset("test", (len(test_dataset), 2, 3, 256, 256), dtype="f")
    for batch_idx, (X, Y) in tqdm(enumerate(test_loader), total=len(test_loader)):
        XY = torch.stack([X, Y], dim=1)
        test_h5[batch_idx * batch_size : (batch_idx + 1) * batch_size] = XY.to("cpu")

!cp data/Rain13K.hdf5 data/Rain13K.hdf5.bak

  0%|          | 0/54 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

cp: cannot stat 'data/Rain13K.hdf5': No such file or directory


In [41]:
with h5py.File("../../workspace/Rain13K.hdf5", "r") as f:
    print(f.keys())
    print(f["train"].shape)
    print(f["test"].shape)

<KeysViewHDF5 ['test', 'train']>
(13711, 2, 3, 256, 256)
(98, 2, 3, 256, 256)


In [42]:
ds = h5py.File("../../workspace/Rain13K.hdf5", "r")
train_ds = ds["train"]
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=False)

In [45]:
for batch_idx, XY in tqdm(enumerate(train_loader)):
    XY = XY.to("cuda")
    X, Y = XY[:, 0], XY[:, 1]

0it [00:00, ?it/s]

KeyboardInterrupt: 

In [46]:
def get_dataset(
    ds_root="../../workspace/Rain13K.hdf5", batch_size=64, train_transform=None, test_transform=None, num_workers=4
):
    ds = h5py.File(ds_root, "r")
    train_ds = ds["train"]
    train_ds, val_ds = torch.utils.data.random_split(train_ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
    test_ds = ds["test"]

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=num_workers
    )
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers
    )
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers
    )
    return train_loader, val_loader, test_loader

In [47]:
train_loader, val_loader, test_loader = get_dataset()

In [53]:
for batch_idx, batch in tqdm(enumerate(train_loader)):
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        batch = batch.to("cuda", non_blocking=True)
        imgs = batch[:, 0, :, :, :]
        labels = batch[:, 1, :, :, :]

0it [00:00, ?it/s]