Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hk.add_loss #54

Closed
cgarciae opened this issue Jul 5, 2020 · 2 comments
Closed

hk.add_loss #54

cgarciae opened this issue Jul 5, 2020 · 2 comments

Comments

@cgarciae
Copy link

cgarciae commented Jul 5, 2020

To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a hk.add_loss utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output from apply. I guess that this would require an additional flag to hk.transform and friends.

@tomhennigan
Copy link
Collaborator

Hey @cgarciae , thanks for the FR! In general we try to keep Haiku fairly lean and encourage features (e.g. training loops, optimizers etc) to be solved in other libraries (then they can benefit all JAX users not just Haiku users) or built from existing Haiku/JAX features.

Wrt using existing features, you might consider using hk.set_state for this. This is a fairly general mechanism in Haiku for logging values associated with modules. You could use hk.data_structures.filter to extract all losses from the state dict:

def f(x):
  y = hk.nets.ResNet50(1000)(x, True)
  loss_1 = y.sum()
  loss_2 = loss_1 ** 2
  hk.set_state("loss_1", loss_1)
  return loss_2

f = hk.transform_with_state(f)

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
params, state = f.init(rng, x)

# Apply as usual:
params, state = f.apply(params, state, rng, x)

# Extract losses from state:
is_loss = lambda m, n, v: n.startswith("loss")
losses = hk.data_structures.filter(is_loss, state)
print(losses)  # frozendict({'~': frozendict({'loss_1': DeviceArray(0., dtype=float32)})})

If you want to implement this as a standalone feature (e.g. decoupled from hk.set_state) then I've forked the following from @ibab who has implemented something similar (his version is more robust with thread safety and nesting).

from contextlib import contextmanager

loggers = []

@contextmanager
def context():
  data = {}
  loggers.append(data)
  try:
    yield data
  finally:
    loggers.pop()

def log(name, value):
  # NOTE: log(..) ignored when not logging.
  if loggers:
    data = loggers[-1]
    data[name] = value

def f(x):
  x = x ** 2
  log("a", x)
  x = x ** 2
  log("b", x)
  return x

def g():
  with context() as data:
    y = f(2)
  return y, data

y, data = g()
assert y == 16
assert data == {'a': 4, 'b': 16}

I hope that's useful, please feel free to reopen if this does not solve your usecase.

@cgarciae
Copy link
Author

cgarciae commented Jul 7, 2020

Thanks for the info @tomhennigan ! I think set_state + filters can achieve what I am looking for :)
I didn't realize set_state could be called without a corresponding get_state.

I still think a first class logger would be nice / less error prone but fortunately its not a blocker.

li-zihang pushed a commit to li-zihang/dm-haiku that referenced this issue Mar 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants