<a href="https://colab.research.google.com/github/mridul-sahu/jax-sharding-tutorials/blob/main/Project_Aurora_Mastering_JAX_Arrays.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project Aurora: Mastering JAX Arrays - The Lifeblood of Colossal AI

Welcome, Architect, to a foundational briefing for Project Aurora. Our ambition is to construct AI models of unprecedented scale and intelligence. To do this, we must first master the very essence of data within our chosen framework, JAX. It's not enough to just have numbers; we need to understand how JAX represents, manages, and places data across Aurora's vast computational hardware. This is where JAX Arrays come into play – they are far more than simple containers; they are the lifeblood of our models.

## The Anatomy of a JAX Array: More Than Just Numbers

In Project Aurora, precision and efficiency are paramount. A JAX array, represented by `jax.Array`, is not merely a collection of numbers like a standard Python list or even a NumPy array. It's a sophisticated entity designed for high-performance numerical computation, especially on accelerators like GPUs and TPUs.

**Key Characteristics of JAX Arrays:**

1.  **Immutability:** Once a JAX array is created, its values cannot be changed in place. Any operation that appears to modify an array actually returns a *new* JAX array. This functional purity is a cornerstone of JAX, enabling cleaner code, easier reasoning about transformations, and powerful compiler optimizations. For Aurora, this means predictability and the ability to safely parallelize operations.

2.  **Device Affinity (`DeviceArray`):** This is where JAX arrays truly diverge from their host-bound cousins. A JAX array, particularly when it's a `jax.DeviceArray` (a common type of `jax.Array`), has a "home." It resides physically on a specific computational **device** – a CPU, a GPU, or a TPU. This is crucial because computations involving this array will ideally happen on that same device, minimizing costly data transfers.
    * A NumPy array, by contrast, lives in the host computer's RAM, accessible by the CPU.

4.  **Asynchronous Execution:** Many operations that create or modify JAX arrays, especially on accelerators, are dispatched asynchronously. This means JAX can tell the device to start work and the Python code can continue running without waiting. This keeps Aurora's systems responsive and efficient. We'll touch on how to manage this later.

Let's see these concepts in action.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

print(f"JAX version: {jax.__version__}")
print(f"Default JAX backend: {jax.default_backend()}")

# A standard NumPy array (lives in Host RAM)
numpy_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
print(f"\n--- NumPy Array ---")
print(f"NumPy array: {numpy_arr}")
print(f"Type of NumPy array: {type(numpy_arr)}")

# Creating a JAX array from a NumPy array
# jnp.array() will typically place it on the default JAX device
jax_arr_from_numpy = jnp.array(numpy_arr)
print(f"\n--- JAX Array (from NumPy) ---")
print(f"JAX array: {jax_arr_from_numpy}")
print(f"Type of JAX array: {type(jax_arr_from_numpy)}")
print(f"Device of JAX array: {jax_arr_from_numpy.device}")
print(f"Shape of JAX array: {jax_arr_from_numpy.shape}")
print(f"Dtype of JAX array: {jax_arr_from_numpy.dtype}")

# Creating a JAX array directly from a Python list
python_list = [[5.0, 6.0], [7.0, 8.0]]
jax_arr_from_list = jnp.array(python_list)
print(f"\n--- JAX Array (from Python List) ---")
print(f"JAX array: {jax_arr_from_list}")
print(f"Device of JAX array: {jax_arr_from_list.device}")

# Immutability in action
original_jax_arr = jnp.array([10, 20, 30])
print(f"\n--- Immutability ---")
print(f"Original JAX array ({id(original_jax_arr)}): {original_jax_arr}")
modified_jax_arr = original_jax_arr.at[0].set(100) # This creates a NEW array
print(f"Supposedly modified JAX array ({id(modified_jax_arr)}): {modified_jax_arr}")
print(f"Original JAX array after .at[0].set(100) ({id(original_jax_arr)}): {original_jax_arr}")
# Note the different IDs (in most cases) and that 'original_jax_arr' is unchanged.

JAX version: 0.5.2
Default JAX backend: gpu

--- NumPy Array ---
NumPy array: [1. 2. 3. 4.]
Type of NumPy array: <class 'numpy.ndarray'>

--- JAX Array (from NumPy) ---
JAX array: [1. 2. 3. 4.]
Type of JAX array: <class 'jaxlib.xla_extension.ArrayImpl'>
Device of JAX array: cuda:0
Shape of JAX array: (4,)
Dtype of JAX array: float32

