Skip to content

Commit

Permalink
[data][tests] Update image classification benchmarks (ray-project#38902)
Browse files Browse the repository at this point in the history
This updates includes a few fixes for image classification benchmarks:

    use Dataset.map instead of Dataset.map_batches.

[data] Implement zero-copy fusion for Read op ray-project#38789 ensures these will get fused with the Read, but map_batches also has some batch formatting overhead.
fix a bug in the benchmark related to image array dimensions
avoid a copy in the map transform
---------

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: Jim Thompson <jimthompson5802@gmail.com>
  • Loading branch information
stephanie-wang authored and jimthompson5802 committed Sep 12, 2023
1 parent 5493101 commit f645371
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 35 deletions.
16 changes: 5 additions & 11 deletions release/air_tests/air_benchmarks/mlperf-train/resnet50_ray_air.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def to_tensor_iterator():
print("epoch time", epoch, epoch_time_s)


def crop_and_flip_image_batch(image_batch):
def crop_and_flip_image(row):
transform = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(
Expand All @@ -220,12 +220,9 @@ def crop_and_flip_image_batch(image_batch):
torchvision.transforms.RandomHorizontalFlip(),
]
)
batch_size, height, width, channels = image_batch["image"].shape
tensor_shape = (batch_size, channels, height, width)
image_batch["image"] = transform(
torch.Tensor(image_batch["image"].reshape(tensor_shape))
)
return image_batch
# Make sure to use torch.tensor here to avoid a copy from numpy.
row["image"] = transform(torch.tensor(np.transpose(row["image"], axes=(2, 0, 1))))
return row


def decode_tf_record_batch(tf_record_batch: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -326,10 +323,7 @@ def convert_class_to_idx(df, classes):
convert_class_to_idx,
fn_kwargs={"classes": classes},
)
ds = ds.map_batches(
crop_and_flip_image_batch,
zero_copy_batch=True,
)
ds = ds.map(crop_and_flip_image)
else:
filenames = get_tfrecords_filenames(
data_root, num_images_per_epoch, num_images_per_input_file
Expand Down
43 changes: 19 additions & 24 deletions release/nightly_tests/dataset/image_loader_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import json
import tensorflow as tf
import numpy as np


DEFAULT_IMAGE_SIZE = 224
Expand Down Expand Up @@ -130,14 +131,21 @@ def get_transform(to_torch_tensor):

def crop_and_flip_image_batch(image_batch):
transform = get_transform(False)
batch_size, height, width, channels = image_batch["image"].shape
tensor_shape = (batch_size, channels, height, width)
image_batch["image"] = transform(
torch.Tensor(image_batch["image"].reshape(tensor_shape))
# Make sure to use torch.tensor here to avoid a copy from numpy.
# Original dims are (batch_size, channels, height, width).
torch.tensor(np.transpose(image_batch["image"], axes=(0, 3, 1, 2)))
)
return image_batch


def crop_and_flip_image(row):
transform = get_transform(False)
# Make sure to use torch.tensor here to avoid a copy from numpy.
row["image"] = transform(torch.tensor(np.transpose(row["image"], axes=(2, 0, 1))))
return row


if __name__ == "__main__":
import argparse

Expand Down Expand Up @@ -185,55 +193,42 @@ def crop_and_flip_image_batch(image_batch):
for i in range(args.num_epochs):
iterate(torch_dataset, "torch+transform", args.batch_size, metrics)

ray_dataset = ray.data.read_images(args.data_root).map_batches(
crop_and_flip_image_batch
)
ray_dataset = ray.data.read_images(args.data_root).map(crop_and_flip_image)
for i in range(args.num_epochs):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+transform",
"ray_data+map_transform",
args.batch_size,
metrics,
)

ray_dataset = ray.data.read_images(args.data_root).map_batches(
crop_and_flip_image_batch, zero_copy_batch=True
crop_and_flip_image_batch
)
for i in range(args.num_epochs):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+transform+zerocopy",
args.batch_size,
metrics,
)

ray_dataset = ray.data.read_images(args.data_root)
for i in range(args.num_epochs):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data",
"ray_data+transform",
args.batch_size,
metrics,
)

ray_dataset = ray.data.read_images(args.data_root).map_batches(
lambda x: x, batch_format="pyarrow", batch_size=args.batch_size
crop_and_flip_image_batch, zero_copy_batch=True
)
for i in range(args.num_epochs):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+dummy_pyarrow_transform",
"ray_data+transform+zerocopy",
args.batch_size,
metrics,
)

ray_dataset = ray.data.read_images(args.data_root).map_batches(
lambda x: x, batch_format="numpy", batch_size=args.batch_size
)
ray_dataset = ray.data.read_images(args.data_root)
for i in range(args.num_epochs):
iterate(
ray_dataset.iter_torch_batches(batch_size=args.batch_size),
"ray_data+dummy_np_transform",
"ray_data",
args.batch_size,
metrics,
)
Expand Down

0 comments on commit f645371

Please sign in to comment.