Please connect to `Metrax (go/metrax)` colab runtime.

If you dont see `Metrax (go/metrax)` from the dropdown menu, please run `/google/bin/releases/colaboratory/public/tools/authorize_colab` on your gLinux workstation or cloudtop and try again.

# Getting Started with Metrax 🚀

Welcome to this hands-on guide for `metrax`, a powerful and flexible metrics library for JAX.

In this Colab, you'll learn how to:
* Use the **Functional metrax API** (`metrax`) and the **Object-Oriented metrax API** (`metrax.nnx`).
* Verify that batch and iterative calculations give **identical results**.
* Scale your metric computations to **multiple devices** using 1)`jax.pmap` and 2)`jax.jit`.
* Scale your metric computations to **multiple hosts** for large scale distributed training, using 1) `Multi-Controller JAX` and 2) `Pathways` solutions.

## ⚙️ Environment Setup: Simulating Multiple Devices

First, let's configure our environment. To demonstrate `metrax`'s multi-device capabilities, we'll instruct JAX to simulate an environment with **4 virtual CPU devices**. This allows us to test `jax.pmap` and `jax.jit` with `mesh` logic even on single-device hardware.

In [None]:
# This environment variable instructs JAX's underlying XLA compiler
# to create a specific number of virtual CPU devices.
#
# This MUST be set *before* the JAX backend is initialized, which happens on
# the first import of `jax`.
#
# In a script, ensure this line comes before `import jax`.
# In a notebook, a kernel restart may be needed if JAX has already been used.
import os
print("Configuring JAX to simulate 4 CPU devices...")
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
import jax


# --- Verify the JAX Environment ---
print("\nVerifying JAX environment configuration:")
print("-" * 40)
device_count = jax.device_count()
process_count = jax.process_count()
print(f"✅ Number of available JAX devices: {device_count}")
print(f"✅ Number of JAX processes: {process_count}")
print("-" * 40)

if device_count == 4:
  print("Success! JAX is now set up for multi-device simulation.")
else:
  print("Warning: JAX device count is not the expected value.")

## 📊 Data Preparation for Realistic Scenarios

Next, let's generate some data. A good metrics demo uses realistic data, so we'll create a dataset that is **imbalanced** and where the model's **predictions are imperfect but correlated** with the true labels.

In [None]:
import numpy as np

# --- 1. Data Generation Setup ---
np.random.seed(42)
N_BATCHES = 4
BATCH_SIZE = 8
TOTAL_SAMPLES = N_BATCHES * BATCH_SIZE

# --- 2. Create Realistic, Correlated Data ---
# Create an imbalanced dataset (80% class 0, 20% class 1).
labels = np.random.choice([0, 1], size=(TOTAL_SAMPLES,), p=[0.8, 0.2])

# Generate predictions correlated with labels, adding some noise for realism.
noise = np.random.normal(loc=0, scale=0.25, size=TOTAL_SAMPLES)
clean_preds = np.where(labels == 1, 0.8, 0.2)
predictions = np.clip(clean_preds + noise, 0, 1)

# Generate sample weights to give more importance to the rare positive class.
sample_weights = np.where(labels == 1, 2.0, 1.0)

# --- 3. Reshape Data into Batched Format ---
# The batched format is useful for demonstrating iterative calculations.
labels_batched = labels.reshape(N_BATCHES, BATCH_SIZE).astype(np.float32)
predictions_batched = predictions.reshape(N_BATCHES, BATCH_SIZE).astype(np.float32)
sample_weights_batched = sample_weights.reshape(N_BATCHES, BATCH_SIZE).astype(np.float32)

# --- 4. Data Shape Verification ---
print("✅ Data generation complete. Verifying array shapes:")
print("-" * 50)
print(f"Flat arrays for full-dataset processing:")
print(f"  - predictions.shape:    {predictions.shape}")
print(f"  - labels.shape:         {labels.shape}")
print(f"\nBatched arrays for iterative/streaming processing:")
print(f"  - predictions_batched.shape: {predictions_batched.shape}")
print(f"  - labels_batched.shape:      {labels_batched.shape}")
print("-" * 50)

## Metrax Linen API (Functional)

The core `metrax` API is functional and stateless, making it a natural fit for JAX. It works by creating immutable `Metric` state objects that can be merged.

