Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge FlatMappings with DeviceArray as values raises AttributeError: 'DeviceArray' object has no attribute 'items' #66

Closed
sebamenabar opened this issue Aug 14, 2020 · 2 comments

Comments

@sebamenabar
Copy link

Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling .init I obtain the slow and fast weights like this

params = f.init(...)
fast_weights = params["fast_weights"]

And just before calling the .apply function I want to merge them (the fast weights will be modified). My first attempt was to use the hk.data_structures.merge method like this

params = hk.data_structures.merge(params, fast_weights)
output = f.apply(params, rng, inputs)

But this raises the exception AttributeError: 'DeviceArray' object has no attribute 'items'. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.

Thanks!

@tomhennigan
Copy link
Collaborator

Hey @sebamenabar , the utilities in data_structures rely on being passed a "two level" mapping (in Haiku the params and state dicts have the following structure: {'module_name': {'parameter_name': parameter_value}}), when you do fast = params['fast'] you are getting one of the inner mappings which cannot be used with functions expecting a two level mapping.

There is an easy solution though 😄 I would suggest using partition and merge to split and combine the params dict while maintaining two level mappings:

params = f.init(..)

# Split fast and slow params to operate on them independently.
predicate = lambda m, n, v: m == 'fast_weights'
fast_params, slow_params = hk.data_structures.partition(predicate, params)

# Recombine fast and slow to (for example) pass to apply.
params = hk.data_structures.merge(fast_params, slow_params)
out = f.apply(params, ..)

Fabio added an example utility function (jax_fn_with_filter) to the tests [0] when he authored these, you may find that useful too.

[0] https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/filtering_test.py

@sebamenabar
Copy link
Author

Thanks tom!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants