You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling .init I obtain the slow and fast weights like this
And just before calling the .apply function I want to merge them (the fast weights will be modified). My first attempt was to use the hk.data_structures.merge method like this
But this raises the exception AttributeError: 'DeviceArray' object has no attribute 'items'. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.
Thanks!
The text was updated successfully, but these errors were encountered:
Hey @sebamenabar , the utilities in data_structures rely on being passed a "two level" mapping (in Haiku the params and state dicts have the following structure: {'module_name': {'parameter_name': parameter_value}}), when you do fast = params['fast'] you are getting one of the inner mappings which cannot be used with functions expecting a two level mapping.
There is an easy solution though 😄 I would suggest using partition and merge to split and combine the params dict while maintaining two level mappings:
params=f.init(..)
# Split fast and slow params to operate on them independently.predicate=lambdam, n, v: m=='fast_weights'fast_params, slow_params=hk.data_structures.partition(predicate, params)
# Recombine fast and slow to (for example) pass to apply.params=hk.data_structures.merge(fast_params, slow_params)
out=f.apply(params, ..)
Fabio added an example utility function (jax_fn_with_filter) to the tests [0] when he authored these, you may find that useful too.
Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling
.init
I obtain the slow and fast weights like thisAnd just before calling the
.apply
function I want to merge them (the fast weights will be modified). My first attempt was to use thehk.data_structures.merge
method like thisBut this raises the exception
AttributeError: 'DeviceArray' object has no attribute 'items'
. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.Thanks!
The text was updated successfully, but these errors were encountered: