Skip to content

Commit

Permalink
Fix imports in examples and refactor base
Browse files Browse the repository at this point in the history
  • Loading branch information
jheek committed Feb 25, 2020
1 parent 5291c3c commit baf43e7
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 332 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10/models/pyramidnet.py
Expand Up @@ -15,7 +15,7 @@
"""PyramidNet with Shake-Drop."""

from flax import nn
import utils
from models import utils
import jax
import jax.numpy as jnp

Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10/models/wideresnet_shakeshake.py
Expand Up @@ -15,7 +15,7 @@
"""Wide Resnet Model with shake-shake regularization."""

from flax import nn
import utils
from models import utils
import jax


Expand Down
4 changes: 2 additions & 2 deletions examples/wip/moco/train_moco.py
Expand Up @@ -27,8 +27,8 @@

from flax import jax_utils
from flax import optim
import imagenet_data_source
import model_resnet
from moco import imagenet_data_source
from moco import model_resnet
from flax.metrics import tensorboard
import flax.nn
from flax.training import common_utils
Expand Down
17 changes: 2 additions & 15 deletions flax/nn/attention.py
Expand Up @@ -182,21 +182,8 @@ def _init(shape_fn):
return Cache(jax.tree_map(_init, self.state))


def _iterate_cache(cache):
# pylint: disable=protected-access
if cache._mutable:
raise ValueError('A mutable cache should not be transformed by Jax.')
return (cache.state,), cache.shared


def _cache_from_iterable(shared, state):
return Cache(state[0], shared=shared)


# make sure a collection is traced.
jax.tree_util.register_pytree_node(Cache,
_iterate_cache,
_cache_from_iterable)
jax.tree_util.register_pytree_node(
Cache, base.iterate_collection, base.collection_from_iterable)


def _scan_nd(body_fn, init, xs, n=1):
Expand Down

0 comments on commit baf43e7

Please sign in to comment.