In [1]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Sequence, TypeAlias, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
import time


def load_model():
    # A dummy model.
    def model(batch: pd.DataFrame) -> pd.DataFrame:

        # Dummy payload so copying the model will actually copy some data
        # across nodes.
        model.payload = np.zeros(10)
        time.sleep(2)
        return pd.DataFrame({"default": batch["passenger_count"] % 2 == 0})

    return model

In [3]:
# Load data

np.random.seed(42)
input_df_1: pd.DataFrame = pd.DataFrame(
    {"passenger_count": np.random.choice(a=np.arange(1, 10), size=50_000)}
)
input_df_2: pd.DataFrame = pd.DataFrame(
    {"passenger_count": np.random.choice(a=np.arange(2, 9), size=55_000)}
)
input_df_1.head()

Unnamed: 0,passenger_count
0,7
1,4
2,8
3,5
4,7


In [4]:
import ray


@ray.remote(num_cpus=2, max_retries=2)
def make_prediction(model, data: pd.DataFrame):
    result = model(data)

    return result.shape

In [5]:
# ray.put() the model just once to local object store, and then pass the
# reference to the remote tasks.
model = load_model()
model_ref = ray.put(model)

result_refs = []

# Launch all prediction tasks.
for data in [input_df_1, input_df_2]:
    # Launch a prediction task by passing model reference and the data to it.
    # NOTE: it would be highly inefficient if you are passing the model itself
    # like make_prediction.remote(model, data), which in order to pass the model
    # to remote node will ray.put(model) for each task, potentially overwhelming
    # the local object store and causing out-of-disk error.
    result_refs.append(make_prediction.remote(model_ref, data))

results = ray.get(result_refs)

2024-04-16 00:36:53,635	INFO worker.py:1743 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [6]:
# Let's check prediction output size.
for r in results:
    console.print(f"Predictions: {r}", style="bold red")

In [7]:
# Start a local cluster
ray.init(ignore_reinit_error=True)

2024-04-16 00:37:01,313	INFO worker.py:1585 -- Calling ray.init() again after it has already been called.


0,1
Python version:,3.11.8
Ray version:,2.10.0
Dashboard:,http://127.0.0.1:8265


In [17]:
database: list[str] = [
    "Learning",
    "Ray",
    "Flexible",
    "Distributed",
    "Python",
    "for",
    "Machine",
    "Learning",
]


def retrieve(item: int):
    time.sleep(item / 10.0)
    return item, database[item]


def print_runtime(input_data: list[Any], start_time: time.time):
    print(f"Runtime: {time.time() - start_time:.2f} seconds, \ndata:")
    print(*input_data, sep="\n")

In [18]:
start = time.time()
data: list[tuple[int, str]] = [retrieve(item) for item in range(8)]
print_runtime(data, start)

Runtime: 2.83 seconds, 
data:
(0, 'Learning')
(1, 'Ray')
(2, 'Flexible')
(3, 'Distributed')
(4, 'Python')
(5, 'for')
(6, 'Machine')
(7, 'Learning')


In [19]:
# Parallelize
@ray.remote
def retrieve_task(item: int):
    return retrieve(item)

In [20]:
start = time.time()
object_references = [retrieve_task.remote(item) for item in range(8)]
data = ray.get(object_references)
print_runtime(data, start)

Runtime: 0.71 seconds, 
data:
(0, 'Learning')
(1, 'Ray')
(2, 'Flexible')
(3, 'Distributed')
(4, 'Python')
(5, 'for')
(6, 'Machine')
(7, 'Learning')


### Object Stores

- The retrieve definition directly accesses items from the database. While this works well on a local Ray cluster, consider how it functions on an actual cluster with multiple computers.
- A `Ray cluster` has:
  -  a `head node` with `a driver process` 
  -  and `multiple worker nodes` with `worker processes` executing tasks.
- In this scenario the database is only defined on the driver, but the worker processes need access to it to run the retrieve task.
- Ray’s solution for sharing objects between the driver and workers or between workers is to use the `ray.put` function to place the data into `Ray’s distributed object store`.
- In the `retrieve_task` definition, you can add a db argument to pass later as the `db_object_ref` object.


In [21]:
db_object_ref = ray.put(database)


@ray.remote
def retrieve_task(item, db):
    time.sleep(item / 10.0)
    return item, db[item]

- By using the object store, you allow Ray to manage data access throughout the entire cluster.
- Although the object store involves some overhead, it improves performance for larger datasets.
- This step is crucial for a truly distributed environment.
- Rerun the example with the retrieve_task function to confirm that it executes as you expect.

In [22]:
start = time.time()
object_references: list[Any] = [
    retrieve_task.remote(item, db_object_ref) for item in range(8)
]
data: list[Any] = ray.get(object_references)
print_runtime(data, start)

Runtime: 0.71 seconds, 
data:
(0, 'Learning')
(1, 'Ray')
(2, 'Flexible')
(3, 'Distributed')
(4, 'Python')
(5, 'for')
(6, 'Machine')
(7, 'Learning')


## Non-blocking Calls


- In the previous section, you used `ray.get(object_references)` to retrieve results.
- This call blocks the driver process until all results are available. This dependency can cause problems if each database item takes several minutes to process.
- More efficiency gains are possible if you allow the driver process to perform other tasks while waiting for results, and to process results as they are completed rather than waiting for all items to finish.
- Additionally, if one of the database items cannot be retrieved due to an issue like a deadlock in the database connection, the driver hangs indefinitely.
- To prevent indefinite hangs, set reasonable `timeout` values when using the wait function.
- For example, if you want to wait less than ten times the time of the slowest data retrieval task, use the wait function to stop the task after that time has passed.



In [24]:
start = time.time()
object_references: list[Any] = [
    retrieve_task.remote(item, db_object_ref) for item in range(8)
]
all_data: list[Any] = []

while len(object_references) > 0:
    finished, object_references = ray.wait(object_references, timeout=7.0)
    data = ray.get(finished)
    # print_runtime(data, start)
    all_data.extend(data)

print_runtime(all_data, start)

Runtime: 0.71 seconds, 
data:
(0, 'Learning')
(1, 'Ray')
(2, 'Flexible')
(3, 'Distributed')
(4, 'Python')
(5, 'for')
(6, 'Machine')
(7, 'Learning')
