Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficiently handling large-scale vision datasets #374

Open
clemsgrs opened this issue Feb 10, 2024 · 2 comments
Open

Efficiently handling large-scale vision datasets #374

clemsgrs opened this issue Feb 10, 2024 · 2 comments

Comments

@clemsgrs
Copy link

clemsgrs commented Feb 10, 2024

Hi, I'm using DINOv2 to pretrain a ViT on a dataset significantly larger than ImageNet22k (between 100M and 1B jpg images). I sticked to the ImageNet22k dataset class for handling and loading data, i.e. utilizing a combination of tarball files for storing images and a single npy file for metadata (start and end offsets + information to know in which tarball file a given image is located). I put the code snippet below.

Unfortunately, I am facing very slow data loading times:

  1. Large tarball files: some tarballs I work with containing as many as 6M images. I suspect this increases RAM usage, which could explain the to slow data loading times -- or even out-of-memory errors -- I face.

  2. To mitigate this issue, I split the large tarballs into smaller ones (of 1Gb). Despite offering some relief by reducing the memory footprint during data loading, this solution doesn't scale well with the batch size : the bigger the batch size, the more tarball files to open/close concurrently, which seems to add significant overhead as it slows the data loading process.

I've tried looking into alternative tools (WebDataset, TorchData), but wasn't successful. I am therefore reaching out for any advice, or alternative strategies to handle large-scale vision datasets. Thank you!

Dataset code
import numpy as np

from io import BytesIO
from typing import Any
from PIL import Image
from pathlib import Path

from mmap import ACCESS_READ, mmap
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets import VisionDataset
from functools import lru_cache


class Decoder:
    def decode(self) -> Any:
        raise NotImplementedError


class ImageDataDecoder(Decoder):
    def __init__(self, image_data: bytes) -> None:
        self._image_data = image_data

    def decode(self) -> Image:
        f = BytesIO(self._image_data)
        return Image.open(f).convert(mode="RGB")


class TargetDecoder(Decoder):
    def __init__(self, target: Any):
        self._target = target

    def decode(self) -> Any:
        return self._target


_DEFAULT_MMAP_CACHE_SIZE = 16  # Warning: This can exhaust file descriptors


def _get_tarball_path(dataset_name: str) -> str:
    return f"{dataset_name}.tar"


def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
    @lru_cache(maxsize=mmap_cache_size)
    def _mmap_tarball(dataset_name: str) -> mmap:
        tarball_path = _get_tarball_path(dataset_name)
        tarball_full_path = Path(tarballs_root, tarball_path)
        with open(tarball_full_path) as f:
            return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)

    return _mmap_tarball


class FoundationDataset(VisionDataset):

    def __init__(
        self,
        *,
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
    ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        self._get_entries()
        self._get_dataset_names()
        self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)

    @property
    def _tarballs_root(self) -> str:
        return self.root

    @property
    def _entries_name(self) -> str:
        return "pretrain_entries.npy"

    def _get_entries(self) -> np.ndarray:
        self._entries = self._load_entries(self._entries_name)

    def _load_entries(self, _entries_name: str) -> np.ndarray:
        entries_path = Path(self.root, _entries_name)
        return np.load(entries_path, mmap_mode="r")

    def _get_filepaths_dict(self, dataset_name: str):
        return self._load_filepaths_dict(dataset_name)

    def _load_filepaths_dict(self, dataset_name: str):
        filepaths_dict_path = Path(self.root, f"{dataset_name}_file_indices.npy")
        return np.load(filepaths_dict_path, allow_pickle=True).item()

    def _get_dataset_names(self) -> dict:
        self._dataset_names = self._load_dataset_names()

    def _load_dataset_names(self) -> dict:
        dataset_dict_path = Path(self.root, "dataset_indices.npy")
        return np.load(dataset_dict_path, allow_pickle=True).item()

    def get_image_data(self, index: int) -> bytes:
        entry = self._entries[index]
        file_idx, start_offset, end_offset, dataset_idx = (
            entry[1],
            entry[2],
            entry[3],
            entry[4],
        )
        dataset_name = self._dataset_names[dataset_idx]
        filepaths_dict = self._get_filepaths_dict(dataset_name)
        filepath = filepaths_dict[file_idx]
        class_mmap = self._mmap_tarball(dataset_name)
        data = class_mmap[start_offset:end_offset]
        return data, Path(filepath)

    def get_target(self, index: int) -> Any:
        return int(self._entries[index][0])

    def get_targets(self) -> np.ndarray:
        return self._entries[:, 0]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        try:
            image_data, _ = self.get_image_data(index)
            image = ImageDataDecoder(image_data).decode()
        except Exception as e:
            raise RuntimeError(f"can not read image for sample {index} ({e})") from e
        target = self.get_target(index)
        target = TargetDecoder(target).decode()

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self._entries)
@clemsgrs clemsgrs changed the title Efficiently handling large-scale image datasets Efficiently handling large-scale vision datasets Feb 10, 2024
@CaedenMotley
Copy link

I ran into a similar issue and it really comes down to your devices computational power. One solution if you are running out of memory that I initially used is to write to the Hard Drive and back having it act as pseudo ram in a sense. This is incredibly slow though. If working with that large of a dataset I highly recommend offloading your processing onto a supercluster if you have the ability to do so. If you are simply trying to create and load the dataset I would recommend using hugging face to store and load the dataset in pieces, do not load all at once but rather in batches. Please let me know if I anything I said prior is not clear or not applicable to your situation.

@cdancette
Copy link

@clemsgrs Did you manage to find a solution to your issues ? I am facing similar problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants