To make your code run with maximum efficiency you also need to load your data efficiently into your device's memory. Fortunately PyTorch offers a tool to make data loading easy. It's called a DataLoader. A DataLoader uses multiple workers to simultanously load data from a Dataset and optionally uses a Sampler to sample data entries and form a batch.

If you can randomly access your data, using a DataLoader is very easy: You simply need to implement a Dataset class that implements `__getitem__` (to read each data item) and `__len__` (to return the number of items in the dataset) methods. For example here's how to load images from a given directory:

In [1]:
import torch
from torch.utils.data import Dataset

class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor, target_tensor, transform=None):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
        self.transform = transform

    def __len__(self):
        return len(self.data_tensor)

    def __getitem__(self, idx):
        data = self.data_tensor[idx]
        target = self.target_tensor[idx]
        
        if self.transform:
            data = self.transform(data)

        return data, target

# Let's create some random data tensors for demonstration
data_tensor = torch.randn(100, 3, 32, 32)  # Assuming 100 images with shape (3, 32, 32)
target_tensor = torch.randint(0, 10, (100,))  # Assuming 100 labels

# Create an instance of the custom dataset
dataset = CustomTensorDataset(data_tensor=data_tensor, target_tensor=target_tensor)

You can then do the following:

In [10]:
# dataloader = torch.utils.data.DataLoader(dataset, num_workers=2)
# for data in dataloader:
#     print(data[0].shape, data[1].shape)
#     break

Using a DataLoader to read data with random access may be ok if you have fast storage or if your data items are large. But imagine having a network file system with slow connection. Requesting individual files this way can be extremely slow and would probably end up becoming the bottleneck of your training pipeline.

A better approach is to store your data in a contiguous file format which can be read sequentially. For example if you have a large collection of images you can use tar to create a single archive and extract files from the archive sequentially in python. To do this you can use PyTorch's IterableDataset. To create an IterableDataset class you only need to implement an __iter__ method which sequentially reads and yields data items from the dataset.

In [8]:
import os
import tarfile
from PIL import Image
from io import BytesIO

from torch.utils.data import IterableDataset
from torchvision.datasets.utils import extract_archive

class TarImageDataset(IterableDataset):
    def __init__(self, tar_path, transform=None):
        self.tar_path = tar_path
        self.transform = transform

        # Extract the tar archive
        self.extracted_folder = self.extract_tar()

        # List all files in the extracted folder
        self.file_list = os.listdir(self.extracted_folder)

    def extract_tar(self):
        # Extract the tar archive to a temporary folder
        extracted_folder = tarfile.open(self.tar_path)
        extracted_folder.extractall()
        extracted_folder.close()
        return os.path.splitext(self.tar_path)[0]  # Temporary folder name

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # In single-process data loading
            return self.generate_samples()
        else:  # In multi-process data loading
            per_worker = int(len(self.file_list) / worker_info.num_workers)
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = (worker_id + 1) * per_worker
            return iter(self.file_list[start:end])

    def generate_samples(self):
        for file_name in self.file_list:
            image_path = os.path.join(self.extracted_folder, file_name)
            image = Image.open(image_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)

            yield image

# dataset = TarImageDataset(tar_path='images.tar', transform=None)

But there's a major problem with this implementation. If you try to use DataLoader to read from this dataset with more than one worker you'd observe a lot of duplicated images:

In [11]:
# dataloader = torch.utils.data.DataLoader(TarImageDataset("/data/imagenet.tar"), num_workers=8)
# for data in dataloader:
#     # data contains duplicated items

The problem is that each worker creates a separate instance of the dataset and each would start from the beginning of the dataset. One way to avoid this is to instead of having one tar file, split your data into num_workers separate tar files and load each with a separate worker:

In [12]:
class TarImageDataset(IterableDataset):
    def __init__(self, paths):
        super().__init__()
        self.paths = paths

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        # For simplicity we assume num_workers is equal to number of tar files
        if worker_info is None or worker_info.num_workers != len(self.paths):
            raise ValueError("Number of workers doesn't match number of files.")
        yield from tar_image_iterator(self.paths[worker_info.worker_id])

This is how our dataset class can be used:



In [13]:
# dataloader = torch.utils.data.DataLoader(
#     TarImageDataset(["/data/imagenet_part1.tar", "/data/imagenet_part2.tar"]), num_workers=2)
# for data in dataloader:
#     # do something with data

We discussed a simple strategy to avoid duplicated entries problem. [tfrecord](https://github.com/vahidk/tfrecord) package uses slightly more sophisticated strategies to shard your data on the fly.