Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions flax/experimental/nnx/examples/lm1b/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any

# add project_root to import lm1b Linen model
project_root = str(Path(__file__).parents[6])
project_root = str(Path(__file__).parents[5])
sys.path.append(project_root)
from examples.lm1b.models import TransformerLM as TransformerLinen

Expand Down Expand Up @@ -98,66 +98,66 @@ def copy_var(nnx_name, linen_name):
flat_params_linen[linen_name].names
)

copy_var('decoder/output_embed/embedding', 'decoder/Embed_0/embedding')
copy_var(('decoder','output_embed','embedding'), 'decoder/Embed_0/embedding')
copy_var(
'decoder/encoderdecoder_norm/bias', 'decoder/encoderdecoder_norm/bias'
('decoder', 'encoderdecoder_norm', 'bias'), 'decoder/encoderdecoder_norm/bias'
)
copy_var(
'decoder/encoderdecoder_norm/scale', 'decoder/encoderdecoder_norm/scale'
('decoder', 'encoderdecoder_norm', 'scale'), 'decoder/encoderdecoder_norm/scale'
)

for idx in range(config.num_layers):
copy_var(
f'decoder/encoderdecoderblock_{idx}/ln1/bias',
('decoder', f'encoderdecoderblock_{idx}', 'ln1', 'bias'),
f'decoder/encoderdecoderblock_{idx}/LayerNorm_0/bias',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/ln1/scale',
('decoder', f'encoderdecoderblock_{idx}', 'ln1', 'scale'),
f'decoder/encoderdecoderblock_{idx}/LayerNorm_0/scale',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/ln2/bias',
('decoder', f'encoderdecoderblock_{idx}', 'ln2', 'bias'),
f'decoder/encoderdecoderblock_{idx}/LayerNorm_1/bias',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/ln2/scale',
('decoder', f'encoderdecoderblock_{idx}', 'ln2', 'scale'),
f'decoder/encoderdecoderblock_{idx}/LayerNorm_1/scale',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/query/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'query', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/query/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/key/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'key', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/key/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/value/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'value', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/value/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/out/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'out', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/out/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/mlp/linear1/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'mlp', 'linear1', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MlpBlock_0/Dense_0/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/mlp/linear1/bias',
('decoder', f'encoderdecoderblock_{idx}', 'mlp', 'linear1', 'bias'),
f'decoder/encoderdecoderblock_{idx}/MlpBlock_0/Dense_0/bias',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/mlp/linear2/kernel',
('decoder', f'encoderdecoderblock_{idx}', 'mlp', 'linear2', 'kernel'),
f'decoder/encoderdecoderblock_{idx}/MlpBlock_0/Dense_1/kernel',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/mlp/linear2/bias',
('decoder', f'encoderdecoderblock_{idx}', 'mlp', 'linear2', 'bias'),
f'decoder/encoderdecoderblock_{idx}/MlpBlock_0/Dense_1/bias',
)

copy_var('decoder/logitdense/kernel', 'decoder/logitdense/kernel')
copy_var('decoder/logitdense/bias', 'decoder/logitdense/bias')
copy_var(('decoder', 'logitdense', 'kernel'), 'decoder/logitdense/kernel')
copy_var(('decoder', 'logitdense', 'bias'), 'decoder/logitdense/bias')

def transfer_cache(
self,
Expand All @@ -177,20 +177,20 @@ def copy_var(nnx_name, linen_name):

for idx in range(config.num_layers):
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cache_index',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'cache_index'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cache_index',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cached_key',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'cached_key'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cached_key',
)
copy_var(
f'decoder/encoderdecoderblock_{idx}/attention/cached_value',
('decoder', f'encoderdecoderblock_{idx}', 'attention', 'cached_value'),
f'decoder/encoderdecoderblock_{idx}/MultiHeadDotProductAttention_0/cached_value',
)

copy_var(
'decoder/posembed_output/cache_index',
('decoder', 'posembed_output', 'cache_index'),
'decoder/posembed_output/cache_index',
)

Expand All @@ -207,7 +207,7 @@ def test_forward_eval(self):
)

model_nnx = TransformerLM.create_abstract(config, rngs=nnx.Rngs(0))
params_nnx, _ = model_nnx.split(nnx.Param)
_, params_nnx = model_nnx.split(nnx.Param)

model_linen = TransformerLinen(config)

Expand Down Expand Up @@ -246,7 +246,7 @@ def test_forward_decode(self):
input_shape = (batch_size, config.max_len, config.emb_dim)
m.init_cache(input_shape, dtype=config.dtype)

params_nnx, cache_nnx, _ = model_nnx.split(nnx.Param, nnx.Cache)
_, params_nnx, cache_nnx = model_nnx.split(nnx.Param, nnx.Cache)

model_linen = TransformerLinen(config)

