Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622967775
  • Loading branch information
mohammadasghari authored and DeepMind committed Apr 8, 2024
1 parent 2a0733a commit 0314d5e
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions enn/losses/base.py
Expand Up @@ -90,11 +90,6 @@ def loss_fn(enn: base.EpistemicNetwork[base.Input, base.Output],
new_state, state)
mean_metrics = jax.tree_util.tree_map(batch_mean, metrics)

# TODO(author2): Adding a logging method for keeping track of state counter.
# This piece of code is only used for debugging/metrics.
if len(new_state) > 0: # pylint:disable=g-explicit-length-test
first_state_layer = new_state[list(new_state.keys())[0]]
mean_metrics['state_counter'] = jnp.mean(first_state_layer['counter'])
return mean_loss, (new_state, mean_metrics)

return loss_fn
Expand Down

0 comments on commit 0314d5e

Please sign in to comment.