In [1]:
import jax
import jax.numpy as jnp

x: jax.Array = jnp.arange(5)
isinstance(x, jax.Array)

True

In [2]:
x.devices()

{CudaDevice(id=0)}

In [None]:
# Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts.
# refer to Introduction to parallel programming
x.sharding

SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)

In [None]:
# Tracer
# The value printed is not the array x, but a Tracer instance that represents essential attributes of x, 
# such as its shape and dtype. By executing the function with traced values, 
# JAX can determine the sequence of operations encoded by the function before those operations are actually executed: 
# transformations like jit(), vmap(), and grad() can then map this sequence of input operations to a transformed sequence of operations.



@jax.jit
def f(x):
  print(x)  # print the tracer
  return x + 1

x = jnp.arange(5)
result = f(x)

Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>