--- JAX Array (from Python List) ---
JAX array: [[5. 6.]
 [7. 8.]]
Device of JAX array: cuda:0

--- Immutability ---
Original JAX array (766908064): [10 20 30]
Supposedly modified JAX array (775258032): [100  20  30]
Original JAX array after .at[0].set(100) (766908064): [10 20 30]


As Aurora's architects, understanding that our JAX arrays are device-aware and immutable is the first step to wielding them effectively. The `device` attribute tells us exactly which piece of Aurora's hardware is responsible for this specific piece of data.

## Discovering Aurora's Hardware: Where JAX Arrays Reside

To intelligently place and manage our JAX arrays, we first need a map of Aurora's computational landscape. JAX provides tools to discover the available devices.

In [2]:
print(f"\n--- Discovering Devices ---")
# List all available JAX devices (CPUs, GPUs, TPUs)
all_devices = jax.devices()
print(f"All available JAX devices: {all_devices}")
print(f"Total number of JAX devices: {jax.device_count()}")

# List devices local to the current JAX process
# In a typical single-process setup, this is often the same as jax.devices()
local_devices = jax.local_devices()
print(f"Local JAX devices: {local_devices}")
print(f"Number of local JAX devices: {jax.local_device_count()}")

# You can also query devices for a specific backend
try:
    cpu_devices = jax.devices("cpu")
    print(f"CPU devices: {cpu_devices}")
except:
    print("CPU backend not explicitly found or no CPU devices listed this way.")

try:
    gpu_devices = jax.devices("gpu")
    print(f"GPU devices: {gpu_devices}")
except:
    print("GPU backend not found or no GPU devices available.")

# The default device is typically the first one in the local_devices list
if local_devices:
    default_pytree_device = jax.tree_util.tree_leaves(jax_arr_from_list)[0].device
    print(f"The jax_arr_from_list is on device: {default_pytree_device}") # Same as jax_arr_from_list.device()
    print(f"This is likely the same as local_devices[0]: {local_devices[0]}")


--- Discovering Devices ---
All available JAX devices: [CudaDevice(id=0)]
Total number of JAX devices: 1
Local JAX devices: [CudaDevice(id=0)]
Number of local JAX devices: 1
CPU devices: [CpuDevice(id=0)]
GPU devices: [CudaDevice(id=0)]
The jax_arr_from_list is on device: cuda:0
This is likely the same as local_devices[0]: cuda:0


When we simply use `jnp.array()`, JAX places the new array on a default device, usually the most capable one it finds (e.g., a GPU or TPU if available, otherwise a CPU). For Aurora's complex models, relying on defaults isn't always optimal. We need explicit control.

## Taking Command: Explicitly Placing JAX Arrays with `jax.device_put()`

This is where Aurora's architects gain true mastery. `jax.device_put()` allows us to dictate precisely which device a JAX array should live on. This is crucial for:

* **Performance**: Minimizing data movement by placing arrays on the device where they'll be used.
* **Resource Management**: Distributing data across multiple accelerators for parallel processing (the focus of later Aurora briefings on sharding).
* **Interfacing**: Moving data from host-based NumPy arrays into the JAX-controlled device memory.

In [7]:
# Let's create some data on the host (CPU RAM)
host_blueprint_data = np.random.rand(2, 3).astype(np.float32)
print(f"\n--- Explicit Placement with device_put ---")
print(f"Host blueprint data (NumPy array on CPU RAM): \n{host_blueprint_data}")

# Select target devices (if available)
# For Aurora, imagine these are specific processing units in our vast cluster
target_cpu = None
if cpu_devices: # from previous cell
    target_cpu = cpu_devices[0]

target_accelerator = None
if 'gpu_devices' in locals() and gpu_devices:
    target_accelerator = gpu_devices[0]
elif local_devices and local_devices[0].platform.lower() != 'cpu': # any non-CPU as accelerator
    target_accelerator = local_devices[0]
else: # Fallback if no distinct accelerator, use CPU for demonstration
    target_accelerator = target_cpu if target_cpu else local_devices[0] if local_devices else None

print(f"Target CPU device for placement: {target_cpu}")
print(f"Target Accelerator device for placement: {target_accelerator}")

# 1. Host-to-Device (H2D) Transfer: NumPy array to a specific JAX device
if target_accelerator:
    print(f"\nPlacing host data onto Accelerator ({target_accelerator.platform}): {target_accelerator}")
    aurora_data_on_accel = jax.device_put(host_blueprint_data, device=target_accelerator)
    print(f"Aurora data on accelerator: \n{aurora_data_on_accel}")
    print(f"Device of aurora_data_on_accel: {aurora_data_on_accel.device}")
    print(f"Type: {type(aurora_data_on_accel)}")
else:
    print("\nNo specific accelerator found to demonstrate H2D, using default placement.")
    aurora_data_on_accel = jax.device_put(host_blueprint_data) # uses default device

# 2. Device-to-Device (D2D) Transfer (if distinct devices are available)
# Imagine transferring processed data from one Aurora GPU to another, or GPU to CPU JAX device
if target_cpu and target_accelerator and target_cpu != target_accelerator:
    print(f"\nTransferring data from Accelerator ({target_accelerator}) to CPU device ({target_cpu})")
    aurora_data_on_cpu_device = jax.device_put(aurora_data_on_accel, device=target_cpu)
    print(f"Aurora data now on CPU device: \n{aurora_data_on_cpu_device}")
    print(f"Device of aurora_data_on_cpu_device: {aurora_data_on_cpu_device.device}")
elif target_cpu == target_accelerator and target_cpu is not None :
    print(f"\nAccelerator and CPU target are the same JAX device ({target_cpu}), D2D demo for distinct devices not applicable here.")
    # If data is already on the target device, device_put can be a no-op or very fast.
    already_on_target_check = jax.device_put(aurora_data_on_accel, target_accelerator)
    print(f"Data put to its own device. Device: {already_on_target_check.device()}. ID of array: {id(already_on_target_check)} vs {id(aurora_data_on_accel)}")
else:
    print("\nNot enough distinct devices to demonstrate D2D transfer clearly.")

# device_put also works with JAX arrays already on a device
if 'aurora_data_on_accel' in locals() and local_devices and len(local_devices)>0 :
    initial_arr = jnp.ones(3) # on default device
    print(f"\nInitial array on {initial_arr.device}")
    arr_on_specific_dev = jax.device_put(initial_arr, local_devices[0]) # ensure on first device
    print(f"Array explicitly on {arr_on_specific_dev.device}")


--- Explicit Placement with device_put ---
Host blueprint data (NumPy array on CPU RAM): 
[[0.22286515 0.93983203 0.92845   ]
 [0.37479925 0.9169095  0.02802854]]
Target CPU device for placement: TFRT_CPU_0
Target Accelerator device for placement: cuda:0

Placing host data onto Accelerator (gpu): cuda:0
Aurora data on accelerator: 
[[0.22286515 0.93983203 0.92845   ]
 [0.37479925 0.9169095  0.02802854]]
Device of aurora_data_on_accel: cuda:0
Type: <class 'jaxlib.xla_extension.ArrayImpl'>

Transferring data from Accelerator (cuda:0) to CPU device (TFRT_CPU_0)
Aurora data now on CPU device: 
[[0.22286515 0.93983203 0.92845   ]
 [0.37479925 0.9169095  0.02802854]]
Device of aurora_data_on_cpu_device: TFRT_CPU_0

Initial array on cuda:0
Array explicitly on cuda:0


With `jax.device_put`, we give Aurora direct instructions. If data is already a JAX array on the target device, `jax.device_put` is often very efficient, potentially just returning the same array. If it needs to move (e.g., host RAM to GPU HBM, or GPU0 HBM to GPU1 HBM), a transfer occurs.

## Retrieving Aurora's Discoveries: jax.device_get()

Sometimes, after complex computations on Aurora's accelerators, we need to bring results back to the host CPU – perhaps for saving to disk, for visualization with libraries like Matplotlib, or for parts of the application logic that run in standard Python/NumPy. For this, we use `jax.device_get()`.

In [8]:
print(f"\n--- Retrieving Data with device_get ---")
# Assuming aurora_data_on_accel is a JAX array on an accelerator (from previous step)
if 'aurora_data_on_accel' not in locals(): # Create one if it doesn't exist
    aurora_data_on_accel = jnp.array([[10.,20.],[30.,40.]], device=target_accelerator if target_accelerator else local_devices[0])
    print("(Recreated aurora_data_on_accel for this section)")


print(f"Data on accelerator ({aurora_data_on_accel.device}): \n{aurora_data_on_accel}")
print(f"Type before get: {type(aurora_data_on_accel)}")

# Retrieve data from the device to the host CPU RAM as a NumPy array
retrieved_blueprint = jax.device_get(aurora_data_on_accel)

print(f"\nRetrieved blueprint (now on host CPU RAM): \n{retrieved_blueprint}")
print(f"Type after get: {type(retrieved_blueprint)}") # Should be <class 'numpy.ndarray'>

# Now you can use standard NumPy operations or save it
sum_on_host = np.sum(retrieved_blueprint)
print(f"Sum computed on host using NumPy: {sum_on_host}")


--- Retrieving Data with device_get ---
Data on accelerator (cuda:0): 
[[0.22286515 0.93983203 0.92845   ]
 [0.37479925 0.9169095  0.02802854]]
Type before get: <class 'jaxlib.xla_extension.ArrayImpl'>

Retrieved blueprint (now on host CPU RAM): 
[[0.22286515 0.93983203 0.92845   ]
 [0.37479925 0.9169095  0.02802854]]
Type after get: <class 'numpy.ndarray'>
Sum computed on host using NumPy: 3.410884380340576


A critical point for Aurora's performance: `jax.device_get()` is a synchronous operation. Your Python program will pause and wait until the data has been fully copied from the device to the host. Frequent, unnecessary `device_get` calls within performance-critical loops can become serious bottlenecks. For Aurora, we aim to keep data on devices as long as possible, only retrieving it when absolutely necessary.

## The Pulse of Aurora: Asynchronous Execution and `block_until_ready()`

As mentioned, JAX operations, particularly those targeting accelerators, often execute asynchronously. The Python host tells the device "do this computation," and JAX quickly returns an `Array` *future* or *handle* to the result, even if the computation isn't finished yet. The host can then continue queueing more work.




In [13]:
print(f"\n--- Asynchronous Execution & Blocking ---")
# Let's use a JAX array that is hopefully on an accelerator
data_for_op = jnp.arange(1_000_000, dtype=jnp.float32).reshape(1000,1000)
if target_accelerator:
  data_for_op = jax.device_put(data_for_op, target_accelerator)
print(f"(Created large data_for_op on {data_for_op.device} for async demo)")

# This operation is dispatched to the device. Python might continue before it's done.
# For very fast ops, the effect might be hard to see without proper profiling.
print("Dispatching a potentially large computation...")
result_future = jnp.dot(data_for_op, data_for_op.T) # Matrix multiplication
print(f"Python host sees a result 'future' of type: {type(result_future)}")
print(f"Result future is on device: {result_future.device}")

# If we need to ensure the computation is complete before proceeding (e.g., timing,
# or using the result outside JAX), we use .block_until_ready()
result_future.block_until_ready()
print("Computation is now guaranteed to be complete.")

# Accessing the result (e.g., printing or jax.device_get) will also implicitly block.
print(f"First element of result: {result_future[0,0]}") # This would block until result[0,0] is ready
# (Don't print large matrices usually)


--- Asynchronous Execution & Blocking ---
(Created large data_for_op on cuda:0 for async demo)
Dispatching a potentially large computation...
Python host sees a result 'future' of type: <class 'jaxlib.xla_extension.ArrayImpl'>
Result future is on device: cuda:0
Computation is now guaranteed to be complete.
First element of result: 332833216.0


For Aurora's architects, `array.block_until_ready()` is the tool to synchronize the host with the device. It's essential when timing operations accurately or when external actions depend on the result being available.

# The Engine Room: A Brief Word on XLA

Beneath the elegant JAX NumPy API lies a powerful compiler: **XLA (Accelerated Linear Algebra)**. When you JIT-compile JAX functions, or even when you perform operations on JAX arrays, JAX translates this into an XLA computation graph. XLA then optimizes this graph and compiles it into highly efficient machine code for the target CPU, GPU, or TPU.

JAX arrays, are JAX's way of representing data that XLA manages on these devices. XLA handles the low-level details of memory allocation, kernel launching, and often decides the optimal physical layout of your data in device memory for maximum performance. While we usually interact with JAX APIs, knowing XLA is the engine empowers us to understand why certain JAX features (like JIT compilation) are so effective.

# Foundations Laid
Architect, you now have a much deeper understanding of JAX arrays – Aurora's fundamental data building blocks. You know they are immutable, device-aware entities. You can query Aurora's hardware landscape, explicitly command where data resides using jax.device_put, retrieve it using jax.device_get, and appreciate the asynchronous nature of JAX's execution model.

This mastery of individual data primitives is the bedrock upon which all of Aurora's distributed computing strategies will be built. Next, we will explore how to take these JAX arrays and the computations upon them, and distribute them across many devices working in concert – the true path to colossal AI.