Each `metrax` metric inherits the CLU [`metric`](http://shortn/_e70RtO7j36) class and provides the following APIs:

* `Metric.from_model_output()`: Creates a metric state from data.
* `Metric.empty()`: Creates an empty, initial state.
* `metric_a.merge(metric_b)`: Combines two metric states.
* `metric.compute()`: Computes the final value.

Let's demonstrate by calculating several metrics on our dataset, once on the 1) full batch and once by 2) iteratively merging results. The second method resembles real world machine learning metrics calculation scenarios.

In [None]:
import metrax

# Define the metrics we want to calculate.
metrics_to_compute = {
    'Precision': metrax.Precision,
    'Recall': metrax.Recall,
    'AUCPR': metrax.AUCPR,
    'AUCROC': metrax.AUCROC,
}

# Define which of these metrics should receive sample weights.
metrics_with_weights = {'AUCPR', 'AUCROC'}


# --- Method 1: Full-Batch Calculation ---
print("--- Method 1: Full-Batch Calculation (on all 32 samples) ---")
full_batch_results = {}
for name, MetricClass in metrics_to_compute.items():
  # Conditionally add sample_weights for supported metrics.
  if name in metrics_with_weights:
    metric_state = MetricClass.from_model_output(
        predictions=predictions,
        labels=labels,
        sample_weights=sample_weights
    )
  else:
    metric_state = MetricClass.from_model_output(
        predictions=predictions,
        labels=labels
    )
  full_batch_results[name] = metric_state.compute()
  print(f"{name}: {full_batch_results[name]}")


# --- Method 2: Iterative Merging by Batch ---
print("\n--- Method 2: Iterative Merging (4 batches of 8 samples) ---")
iterative_metrics = {
    name: MetricClass.empty() for name, MetricClass in metrics_to_compute.items()
}

for labels_b, predictions_b, weights_b in zip(labels_batched, predictions_batched, sample_weights_batched):
  for name, MetricClass in metrics_to_compute.items():
    if name in metrics_with_weights:
      current_metric_state = MetricClass.from_model_output(
          predictions=predictions_b,
          labels=labels_b,
          sample_weights=weights_b
      )
    else:
      current_metric_state = MetricClass.from_model_output(
          predictions=predictions_b,
          labels=labels_b
      )
    iterative_metrics[name] = iterative_metrics[name].merge(current_metric_state)

iterative_results = {}
for name, metric_state in iterative_metrics.items():
  iterative_results[name] = metric_state.compute()
  print(f"{name}: {iterative_results[name]}")


# --- Verification ---
print("\n--- Verification ---")
for name in metrics_to_compute.keys():
  assert np.allclose(full_batch_results[name], iterative_results[name])

print("✅ Success! Both methods produce identical results.")

## Metrax NNX API (Object-Oriented)

For users who prefer an object-oriented style, `metrax.nnx` provides stateful metric objects. This can simplify the code for iterative updates, as you update a single object in place.

Each `metrax.nnx` metric inherits the NNX [`metric`](http://shortn/_VyVVvvsQ00) class and provides the following APIs:

* `metric = Metric()`: Creates a stateful metric object.
* `metric.update()`: Updates the metric's internal state with new data.
* `metric.compute()`: Computes the final value from the accumulated state.


In [None]:
import metrax.nnx

# Define the nnx metrics we want to calculate.
metrics_to_compute_nnx = {
    'Precision': metrax.nnx.Precision,
    'Recall': metrax.nnx.Recall,
    'AUCPR': metrax.nnx.AUCPR,
    'AUCROC': metrax.nnx.AUCROC,
}

# Define which metrics require which specific arguments.
metrics_with_threshold = {'Precision', 'Recall'}


# --- Method 1: Full-Batch Calculation (nnx) ---
print("--- Method 1: Full-Batch Calculation with nnx ---")
full_batch_metrics_nnx = {
    name: MetricClass() for name, MetricClass in metrics_to_compute_nnx.items()
}
for name, metric_obj in full_batch_metrics_nnx.items():
  update_kwargs = {'predictions': predictions, 'labels': labels}
  if name in metrics_with_weights:
    update_kwargs['sample_weights'] = sample_weights
  if name in metrics_with_threshold:
    update_kwargs['threshold'] = 0.5
  metric_obj.update(**update_kwargs)

full_batch_results_nnx = {}
for name, metric_obj in full_batch_metrics_nnx.items():
  full_batch_results_nnx[name] = metric_obj.compute()
  print(f"{name}: {full_batch_results_nnx[name]}")


# --- Method 2: Iterative Updating by Batch (nnx) ---
print("\n--- Method 2: Iterative Updating with nnx ---")
iterative_metrics_nnx = {
    name: MetricClass() for name, MetricClass in metrics_to_compute_nnx.items()
}
for labels_b, predictions_b, weights_b in zip(labels_batched, predictions_batched, sample_weights_batched):
  for name, metric_obj in iterative_metrics_nnx.items():
    update_kwargs = {'predictions': predictions_b, 'labels': labels_b}
    if name in metrics_with_weights:
      update_kwargs['sample_weights'] = weights_b
    if name in metrics_with_threshold:
      update_kwargs['threshold'] = 0.5
    metric_obj.update(**update_kwargs)

iterative_results_nnx = {}
for name, metric_obj in iterative_metrics_nnx.items():
  iterative_results_nnx[name] = metric_obj.compute()
  print(f"{name}: {iterative_results_nnx[name]}")


# --- Verification ---
print("\n--- Verification ---")
for name in metrics_to_compute_nnx.keys():
  assert np.allclose(full_batch_results_nnx[name], iterative_results_nnx[name])

print("✅ Success! Both methods produce identical results using the nnx API.")

## Scaling to Multiple Devices

`metrax` is designed from the ground up to work seamlessly in distributed environments. Let's explore the two primary ways to scale metric computations in JAX.

### Method 1: The `pmap` Approach (Simple Data Parallelism)

The legacy way to scale computations across multiple devices in JAX is with `jax.pmap`. It's a powerful transformation designed for **data parallelism**, where you want to run the exact same program on different slices of data. `metrax`'s composable metrics are perfectly suited for this.

`pmap` is more than just a parallel loop; it's a transformation that handles several complex steps for you automatically. When you apply `pmap` to a function like `metrax.AUCPR.from_model_output`, here's what happens:

1.  **Compile (`jit`):** First, `pmap` **JIT-compiles** your function into highly efficient machine code that is optimized for your hardware (like GPUs or TPUs). You don't need to add `@jax.jit` yourself; it's an inherent part of the `pmap` process.

2.  **Distribute (Shard):** It then automatically splits your input arrays along their first (batch) axis and sends each data "shard" to a different device.

3.  **Execute in Parallel:** The compiled function runs simultaneously on all devices, with each device processing its own shard of the data. This step produces a distributed `metrax` object, where each device holds the intermediate metric state calculated from its local data.

4.  **Aggregate (Reduce):** To get a final, globally correct result, you call the `.reduce()` method on the output from `pmap`. This performs the necessary cross-device communication (an `all_reduce` sum) to combine the intermediate states from all devices into one.

---

### Method 2: The `jit` and `Mesh` Approach (Advanced Parallelism)

For more advanced control over distributed computation, JAX provides an explicit sharding mechanism using the `jax.sharding` API. This **SPMD (Single Program, Multiple Data)** approach is more powerful and flexible than `pmap` and is the standard for large-scale models.

Instead of `pmap`'s automatic behavior, you take explicit control over each step of the process:

1.  **Define a `Mesh`**: You first create a logical grid of your physical devices and give names to the axes (e.g., `Mesh(jax.devices(), ('data',))`). This describes the topology you'll be working with.

2.  **Create a `Sharding` Rule**: You specify exactly how each dimension of your array should be mapped to the mesh's axes. This is done using `NamedSharding` and `PartitionSpec`. For data parallelism, you would shard the batch axis of your data across the `'data'` axis of your mesh.

3.  **Explicitly Place Data**: You use `jax.device_put` to apply this sharding rule to your data arrays. At this point, your JAX arrays are "aware" of how they are distributed across the physical hardware.

4.  **`jit`-Compile the Function**: You write a function that looks like a normal, single-device calculation and decorate it with `@jax.jit`. When JAX's compiler sees that the inputs to this function are sharded arrays, it automatically generates a distributed version of the code. It implicitly handles all the necessary cross-device communication, so **no explicit `.reduce()` call is needed.**

This method provides the fine-grained control that is essential for more complex scenarios like model parallelism (sharding a model's weights) in addition to data parallelism.

In [None]:
import jax
import numpy as np
import metrax
from jax.sharding import Mesh, NamedSharding, PartitionSpec

# This script assumes that the JAX environment is configured for 4 devices
# and that the data arrays `predictions`, `labels`, and `sample_weights`
# have been created in a previous cell.

# --- 1. Metric Calculation Functions ---

# Method 1: pmap (Simple Data Parallelism)
def calculate_aucpr_pmap(predictions, labels, sample_weights):
  """
  Distributes data across devices using pmap and aggregates with .reduce().
  """
  n_devices = jax.device_count()
  # Reshape data to have a leading device dimension for pmap.
  sharded_preds = predictions.reshape((n_devices, -1))
  sharded_labels = labels.reshape((n_devices, -1))
  sharded_weights = sample_weights.reshape((n_devices, -1))

  # pmap the metric calculation function across devices.
  per_device_metric = jax.pmap(metrax.AUCPR.from_model_output)(
      predictions=sharded_preds,
      labels=sharded_labels,
      sample_weights=sharded_weights
  )
  # reduce() combines the states from all devices into one.
  return per_device_metric.reduce()


# Method 2: jit + Mesh (Advanced SPMD Parallelism)
def calculate_aucpr_mesh(predictions, labels, sample_weights):
    """
    Explicitly shards data across a device Mesh and calculates with jit.
    """
    # 1. Define the device mesh and sharding rule.
    mesh = Mesh(jax.devices(), axis_names=('data',))
    sharding_rule = NamedSharding(mesh, PartitionSpec('data'))

    # 2. Explicitly move and shard the data onto the mesh.
    sharded_predictions = jax.device_put(predictions, sharding_rule)
    sharded_labels = jax.device_put(labels, sharding_rule)
    sharded_weights = jax.device_put(sample_weights, sharding_rule)

    # 3. Define the function to be JIT-compiled.
    def _calculate(preds, labs, weights):
      return metrax.AUCPR.from_model_output(
          predictions=preds, labels=labs, sample_weights=weights)

    # 4. JIT-compile the function with explicit sharding annotations.
    #    - in_shardings: Specifies how each input array is expected to be sharded.
    #    - out_sharding: Specifies the desired sharding for the output.
    #                    'None' means the output should be replicated on all devices.
    jitted_calculate = jax.jit(
        _calculate,
        in_shardings=(sharding_rule, sharding_rule, sharding_rule),
        out_shardings=None
    )

    # The result is already a globally correct metric state, replicated on all devices.
    return jitted_calculate(sharded_predictions, sharded_labels, sharded_weights)


# Baseline: Single-Device (Direct)
@jax.jit
def calculate_aucpr_direct(predictions, labels, sample_weights):
  """Computes AUCPR on the entire dataset on a single device."""
  return metrax.AUCPR.from_model_output(
      predictions=predictions,
      labels=labels,
      sample_weights=sample_weights
  )


# --- 2. Execution and Verification ---
print("\nRunning all three AUCPR calculation methods...")

# Execute each of the three methods.
state_pmap = calculate_aucpr_pmap(predictions, labels, sample_weights)
state_mesh = calculate_aucpr_mesh(predictions, labels, sample_weights)
state_direct = calculate_aucpr_direct(predictions, labels, sample_weights)

# Compute the final values from the metric states.
result_pmap = state_pmap.compute()
result_mesh = state_mesh.compute()
result_direct = state_direct.compute()

# Ensure all computations are finished before verifying.
result_pmap.block_until_ready()
result_mesh.block_until_ready()
result_direct.block_until_ready()

# Verify that all results are numerically identical.
assert np.allclose(result_pmap, result_direct, rtol=1e-6)
assert np.allclose(result_mesh, result_direct, rtol=1e-6)


# --- 3. Display Results ---
print("\n" + "="*60)
print("          Comparison of Multi-Device AUCPR Calculations")
print("="*60)
print(f"{'Method':<35} {'AUCPR Value'}")
print("-" * 60)
print(f"{'Method 1: pmap':<35} {result_pmap}")
print(f"{'Method 2: jit + Mesh':<35} {result_mesh}")
print(f"{'Baseline: Direct Single-Device':<35} {result_direct}")
print("="*60)
print("\n✅ Verification successful: All three methods yield identical results.")

## 🧠 Advanced Use: Multi-Host Environments

For large-scale training (e.g., on TPU Pods), JAX uses multiple hosts (controllers), each managing multiple devices. In these scenarios, `metrax` integrates seamlessly with JAX's sharding capabilities provided by `jax.sharding` and `Mesh`.

By `jit`-compiling your training step with the appropriate device mesh, you can calculate metrics across hundreds or thousands of devices without changing the core metric logic.

For a detailed example of multi-controller and Pathways training, please see our [**example training script**](https://source.corp.google.com/piper///depot/google3/third_party/py/metrax/examples/xm_launch.py). The example demonstrates a multi-host setup with two hosts, each controlling four local devices for a total of eight(`viperlite=4x2`).