# Scaling Batch Inference with Ray Data

This template walks through GPU batch inference on a subset of the Imagenet dataset using a PyTorch ResNet model.

The framework and data format used in this template can be easily replaced to suit your own application!

> Slot in your code below wherever you see the ✂️ icon to build off of this template!

## Set up the dependencies

Since we're running on a distributed Ray cluster with multiple nodes, we need to first
set up dependencies so that our batch inference workers can access all the required packages.

There are two sets of dependencies that we'll set up.

### Set up local dependencies

The first set contains any dependencies that are needed locally by this notebook.
Install the dependencies with the following command:

```
pip install -r requirements.txt
```


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from typing import Dict

import ray


### Set up runtime dependencies

The second set contains dependencies that are required by each worker.

Later on, we define a class that implements our custom model initialization and inference logic.
Ray Data will run batch inference on many workers using copies of this class somewhere in our Ray cluster.
It's important to note that the workers may live on a different node than the one running this notebook.
Therefore, a dependency installed here locally may not be accessible to our training code at runtime.

To address this, we can specify a [Ray Runtime Environment](https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#runtime-environments)
to dynamically set up dependencies, which enables us to import the specified dependencies
on the workers.

In [None]:
ray.init(runtime_env={"pip": ["torch", "torchvision"]})


## Load the dataset

> ✂️ Replace this function with logic to load your own data with Ray Data.
>
> See [the Ray Data guide on creating datasets](https://docs.ray.io/en/latest/data/creating-datasets.html) to learn how to create a dataset based on the data type and how file storage format.

In [None]:
def load_ray_dataset():
    from ray.data.datasource.partitioning import Partitioning

    s3_uri = "s3://anonymous@air-example-data-2/imagenette2/val/"
    partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)
    ds = ray.data.read_images(
        s3_uri, size=(256, 256), partitioning=partitioning, mode="RGB"
    )
    return ds


In [None]:
ds = load_ray_dataset()


In [None]:
sample_images = [sample["image"] for sample in ds.take(5)]

_, axs = plt.subplots(1, 5, figsize=(10, 5))

for i, image in enumerate(sample_images):
    axs[i].imshow(image)
    axs[i].axis("off")


## Preprocess the dataset

We may need to preprocess the dataset before passing it to the model.
This just amounts to writing a function that performs the preprocessing logic, and then
applying the function to the entire dataset with a call to `map_batches`.

> ✂️ Replace this function with your own data preprocessing logic.

In [None]:
def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    import torch
    from torchvision import transforms

    def to_tensor(batch: np.ndarray) -> torch.Tensor:
        tensor = torch.as_tensor(batch, dtype=torch.float)
        # (B, H, W, C) -> (B, C, H, W)
        tensor = tensor.permute(0, 3, 1, 2).contiguous()
        # [0., 255.] -> [0., 1.]
        tensor = tensor.div(255)
        return tensor

    transform = transforms.Compose(
        [
            transforms.Lambda(to_tensor),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return {"image": transform(batch["image"]).numpy()}


In [None]:
ds = ds.map_batches(preprocess, batch_format="numpy")

print("Dataset schema:\n", ds.schema())
print("Number of images:", ds.count())


## Set up your model for inference

Define a class that loads the model on initialization, and also performs inference with the loaded model whenever the class is called (by implementing `__call__`).

> ✂️ Replace parts of this callable class with your own model initialization and inference logic.

In [None]:
class PredictCallable:
    def __init__(self):
        # <Replace this with your own model initialization>
        import torch
        from torchvision import models
        from torchvision.models import ResNet152_Weights

        self.model = models.resnet152(weights=ResNet152_Weights.DEFAULT)
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # <Replace this with your own model inference logic>
        import torch

        input_data = torch.as_tensor(batch["image"], device=self.device)
        with torch.inference_mode():
            pred = self.model(input_data)
        return {"predicted_class_index": pred.argmax(dim=1).detach().cpu().numpy()}


## Run batch inference

We'll first configure the number of workers and the resource requirements of each worker.

These defaults will assume that your cluster has 4 GPUs available.
Be sure to stay within the resource constraints of your Ray Cluster if autoscaling is not enabled.

`NUM_GPUS_PER_WORKER` can be a fractional amount! This will leverage Ray's fractional resource allocation, which means you can schedule multiple batch inference workers to use the same GPU, assuming that the models can all fit in GPU memory.

In [None]:
NUM_WORKERS: int = 4
NUM_GPUS_PER_WORKER: float = 1  # 0 <= NUM_GPUS_PER_WORKER <= 1


In [None]:
if NUM_WORKERS * NUM_GPUS_PER_WORKER > ray.available_resources()["GPU"]:
    print(
        "Your cluster does not currently have enough resources to run with these settings. "
        "Consider decreasing the number of workers, or decreasing the resources needed "
        "per worker."
    )

assert (
    0 <= NUM_GPUS_PER_WORKER <= 1
), "`NUM_GPUS_PER_WORKER` must be within the range [0, 1]"


You can check the available resources in your Ray Cluster with:

In [None]:
!ray status

Now, use Ray Data to perform batch inference using `NUM_WORKERS` copies of the `PredictCallable` class you defined.

In [None]:
predictions = ds.map_batches(
    PredictCallable,
    batch_size=128,
    compute=ray.data.ActorPoolStrategy(
        # Fix the number of batch inference workers to `NUM_WORKERS`.
        min_size=NUM_WORKERS,
        max_size=NUM_WORKERS,
    ),
    num_gpus=NUM_GPUS_PER_WORKER,
    batch_format="numpy",
)

preds = predictions.materialize()


See the appendix for more information about setting `min_size` and `max_size`.

## View the predictions

Show the first few predictions, which will show the predicted class labels of the images shown earlier! These first few predictions should show index 0, which maps to the class label `"tench"` (a type of fish).

In [None]:
preds.take(5)


Shard the predictions into a few partitions, and save each partition to a file.

This currently saves to the local filesystem under a temporary directory, but you could also save to a cloud bucket (e.g., `s3://predictions-bucket`).

In [None]:
num_shards = 3

temp_dir = tempfile.mkdtemp()
predictions.repartition(num_shards).write_parquet(temp_dir)
print(f"Predictions saved to `{temp_dir}`!")


## Summary

This template used [Ray Data](https://docs.ray.io/en/latest/data/dataset.html) to scale out batch inference. Ray Data is one of many libraries under the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html). See [this blog post](https://www.anyscale.com/blog/model-batch-inference-in-ray-actors-actorpool-and-datasets) for more details on batch inference with Ray!

At a high level, this template showed how to:
1. [Load your dataset using Ray Data.](https://docs.ray.io/en/latest/data/loading-data.html)
2. [Preprocess your dataset before feeding it to your model.](https://docs.ray.io/en/latest/data/transforming-data.html)
3. [Initialize your model and perform inference on a shard of your dataset with a remote actor.](https://docs.ray.io/en/latest/data/transforming-data.html#reduce-setup-overheads-using-actors)
4. [Save your prediction results.](https://docs.ray.io/en/latest/data/api/input_output.html)



### Appendix

#### Automatically determine the number of workers

Play around with the `min_size` and `max_size` parameters to enable Ray Data to scale the number of workers based on the dataset size.
For example, try commenting out `max_size`. This will start with `min_size` workers, then spin up as many new workers as needed (until the cluster runs out of resources to assign).