Skip to content

Iterating over Image feature columns is extremely slow #7508

@sohamparikh

Description

@sohamparikh

We are trying to load datasets where the image column stores PIL.PngImagePlugin.PngImageFile images. However, iterating over these datasets is extremely slow.
What I have found:

  1. It is the presence of the image column that causes the slowdown. Removing the column from the dataset results in blazingly fast (as expected) times
  2. It is ~2x faster to iterate when the column contains a single image as opposed to a list of images i.e., the feature is a Sequence of Image objects. We often need multiple images per sample, so we need to work with a list of images
  3. It is ~17x faster to store paths to PNG files and load them using PIL.Image.open, as opposed to iterating over a Dataset with an Image column, and ~30x faster compared to Sequence of Images. See a simple script below with an openly available dataset.

It would be great to understand the standard practices for storing and loading multimodal datasets (image + text).

https://huggingface.co/docs/datasets/en/image_load seems a bit underdeveloped? (e.g., dataset.decode only works with IterableDataset, but it's not clear from the doc)

Thanks!

from datasets import load_dataset, load_from_disk
from PIL import Image
from pathlib import Path

ds = load_dataset("getomni-ai/ocr-benchmark")

for idx, sample in enumerate(ds["test"]):
    image = sample["image"]
    image.save(f"/tmp/ds_files/images/image_{idx}.png")

ds.save_to_disk("/tmp/ds_columns")

# Remove the 'image' column
ds["test"] = ds["test"].remove_columns(["image"])

# Create image paths for each sample
image_paths = [f"images/image_{idx}.png" for idx in range(len(ds["test"]))]

# Add the 'image_path' column to the dataset
ds["test"] = ds["test"].add_column("image_path", image_paths)

# Save the updated dataset
ds.save_to_disk("/tmp/ds_files")
files_path = Path("/tmp/ds_files")
column_path = Path("/tmp/ds_columns")

# load and benchmark
ds_file = load_from_disk(files_path)
ds_column = load_from_disk(column_path)
import time


images_files = []
start = time.time()
for idx in range(len(ds_file["test"])):
    image_path = files_path / ds_file["test"][idx]["image_path"]
    image = Image.open(image_path)
    images_files.append(image)
end = time.time()
print(f"Time taken to load images from files: {end - start} seconds")

# Time taken to load images from files: 1.2364635467529297 seconds


images_column = []
start = time.time()
for idx in range(len(ds_column["test"])):
    images_column.append(ds_column["test"][idx]["image"])
end = time.time()
print(f"Time taken to load images from columns: {end - start} seconds")

# Time taken to load images from columns: 20.49347186088562 seconds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions