In [2]:
shared = {
    "a": 1,
    "b": 2,
}
pytree = {
    "x": shared,
    "y": shared,
}

# in regular python
pytree["x"]["a"] = 100
assert pytree["y"]["a"] == 100

pytree 

{'x': {'a': 100, 'b': 2}, 'y': {'a': 100, 'b': 2}}

In [3]:
import jax

@jax.jit
def f(pytree):
    pytree["x"]["b"] = 200
    return pytree

pytree = f(pytree)

try:
    assert pytree["y"]["b"] == 200
except AssertionError:
    print("References lost")

pytree

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


References lost


{'x': {'a': Array(100, dtype=int32, weak_type=True),
  'b': Array(200, dtype=int32, weak_type=True)},
 'y': {'a': Array(100, dtype=int32, weak_type=True),
  'b': Array(2, dtype=int32, weak_type=True)}}

In [4]:
import refx

shared = {
    "a": refx.Ref(1),
    "b": refx.Ref(2),
}
pytree = {
    "x": shared,
    "y": shared,
}

# in regular python
pytree["x"]["a"].value = 100
assert pytree["y"]["a"].value == 100

In general, this is very dangerous because it can lead very easily to tracer leakage. E.g.

In [11]:
from dataclasses import dataclass

@dataclass
class FakeRef:
    value: object

ref = FakeRef(1)

@jax.jit
def f():
    x = jax.numpy.empty(1)
    ref.value = x
    return x

x = f()

try:
    ref.value + 1
except BaseException as e:
    print("\n".join(str(e).splitlines()[:2]))

Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.


In [13]:
from dataclasses import dataclass
ref = refx.Ref(1)

@jax.jit
def f():
    x = jax.numpy.empty(1)
    ref.value = x
    return x

try:
    f()
except ValueError as e:
    print(e)

Cannot mutate ref from different trace level


How do we even use this?

In [14]:
shared = {
    "a": refx.Ref(1),
    "b": refx.Ref(2),
}
pytree = {
    "x": shared,
    "y": shared,
}

@jax.jit
def f(pytree):
    return pytree

try:
    f(pytree)
except BaseException as e:
    print(e)

Cannot interpret value of type <class 'refx.ref.Ref'> as an abstract array; it does not have a dtype attribute


We need to deref:

In [16]:
r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    "x": [r1, r1, r2],
    "y": r2,
    "z": 10,
}

refx.deref(pytree)

{'x': [Value(_value=1, index=0, ref_type=<class 'refx.ref.Ref'>),
  Index(index=0, ref_type=<class 'refx.ref.Ref'>),
  Value(_value=2, index=1, ref_type=<class 'refx.ref.Ref'>)],
 'y': Index(index=1, ref_type=<class 'refx.ref.Ref'>),
 'z': 10}

In [17]:
r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    "x": [r1, r1, r2],
    "y": r2,
    "z": 10,
}

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    return pytree

f(pytree)

{'x': [Value(_value=Array(1, dtype=int32, weak_type=True), index=0, ref_type=<class 'refx.ref.Ref'>),
  Index(index=0, ref_type=<class 'refx.ref.Ref'>),
  Value(_value=Array(2, dtype=int32, weak_type=True), index=1, ref_type=<class 'refx.ref.Ref'>)],
 'y': Index(index=1, ref_type=<class 'refx.ref.Ref'>),
 'z': Array(10, dtype=int32, weak_type=True)}

In [20]:
r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    "x": [r1, r1, r2],
    "y": r2,
    "z": 10,
}

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    y = pytree['x'][1].value
    return pytree

try:
    f(pytree)
except BaseException as e:
    print(e)

Cannot get value of Index


In [22]:
r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    "x": [r1, r1, r2],
    "y": r2,
    "z": 10,
}

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    pytree = refx.reref(pytree)
    pytree['x'][1].value = 100
    pytree = refx.deref(pytree)
    return pytree

pytree = f(pytree)
pytree = refx.reref(pytree)

assert pytree['x'][0].value == 100