-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hk.add_loss #54
Comments
Hey @cgarciae , thanks for the FR! In general we try to keep Haiku fairly lean and encourage features (e.g. training loops, optimizers etc) to be solved in other libraries (then they can benefit all JAX users not just Haiku users) or built from existing Haiku/JAX features. Wrt using existing features, you might consider using def f(x):
y = hk.nets.ResNet50(1000)(x, True)
loss_1 = y.sum()
loss_2 = loss_1 ** 2
hk.set_state("loss_1", loss_1)
return loss_2
f = hk.transform_with_state(f)
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
params, state = f.init(rng, x)
# Apply as usual:
params, state = f.apply(params, state, rng, x)
# Extract losses from state:
is_loss = lambda m, n, v: n.startswith("loss")
losses = hk.data_structures.filter(is_loss, state)
print(losses) # frozendict({'~': frozendict({'loss_1': DeviceArray(0., dtype=float32)})}) If you want to implement this as a standalone feature (e.g. decoupled from from contextlib import contextmanager
loggers = []
@contextmanager
def context():
data = {}
loggers.append(data)
try:
yield data
finally:
loggers.pop()
def log(name, value):
# NOTE: log(..) ignored when not logging.
if loggers:
data = loggers[-1]
data[name] = value
def f(x):
x = x ** 2
log("a", x)
x = x ** 2
log("b", x)
return x
def g():
with context() as data:
y = f(2)
return y, data
y, data = g()
assert y == 16
assert data == {'a': 4, 'b': 16} I hope that's useful, please feel free to reopen if this does not solve your usecase. |
Thanks for the info @tomhennigan ! I think I still think a first class logger would be nice / less error prone but fortunately its not a blocker. |
PiperOrigin-RevId: 363357569
To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a
hk.add_loss
utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output fromapply
. I guess that this would require an additional flag tohk.transform
and friends.The text was updated successfully, but these errors were encountered: