Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
fix rng init
Browse files Browse the repository at this point in the history
  • Loading branch information
lkhphuc committed Mar 17, 2022
1 parent 1467753 commit 0d07d2a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
26 changes: 15 additions & 11 deletions tests/nn/test_attention.py
Expand Up @@ -29,9 +29,9 @@ class MultiHeadDotProductAttentionTest(unittest.TestCase):
batch_size=st.integers(min_value=1, max_value=32),
length=st.integers(min_value=1, max_value=32),
log2_features_in=st.integers(min_value=1, max_value=5),
log2_num_heads=st.integers(min_value=1, max_value=3),
log2_qkv_features=st.integers(min_value=1, max_value=5),
log2_out_features=st.integers(min_value=1, max_value=5),
log2_num_heads=st.integers(min_value=0, max_value=2),
log2_qkv_features=st.integers(min_value=3, max_value=5),
log2_out_features=st.integers(min_value=3, max_value=5),
broadcast_dropout=st.booleans(),
dropout_rate=st.floats(min_value=0.0, max_value=1.0),
deterministic=st.booleans(),
Expand All @@ -41,7 +41,7 @@ class MultiHeadDotProductAttentionTest(unittest.TestCase):
decode=st.booleans(),
training=st.booleans(),
)
@hp.settings(deadline=None, max_examples=1)
@hp.settings(deadline=None, max_examples=20)
def test_equivalence(
self,
batch_size,
Expand Down Expand Up @@ -73,7 +73,8 @@ def test_equivalence(
out_features=2 ** log2_out_features,
broadcast_dropout=broadcast_dropout,
dropout_rate=dropout_rate,
deterministic=deterministic,
# deterministic=deterministic,
deterministic=False,
kernel_init=kernel_init,
bias_init=bias_init,
use_bias=use_bias,
Expand All @@ -85,15 +86,19 @@ def test_equivalence(
out_features=2 ** log2_out_features,
broadcast_dropout=broadcast_dropout,
dropout_rate=dropout_rate,
deterministic=deterministic,
# deterministic=deterministic,
deterministic=False,
kernel_init=kernel_init,
bias_init=bias_init,
use_bias=use_bias,
decode=False,
).train(training)

flax_key, _ = tx.iter_split(key) # emulate init split
variables = flax_module.init(key, inputs_q, inputs_kv)
_, flax_key = tx.iter_split(key) # emulate init split
param_key, _ = tx.iter_split(flax_key) # emulate init split
variables = flax_module.init(
{"params": param_key, "dropout": key}, inputs_q, inputs_kv
)
treex_module = treex_module.init(key, (inputs_q, inputs_kv))

assert np.allclose(
Expand Down Expand Up @@ -123,12 +128,11 @@ def test_equivalence(
)

# split key same way tx.Dropout does internally
rng, _ = tx.iter_split(flax_key, 2)
y_flax = flax_module.apply(
variables, rngs={"dropout": rng}, inputs_q=inputs_q, inputs_kv=inputs_kv
variables, rngs={"dropout": key}, inputs_q=inputs_q, inputs_kv=inputs_kv
)

y_treex = treex_module(inputs_q=inputs_q, inputs_kv=inputs_kv)
y_treex = treex_module(inputs_q=inputs_q, inputs_kv=inputs_kv, rng=key)

assert np.allclose(y_flax, y_treex)

Expand Down
8 changes: 1 addition & 7 deletions treex/nn/attention.py
Expand Up @@ -156,7 +156,7 @@ def __call__(
output of shape `[batch_sizes..., length, features]`.
"""
if self.initializing():
rngs = {"params": next_key()}
rngs = {"params": next_key(), "dropout": next_key()}
variables = self.module.init(rngs, inputs_q, inputs_kv, mask)

# Extract collections
Expand All @@ -172,12 +172,6 @@ def __call__(
assert self.value is not None
assert self.out is not None

if self.use_bias:
assert self.query["bias"] is not None
assert self.key["bias"] is not None
assert self.value["bias"] is not None
assert self.out["bias"] is not None

params = {
"query": self.query,
"key": self.query,
Expand Down

0 comments on commit 0d07d2a

Please sign in to comment.