# EX03: H5

## 1. Grab data

In [1]:
import h5py
import torch

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

# Download and save MNIST data to an H5 file
transform = transforms.ToTensor()
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Process the data into an H5 file
mnist_h5_file = 'mnist.h5'
with h5py.File(mnist_h5_file, 'w') as f:
    data = f.create_dataset('images', data=mnist_train.data.numpy())  # Save images
    labels = f.create_dataset('labels', data=mnist_train.targets.numpy())  # Save labels


## 2. Compare EasyLoader with a DataLoader + Dataset

Define a simple H5 Dataset.

In [2]:
import h5py

from torch.utils.data import Dataset


class SimpleH5Dataset(Dataset):

    def __init__(self, h5_file, keys):
        
        self.h5_file = h5_file
        self.keys = keys
        self.h5 = h5py.File(h5_file, 'r')
    
    def __len__(self):
        return len(self.h5[self.keys[0]])
    
    def __getitem__(self, idx):
        return [self.h5[key][idx] for key in self.keys]

Iterate over data using a simple DataLoader + Dataset.

In [3]:
from tqdm.auto import tqdm

dataset = SimpleH5Dataset(mnist_h5_file, ['images', 'labels'])
dl_simple = DataLoader(dataset, batch_size=100, shuffle=True)

for batch in tqdm(dl_simple):
    pass

  0%|          | 0/600 [00:00<?, ?it/s]

Now try using an EasyLoader, and improve speed using graining.

In [4]:
from easyloader.loader import H5DataLoader

dl_easy = H5DataLoader(mnist_h5_file, ['images', 'labels'], batch_size=100,
                       grain_size=10, shuffle=True)

for batch in tqdm(dl_easy):
    pass

  0%|          | 0/600 [00:00<?, ?it/s]