In [1]:
import ray
import time

In [2]:
# A regular Python function.
def normal_function():
    return 1


# By adding the `@ray.remote` decorator, a regular Python function
# becomes a Ray remote function.
@ray.remote
def my_function():
    return 1

In [3]:
# To invoke this remote function, use the `remote` method.
# This will immediately return an object ref (a future) and then create
# a task that will be executed on a worker process.
obj_ref = my_function.remote()

# The result can be retrieved with ``ray.get``.
assert ray.get(obj_ref) == 1

In [4]:
@ray.remote
def slow_function():
    time.sleep(10)
    return 1

In [5]:
# Invocations of Ray remote functions happen in parallel.
# All computation is performed in the background, driven by Ray's internal event loop.
for _ in range(4):
    # This doesn't block.
    slow_function.remote()

In [6]:
@ray.remote(num_gpus=0.5)
def h():
    return 1

In [7]:
res = h.remote()

In [8]:
ray.get(res)

1

In [9]:
@ray.remote(num_gpus=0.1)
def has_gpu():
    import jax
    from jax.lib import xla_bridge
    print(xla_bridge.get_backend().platform)
    x = jax.numpy.array((10,10))

In [10]:
jax_objs = [has_gpu.remote() for _ in range(3)]

[2m[36m(has_gpu pid=206541)[0m gpu
[2m[36m(has_gpu pid=206556)[0m gpu
[2m[36m(has_gpu pid=206547)[0m gpu


In [12]:
for o in jax_objs:
    ray.get(o)

In [13]:
ray.shutdown()