Skip to content

Commit

Permalink
Use jax.tree_util.tree_map instead of deprecated jax.tree_map.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609937836
  • Loading branch information
IvyZX authored and Flax Authors committed Feb 26, 2024
1 parent 83e7466 commit 85eb0de
Show file tree
Hide file tree
Showing 39 changed files with 450 additions and 373 deletions.
6 changes: 3 additions & 3 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,11 @@ def encode_strings(strs, max_len):
predict_step,
in_axes=(
0,
jax.tree_map(lambda x: None, state.params),
jax.tree_util.tree_map(lambda x: None, state.params),
0,
None,
None,
jax.tree_map(lambda x: None, predict_config),
jax.tree_util.tree_map(lambda x: None, predict_config),
None,
None,
),
Expand Down Expand Up @@ -558,7 +558,7 @@ def encode_strings(strs, max_len):
# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = next(train_iter)
batch = jax.tree_map(lambda x: jnp.array(x), batch)
batch = jax.tree_util.tree_map(lambda x: jnp.array(x), batch)
state, metrics = jit_train_step(
state, batch, train_config, learning_rate_fn, 0.0, dropout_rngs
)
Expand Down
8 changes: 6 additions & 2 deletions flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def copy(
if isinstance(x, FrozenDict):
return x.copy(add_or_replace)
elif isinstance(x, dict):
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict = jax.tree_util.tree_map(
lambda x: x, x
) # make a deep copy of dict x
new_dict.update(add_or_replace)
return new_dict
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
Expand Down Expand Up @@ -280,7 +282,9 @@ def pop(
if isinstance(x, FrozenDict):
return x.pop(key)
elif isinstance(x, dict):
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict = jax.tree_util.tree_map(
lambda x: x, x
) # make a deep copy of dict x
value = new_dict.pop(key)
return new_dict, value
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
Expand Down
38 changes: 23 additions & 15 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
import functools
from typing import Any, Callable, Dict, Generic, Optional, TypeVar

import jax
from jax.experimental import maps

from flax import errors, struct
from flax.typing import LogicalNames
import jax
from jax.experimental import maps

A = TypeVar('A')
B = TypeVar('B')
Expand Down Expand Up @@ -77,6 +76,7 @@ def replace_boxed(self, val: B) -> 'AxisMetadata[B]':
Args:
val: The new value to be boxed by this AxisMetadata wrapper
Returns:
A new instance of the same type as self with `val` as the new ``unbox``
content
Expand All @@ -85,7 +85,7 @@ def replace_boxed(self, val: B) -> 'AxisMetadata[B]':

@abc.abstractmethod
def add_axis(
self: TAxisMetadata, index: int, params: Dict[Any, Any]
self: TAxisMetadata, index: int, params: Dict[Any, Any]
) -> TAxisMetadata:
"""Adds a new axis to the axis metadata.
Expand All @@ -98,6 +98,7 @@ def add_axis(
that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
Expand All @@ -106,7 +107,7 @@ def add_axis(

@abc.abstractmethod
def remove_axis(
self: TAxisMetadata, index: int, params: Dict[Any, Any]
self: TAxisMetadata, index: int, params: Dict[Any, Any]
) -> TAxisMetadata:
"""Removes an axis from the axis metadata.
Expand All @@ -116,9 +117,10 @@ def remove_axis(
Args:
index: The position of the axis that is to be removed
params: An arbitrary dictionary of parameters passed by the transformation
that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The
user passes this dictionary as the `metadata_param` argument to the
that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user
passes this dictionary as the `metadata_param` argument to the
transformation.
Returns:
A new instance of the same type as self and with the same ``unbox``
content with updated axis metadata.
Expand Down Expand Up @@ -167,7 +169,9 @@ def inner_update(c, v):
else:
return v

return jax.tree_util.tree_map(inner_update, tree, updates, is_leaf=is_axis_metadata)
return jax.tree_util.tree_map(
inner_update, tree, updates, is_leaf=is_axis_metadata
)


PARTITION_NAME = 'partition_name'
Expand Down Expand Up @@ -233,13 +237,12 @@ def body(mdl, c):
body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8,
metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x)
return c
"""

value: Any
names: LogicalNames = struct.field(pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(
default=None, pytree_node=False
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> A:
Expand Down Expand Up @@ -285,9 +288,9 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:


def with_partitioning(
fn: Callable[..., Any],
names: LogicalNames,
mesh: Optional[jax.sharding.Mesh] = None,
fn: Callable[..., Any],
names: LogicalNames,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Callable[..., Partitioned[Any]]:
"""Wraps a function's return value with Partitioned.
Expand All @@ -303,6 +306,7 @@ def with_partitioning(
names: The logical axis passed to ``Partitioned``.
mesh: The mesh to use for the partitioning. If None, the global mesh
resource is used if available.
Returns:
A function wrapping ``fn`` that will return an instance of ``Partitioned``.
"""
Expand All @@ -326,10 +330,14 @@ def f(x):
else:
return None

return jax.tree_util.tree_map(f, tree, is_leaf=lambda x: isinstance(x, Partitioned))
return jax.tree_util.tree_map(
f, tree, is_leaf=lambda x: isinstance(x, Partitioned)
)


def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any:
"""Extracts a jax.sharding tree from a PyTree containing ``Partitioned`` values and a mesh."""
pspec_tree = get_partition_spec(tree)
return jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), pspec_tree)
return jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(mesh, x), pspec_tree
)
42 changes: 21 additions & 21 deletions flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,27 +534,27 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
# Since the inputs and rngkey args for predict_step will be batched,
# we must vmap them, otherwise the global arrays will be seen in each device
jit_pred_step = jax.jit(
jax.vmap(
predict_step,
in_axes=(
0,
jax.tree_map(lambda x: None, state.params),
0,
None,
None,
None,
None,
None,
None,
jax.vmap(
predict_step,
in_axes=(
0,
jax.tree_util.tree_map(lambda x: None, state.params),
0,
None,
None,
None,
None,
None,
None,
),
),
),
in_shardings=(
data_sharding,
state_sharding.params,
data_sharding,
), # type: ignore
out_shardings=data_sharding, # type: ignore
static_argnums=tuple(range(3, 9)),
in_shardings=(
data_sharding,
state_sharding.params,
data_sharding,
), # type: ignore
out_shardings=data_sharding, # type: ignore
static_argnums=tuple(range(3, 9)),
)

# Main Train Loop
Expand Down Expand Up @@ -582,7 +582,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation('train', step_num=step):
batch = next(train_iter)
batch = jax.tree_map(lambda x: jnp.asarray(x), batch)
batch = jax.tree_util.tree_map(lambda x: jnp.asarray(x), batch)
state, metrics = jit_train_step(
state, batch, learning_rate_fn, 0.0, dropout_rngs
)
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/lm1b/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def setup_initial_state(
state = TrainState.create(
apply_fn=static.apply, params=params, tx=tx, graphdef=static
)
state = jax.tree_map(_to_array, state)
state = jax.tree_util.tree_map(_to_array, state)
state_spec = nnx.get_partition_spec(state)
state = jax.lax.with_sharding_constraint(state, state_spec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def loss_fn(params):

grad, counts = jax.grad(loss_fn, has_aux=True)(params)
# |-------- sgd ---------|
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad)
params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grad)

return params, counts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def loss_fn(model: MLP):
grad: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model)
# sdg update
model.update(
jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grad)
jax.tree_util.tree_map(
lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grad
)
)

# no return!!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ def scan_fn(
with nnx.flags(deterministic=False):
y = model(x, rngs=nnx.Rngs(dropout=1))

print(jax.tree_map(jnp.shape, model.get_state()))
print(jax.tree_util.tree_map(jnp.shape, model.get_state()))
print(y.shape)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import jax

from flax.experimental import nnx
Expand Down Expand Up @@ -52,5 +51,10 @@ def __call__(self, x):
# split the parameters into trainable and non-trainable parameters
trainable_params, non_trainable, static = model.split(is_trainable, ...)

print('trainable_params =', jax.tree_map(jax.numpy.shape, trainable_params))
print('non_trainable = ', jax.tree_map(jax.numpy.shape, non_trainable))
print(
'trainable_params =',
jax.tree_util.tree_map(jax.numpy.shape, trainable_params),
)
print(
'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable)
)
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array:
state, loss = train_step(state, x_batch, y_batch)

metrics = eval_step(state, X_test, Y_test)
metrics = jax.tree_map(lambda x: x.item(), metrics)
metrics = jax.tree_util.tree_map(lambda x: x.item(), metrics)
print(f'Epoch {epoch} - {metrics}')

# %%
Expand Down Expand Up @@ -238,7 +238,7 @@ def optimize(
tx = optax.adam(1e-3)
opt_state = tx.init(q_hparams)

print(jax.tree_map(lambda x: x.shape, q_hparams))
print(jax.tree_util.tree_map(lambda x: x.shape, q_hparams))

@jax.jit
def optimization_step(
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/ideas/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array:
# eager
m1 = Linear(din=32, dout=10, rngs=nnx.Rngs(params=0))
y = m1(x=jnp.ones((1, 32)))
print(jax.tree_map(jnp.shape, m1.get_state()))
print(jax.tree_util.tree_map(jnp.shape, m1.get_state()))

# lazy
m2 = Linear(dout=10)
y = m2.init(x=jnp.ones((1, 32)), rngs=nnx.Rngs(params=0))
print(jax.tree_map(jnp.shape, m2.get_state()))
print(jax.tree_util.tree_map(jnp.shape, m2.get_state()))

# usage
y1 = m1(x=jnp.ones((1, 32)))
Expand Down Expand Up @@ -199,12 +199,12 @@ def __call__(self, x: jax.Array, _, *, train: bool, rngs: nnx.Rngs):
mlp = MLP(din=10, dout=10, rngs=nnx.Rngs(params=0))
y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1))
print(f'{y.shape=}')
print('state =', jax.tree_map(jnp.shape, mlp.get_state()))
print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state()))
print()

# lazy
mlp = MLP(dout=10)
mlp.init(jnp.ones((1, 10)), None, train=False, rngs=nnx.Rngs(params=0))
y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1))
print(f'{y.shape=}')
print('state =', jax.tree_map(jnp.shape, mlp.get_state()))
print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state()))
6 changes: 3 additions & 3 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ class LinearGeneral(Module):
>>> # output features (4, 5)
>>> layer = nn.LinearGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_map(jnp.shape, params)
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.LinearGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_map(jnp.shape, params)
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
Attributes:
Expand Down Expand Up @@ -193,7 +193,7 @@ def kernel_init_wrap(rng, shape, dtype) -> jax.Array:
* np.prod(shape[n_batch_axis : n_in_features + n_batch_axis]),
np.prod(shape[-n_out_features:]),
)
flat_shape = jax.tree_map(int, flat_shape)
flat_shape = jax.tree_util.tree_map(int, flat_shape)
kernel = self.kernel_init(rng, flat_shape, dtype)
if isinstance(kernel, variables.VariableMetadata):
kernel.value = jnp.reshape(kernel.value, shape)
Expand Down
Loading

0 comments on commit 85eb0de

Please sign in to comment.