Expand Down
4 changes: 3 additions & 1 deletion flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def predict_step(
input_shape = (inputs.shape[0], max_decode_len, config.emb_dim)
m.init_cache(input_shape, dtype=config.dtype)

cache = module.extract(nnx.Cache)
# passed in static argument does not contain cache
# need to split module to get updated static that contains cache
static, params, cache = module.split(nnx.Param, nnx.Cache)

def tokens_ids_to_logits(flat_ids, cache: nnx.State):
"""Token slice to logits from decoder model."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __call__(self, x):
# create a filter to select all the parameters that are not part of the
# backbone, i.e. the classifier parameters
is_trainable = lambda path, node: (
path.startswith('backbone') and isinstance(node, nnx.Param)
(path[0] == 'backbone') and isinstance(node, nnx.Param)
)

# split the parameters into trainable and non-trainable parameters
trainable_params, non_trainable, static = model.split(is_trainable, ...)
static, trainable_params, non_trainable = model.split(is_trainable, ...)

print('trainable_params =', jax.tree_util.tree_map(jax.numpy.shape, trainable_params))
print('non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable))
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
matplotlib>=3.7.1
datasets>=2.12.0"
datasets>=2.12.0
11 changes: 8 additions & 3 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def partial_init(cls: type[M], state: State, *states: State) -> type[M]:
>>> from flax.experimental import nnx
...
>>> bias = jax.random.normal(jax.random.key(0), (4,))
>>> state = nnx.State({'bias': bias}) # in reality load it from a checkpoint
>>> state = nnx.State({'bias': nnx.Param(bias)}) # in reality load it from a checkpoint
>>> linear = nnx.Linear.partial_init(state)(2, 4, rngs=nnx.Rngs(1))
>>> y = linear(jnp.ones((1, 2)))
...
>>> assert jnp.allclose(linear.bias, bias)
>>> assert jnp.allclose(linear.bias.value, bias)
>>> assert y.shape == (1, 4)

Args:
Expand Down Expand Up @@ -332,7 +332,12 @@ def set_attributes(
``Filter``'s can be used to set the attributes of specific Modules::

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True)
>>> block.set_attributes(
... nnx.Dropout,
... deterministic=True,
... use_running_average=True,
... raise_if_not_found=False, # Don't raise an error if the attribute isn't found
... )
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Expand Down
57 changes: 16 additions & 41 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,51 +221,26 @@ class MultiHeadAttention(Module):

Example usage::

>>> import flax.linen as nn
>>> from flax.experimental import nnx
>>> import jax

>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> layer = nnx.MultiHeadAttention(
... num_heads=8, in_features=5, qkv_features=16, decode=False, rngs=nnx.Rngs(0)
... )
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> q, k, v = (
... jax.random.uniform(key1, shape),
... jax.random.uniform(key2, shape),
... jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
>>> out = layer(q, k, v)
>>> # equivalent to layer(inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer(q, k)
>>> # equivalent to layer(inputs_q=q, inputs_k=q) and layer(inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer(q)

Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
Expand Down Expand Up @@ -609,7 +584,7 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
... out_features=6,
... decode=True,
... rngs=rngs,
>>> )
... )

>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized

Expand Down
31 changes: 15 additions & 16 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,20 @@ class LinearGeneral(Module):

Example usage::

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # equivalent to `nn.Linear(features=4)`
>>> layer = nn.LinearGeneral(features=4)
>>> # output features (4, 5)
>>> layer = nn.LinearGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.LinearGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
>>> from flax.experimental import nnx

>>> rngs = nnx.Rngs(0)
>>> # equivalent to `nn.Linear(3, 4)`
>>> layer = nnx.LinearGeneral(3, 4, rngs=rngs)
>>> layer.kernel.value.shape
(3, 4)
>>> layer.bias.value.shape
(4,)
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(0, 1), rngs=rngs)
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)

Attributes:
features: int or tuple with number of output features.
Expand Down Expand Up @@ -366,7 +365,7 @@ class Einsum(Module):
>>> from flax.experimental import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Einsum('abc,cde->abde', (3, 4, 5), (5, 6, 7), rngs=nnx.Rngs(0))
>>> layer = nnx.Einsum('abc,cde->abde', (5, 6, 7), (6, 7), rngs=nnx.Rngs(0))
>>> assert layer.kernel.value.shape == (5, 6, 7)
>>> assert layer.bias.value.shape == (6, 7)
>>> out = layer(jnp.ones((3, 4, 5)))
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/nnx/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class MultiMetric(Metric):

>>> metrics = nnx.MultiMetric(
... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
>>> )
... )
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ class Optimizer(graph_utils.GraphNode):

>>> loss_fn = lambda model: ((model(x)-y)**2).mean()
>>> loss_fn(state.model)
1.7055722
Array(1.7055722, dtype=float32)
>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads)
>>> loss_fn(state.model)
1.6925814
Array(1.6925814, dtype=float32)

Note that you can easily extend this class by subclassing it for storing
additional data (e.g. adding metrics).
Expand All @@ -90,10 +90,10 @@ class Optimizer(graph_utils.GraphNode):
>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
1.6925814
Array(1.6925814, dtype=float32)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
1.68612
Array(1.68612, dtype=float32)

For more exotic usecases (e.g. multiple optimizers) it's probably best to
fork the class and modify it.
Expand Down
3 changes: 3 additions & 0 deletions flax/experimental/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ def grad(

Example::

>>> from flax.experimental import nnx
>>> import jax, jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ filterwarnings = [
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
# DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
"ignore:.*jax.tree_map is deprecated.*:DeprecationWarning",
# DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
"ignore:.*the imp module is deprecated in favour of importlib.*:DeprecationWarning",
]

[tool.coverage.report]
Expand Down
3 changes: 2 additions & 1 deletion tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ if $RUN_DOCTEST; then
pytest -n auto flax \
--doctest-modules \
--suppress-no-test-exit-code \
--ignore=flax/experimental/nnx
--ignore=flax/experimental/nnx/ideas \
--ignore=flax/experimental/nnx/examples
fi

# check that flax is running on editable mode
Expand Down