Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
Bumps flax to 0.4.0 (#60)
Browse files Browse the repository at this point in the history
* Bumps `flax` to `0.4.0`

Updates flax to the most recent version. This currently breaks the
current implementation and way in which rng keys of  dropout is being
handled.

Currently have disabled one of the dropout equivalence tests as I am not
fully aware if there is a method of directly affecting the value of
`next_key` within a treex module.

* [auto] black linting fixes

* fix test_pretrained_flax_module

Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
  • Loading branch information
ptigwe and cgarciae committed Feb 4, 2022
1 parent d83aca4 commit 4b9bdac
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 19 deletions.
46 changes: 38 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -15,7 +15,7 @@ secondary = true

[tool.poetry.dependencies]
python = "^3.7"
flax = "^0.3.4"
flax = "^0.4.0"
PyYAML = "^5.4.1"
rich = "^10.7.0"
optax = "^0.0.9"
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/test_dropout.py
Expand Up @@ -60,7 +60,7 @@ def test_dropout_equivalence(
# split key same way tx.Dropout does internally
rng, _ = tx.iter_split(flax_key, 2)

y_flax = flax_module.apply(variables, x, rng=rng)
y_flax = flax_module.apply(variables, x, rngs={"dropout": rng})
y_treex = treex_module(x)

assert np.allclose(y_flax, y_treex)
Expand Down
18 changes: 10 additions & 8 deletions tests/nn/test_flax_module.py
Expand Up @@ -29,10 +29,10 @@ def __call__(self, x, training):
def test_pretrained_flax_module(self):
class SomeModule(flax.linen.Module):
@flax.linen.compact
def __call__(self, x, rng, training):
def __call__(self, x, training):
x = flax.linen.Dense(16)(x)
x = flax.linen.BatchNorm()(x, use_running_average=not training)
x = flax.linen.Dropout(0.5)(x, deterministic=not training, rng=rng)
x = flax.linen.Dropout(0.5)(x, deterministic=not training)
x = flax.linen.Conv(16, [3])(x)

return x
Expand All @@ -48,13 +48,12 @@ def __call__(self, x, rng, training):
variables = flax_module.init(
{"params": params_key, "dropout": dropout_key},
x,
rng,
False,
)

treex_module = tx.FlaxModule(SomeModule(), variables=variables,).init(
42,
inputs=tx.Inputs(x, rng),
inputs=tx.Inputs(x),
)

assert all(
Expand All @@ -72,23 +71,26 @@ def __call__(self, x, rng, training):
)
)

y_treex = treex_module(x, rng)
y_treex = treex_module(x, rng)
flax_next_key = treex_module.next_key.copy()
y_treex = treex_module(x)
y_treex = treex_module(x)

rng, next_rng = tx.iter_split(dropout_key)
y_flax, updates = flax_module.apply(
variables,
x,
rng,
training,
mutable=["batch_stats"],
rngs={"dropout": flax_next_key()},
)

variables = variables.copy(updates)
y_flax, updates = flax_module.apply(
variables,
x,
rng,
training,
mutable=["batch_stats"],
rngs={"dropout": flax_next_key()},
)
variables = variables.copy(updates)

Expand Down
2 changes: 1 addition & 1 deletion treex/nn/dropout.py
Expand Up @@ -89,7 +89,7 @@ def __call__(
variables,
x,
deterministic=not training,
rng=rng,
rngs={"dropout": rng},
)

return tp.cast(jnp.ndarray, output)

0 comments on commit 4b9bdac

Please sign in to comment.