# Online, parallel means

TODO:

In [1]:
from typing import override

import numpy as np
import numpy.typing as npt

In [2]:
np.random.seed(8947501)

data = np.random.uniform(low=1, high=10_000, size=(10_000, 64, 64))

In [3]:
mean_true = np.mean(data, axis=0)
print(mean_true)

[[4979.06187109 5042.03419628 5007.30759842 ... 5027.53444031
  5027.39917602 4988.98437983]
 [5004.40211044 4970.12877011 4974.80460791 ... 4987.42125333
  5020.60978713 4974.83743075]
 [5015.70418641 4964.27107736 4989.80859755 ... 4977.72149064
  5010.71360439 5037.07001349]
 ...
 [4959.73941289 4994.51878591 4988.9707429  ... 5038.30511538
  5065.64332467 4979.81483695]
 [5013.08498048 4979.70225719 4998.08322856 ... 5004.07146257
  5059.87903943 5005.13015452]
 [5027.95742434 5004.89616586 5021.97812961 ... 5015.61308805
  5004.63068546 4987.04909058]]


In [None]:
from raygent.task import Task
from raygent.results import MeanResult


class MeanTask(Task[MeanResult[npt.NDArray[np.float64]]]):
    """
    A task that computes the element-wise partial mean of a batch of 2D NumPy arrays.

    This task uses the batch processing method to compute the mean for all items in
    the input list and returns a tuple containing: `(partial_mean, count)`.
    where partial_mean is the element-wise mean computed over the batch, and count is
    the number of observations in the batch.
    """

    @override
    def do(
        self, batch: npt.NDArray[np.float64], *args: object, **kwargs: object
    ) -> MeanResult[npt.NDArray[np.float64]]:
        # Convert the list of 2D arrays into a single 3D NumPy array.
        arr = np.array(batch, dtype=np.float64)
        # Compute the element-wise mean over the first axis (i.e. across all observations).
        partial_mean = np.mean(arr, axis=0)
        # The count is the number of observations processed in this batch.
        count = arr.shape[0]
        return MeanResult(value=partial_mean, count=count)

In [5]:
from raygent.runner import TaskRunner


from raygent.results.handlers import OnlineMeanResultsHandler

runner = TaskRunner[
    npt.NDArray[np.float64], OnlineMeanResultsHandler[npt.NDArray[np.float64]]
](
    task_cls=MeanTask,
    handler_cls=OnlineMeanResultsHandler,
    in_parallel=True,
    n_cores=8,
    n_cores_worker=1,
)


handler = runner.submit_tasks(data, batch_size=50)

mean_parallel = handler.get().value

print("Element-wise Mean:")
print(mean_parallel)

2025-07-15 22:41:45,487	INFO worker.py:1888 -- Started a local Ray instance.


Element-wise Mean:
[[4979.06187109 5042.03419628 5007.30759842 ... 5027.53444031
  5027.39917602 4988.98437983]
 [5004.40211044 4970.12877011 4974.80460791 ... 4987.42125333
  5020.60978713 4974.83743075]
 [5015.70418641 4964.27107736 4989.80859755 ... 4977.72149064
  5010.71360439 5037.07001349]
 ...
 [4959.73941289 4994.51878591 4988.9707429  ... 5038.30511538
  5065.64332467 4979.81483695]
 [5013.08498048 4979.70225719 4998.08322856 ... 5004.07146257
  5059.87903943 5005.13015452]
 [5027.95742434 5004.89616586 5021.97812961 ... 5015.61308805
  5004.63068546 4987.04909058]]


In [6]:
print("Parallel versus NumPy error")
print(np.mean(mean_parallel - mean_true))

Parallel versus NumPy error
5.517808432387028e-13
