# Array Refs (experimental)

In [1]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax

## Basics

### Array Refs 101

In [2]:
a_ref = jax.new_ref(jnp.array([1, 2, 3]))

@jax.jit
def increment(a_ref: jax.Ref):  # no return!
  array: jax.Array = a_ref[...]  # access
  a_ref[...] = array + 1         # update

print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref)

[1] = ArrayRef([1, 2, 3], dtype=int32)
[2] = ArrayRef([2, 3, 4], dtype=int32)


In [3]:
@jax.jit
def inc(x):
  x[...] += 1

print(increment.lower(a_ref).as_text())

module @jit_increment attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {tf.aliasing_output = 0 : i32}) -> (tensor<3xi32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<3xi32>
    %1 = stablehlo.add %arg0, %0 : tensor<3xi32>
    return %1 : tensor<3xi32>
  }
}



### Variables Refs

In [4]:
variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)
print(f"{variable.is_hijax = }\n")

print("[1] =", variable); increment(variable); print("[2] =", variable)

variable.has_ref = True

[1] = [38;2;79;201;177mVariable[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 3 (12 B)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArrayRef([1, 2, 3], dtype=int32)
[38;2;255;213;3m)[0m
[2] = [38;2;79;201;177mVariable[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 3 (12 B)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArrayRef([2, 3, 4], dtype=int32)
[38;2;255;213;3m)[0m


In [5]:
with nnx.use_refs(True):
  variable = nnx.Variable(jnp.array([1, 2, 3]))

print(f"{variable.is_hijax = }")

variable.has_ref = True


Mention `nnx.use_refs` can be used as global flag

### Changing Status

In [6]:
class Linear(nnx.Module):
  def __init__(self, in_features, out_features, rngs: nnx.Rngs):
    self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))
    self.bias = nnx.Param(jnp.zeros(out_features))

  def __call__(self, x):
    return x @ self.kernel + self.bias[None]

model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs
refs_model = nnx.to_refs(model) # convert to array refs
arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays

print("nnx.to_refs(model) =", refs_model)
print("nnx.to_arrays(refs_model) =", arrays_model)

nnx.to_refs(model) = [38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 6 (24 B)[0m
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 3 (12 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m3[0m,[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mkernel[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 3 (12 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m1[0m, [38;2;182;207;169m3[0m[38;2;255;213;3m)[0m, [38;2;

## Examples

In [7]:
class Block(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)
    self.linear_out = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.gelu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

### Training Loop

In [8]:
with nnx.use_refs(True):
  model = Block(2, 64, 3, rngs=nnx.Rngs(0))
  optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, x, y):
  graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
  def loss_fn(params):
    model =  nnx.merge(graphdef, params, nondiff)
    return ((model(x) - y) ** 2).mean()

  loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params))  # freeze ArrayRefs for jax.grad
  optimizer.update(model, grads)

  return loss

train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))

Array(1.000178, dtype=float32)

### Scan Over Layers

In [9]:
@nnx.vmap
def create_stack(rngs):
  return Block(2, 64, 2, rngs=rngs)

with nnx.use_refs(True):
  block_stack = create_stack(nnx.Rngs(0).fork(split=8))

def scan_fn(x, block):
  x = block(x)
  return x, None

x = jax.random.uniform(jax.random.key(0), (3, 2))
y, _ = jax.lax.scan(scan_fn, x, block_stack)

print("y = ", y)

y =  [[ 0.82840395 -0.25364894]
 [ 4.9552917   4.93638   ]
 [-7.6721525  -3.4668717 ]]


## Limitations

### MutableArray Outputs

In [10]:
@jax.jit
def create_model(rngs):
  return Block(2, 64, 3, rngs=rngs)

try:
  with nnx.use_refs(True):
    model = create_model(nnx.Rngs(0))
except Exception as e:
  print(f"Error:", e)

Error: function create_model at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1421484665.py:1 traced for jit returned a mutable array reference of type Ref{float32[64]} at output tree path result.bn.bias.value, but mutable array references cannot be returned.

The returned mutable array was created on line /Users/cgarciae/repos/flax/flax/nnx/variablelib.py:250:17 (Variable.__init__).


In [11]:
with nnx.use_refs(False): # <-- disable array refs
  model = create_model(nnx.Rngs(0))

model = nnx.to_refs(model) # convert to mutable after creation

print("model.linear =", model.linear)

model.linear = [38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 192 (768 B)[0m
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 64 (256 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m64[0m,[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mkernel[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 128 (512 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m2[0m, [38;2;182;207;169m64[0m[38;2;255;213;3m)[0m, [3

In [12]:
@nnx.jit
def create_model(rngs):
  return Block(2, 64, 3, rngs=rngs)

with nnx.use_refs(True):
  model = create_model(nnx.Rngs(0))

print("model.linear =", model.linear)

model.linear = [38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 192 (768 B)[0m
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 64 (256 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m64[0m,[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mkernel[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 128 (512 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArrayRef[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m2[0m, [38;2;182;207;169m64[0m[38;2;255;213;3m)[0m, [3

### Reference Sharing (aliasing)

In [13]:
def get_error(f, *args):
  try:
    return f(*args)
  except Exception as e:
    return f"{type(e).__name__}: {e}"
  
x = jax.new_ref(jnp.array(0))

@jax.jit
def f(a, b):
  ...

print(get_error(f, x, x))

ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing f at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1563421490.py:9 for jit the mutable array reference of type Ref{int32[]} appeared at both a and b.


In [14]:
class SharedVariables(nnx.Pytree):
  def __init__(self):
    self.a = nnx.Variable(jnp.array(0))
    self.b = nnx.Variable(jnp.array(1))
    self.c = self.a

class SharedModules(nnx.Pytree):
  def __init__(self):
    self.d = Linear(1, 1, rngs=nnx.Rngs(0))
    self.e = Linear(1, 1, rngs=nnx.Rngs(0))
    self.f = self.d

@jax.jit
def g(pytree):
  ...

with nnx.use_refs(True):
  shared_variables = SharedVariables()
  shared_modules = SharedModules()

print("SharedVariables", get_error(g, shared_variables))
print("SharedModules", get_error(g, shared_modules))

SharedVariables ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{int32[]} appeared at both pytree.a.value and pytree.c.value.
SharedModules ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{float32[1]} appeared at both pytree.d.bias.value and pytree.f.bias.value.


In [15]:
if (duplicates := nnx.find_duplicates(shared_variables)):
  print("shared variables duplicates:", duplicates)

if (duplicates := nnx.find_duplicates(shared_modules)):
  print("shared modules duplicates:  ", duplicates)

shared variables duplicates: [[('a',), ('c',)]]
shared modules duplicates:   [[('d',), ('f',)]]


In [16]:
@jax.jit
def h(graphdef, state):
  obj = nnx.merge(graphdef, state)
  obj.a[...] += 10

graphdef, state = nnx.split(shared_variables)
print(state) # split deduplicates the state

h(graphdef, state)

print("updated", shared_variables)

[38;2;79;201;177mState[0m[38;2;255;213;3m({[0m[38;2;105;105;105m[0m
  [38;2;156;220;254m'a'[0m[38;2;212;212;212m: [0m[38;2;79;201;177mVariable[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArrayRef(0, dtype=int32, weak_type=True)
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254m'b'[0m[38;2;212;212;212m: [0m[38;2;79;201;177mVariable[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArrayRef(1, dtype=int32, weak_type=True)
  [38;2;255;213;3m)[0m
[38;2;255;213;3m})[0m
updated [38;2;79;201;177mSharedVariables[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Variable: 2 (8 B)[0m
  [38;2;156;220;254ma[0m[38;2;212;212;212m=[0m[38;2;79;201;177mVariable[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArrayRef(10, dtype=int32)
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mb[