In [2]:
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

def test_metadata_preservation():
    """Test what actually survives through JAX compilation."""
    
    @jax.jit
    def compute_with_metadata(x):
        with set_xla_metadata(operation_id="op_001", challengeable="true"):
            y = jnp.sin(x)
        
        with set_xla_metadata(operation_id="op_002", challengeable="true"):
            z = jnp.cos(y)
        
        with set_xla_metadata(operation_id="op_003", challengeable="true"):
            result = y + z
        
        return result
    
    # Compile and get StableHLO
    x = jnp.array([1.0, 2.0, 3.0])
    lowered = compute_with_metadata.lower(x)
    
    # Check different representations
    print("=== StableHLO ===")
    stablehlo_text = lowered.as_text(dialect="stablehlo")
    print(stablehlo_text[:1000])  # First 1000 chars
    
    print("\n=== HLO ===")
    hlo_text = lowered.compile().as_text()
    print(hlo_text[:1000])
    
    # Check if our metadata appears anywhere
    print("\n=== Metadata search ===")
    print(f"'operation_id' in StableHLO: {'operation_id' in stablehlo_text}")
    print(f"'challengeable' in StableHLO: {'challengeable' in stablehlo_text}")
    print(f"'op_001' in StableHLO: {'op_001' in stablehlo_text}")
    
    return stablehlo_text, hlo_text

# Run the test
stablehlo, hlo = test_metadata_preservation()

=== StableHLO ===
module @jit_compute_with_metadata attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32>) -> (tensor<3xf32> {jax.result_info = "result"}) {
    %0 = stablehlo.sine %arg0 {mhlo.frontend_attributes = {challengeable = "true", operation_id = "op_001"}} : tensor<3xf32>
    %1 = stablehlo.cosine %0 {mhlo.frontend_attributes = {challengeable = "true", operation_id = "op_002"}} : tensor<3xf32>
    %2 = stablehlo.add %0, %1 {mhlo.frontend_attributes = {challengeable = "true", operation_id = "op_003"}} : tensor<3xf32>
    return %2 : tensor<3xf32>
  }
}


=== HLO ===
HloModule jit_compute_with_metadata, is_scheduled=true, entry_computation_layout={(f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.1: f32[3]) -> f32[3] {
  %param_0.1 = f32[3]{0} parameter(0)
  %sin.0 = f32[3]{0} sine(%param_0.1), frontend_attri

In [4]:
def glug(z):
    return z + 1

def foo(x):
    return x + 2 

@jax.jit
def main(x):
    y = foo(x)
    z = glug(y)
    return z

x = jnp.array([1.0])
main(x)

Array([4.], dtype=float32)

In [None]:
@jax.jit
def simple_test(x):
    with set_xla_metadata(my_custom_attr="test_value"):
        return x * 2

x = jnp.array([1.0])
lowered = simple_test.lower(x)

print("Full StableHLO:")
print(lowered.as_text(dialect="stablehlo"))

print("\nFull HLO:")
print(lowered.compile().as_text())