Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial draft of hooks #81

Closed
wants to merge 1 commit into from
Closed

Initial draft of hooks #81

wants to merge 1 commit into from

Conversation

seanmor5
Copy link
Contributor

Resolves #77 when merged.

This is an initial draft but it motivates 2 discussions: host callbacks and reusable model blocks.

In order to perform side-effecting operations with hooks, Nx needs to support host callbacks. Once that's done, it likely means that hooks will always be implemented as host callbacks, but the implementation will need to evolve once we support them. I don't see the API being impacted much once we introduce host callbacks.

As for model blocks: In PyTorch has two kinds of hooks - tensor hooks and module hooks. In Axon we're more focused on module-style hooks for now. Module hooks have the following signature: hook(module, input, output) :: None. The equivalent in Axon would be: hook(axon_struct, input, output) :: nil. For both forward and backward hooks, even if they are only meant to be side-effecting operations, I'd prefer they always return an output such that after a layer the output is basically what's shown in this PR.

One issue is that we're currently unable to support hooks registered on large blocks of a model. For example, if I have a block of Convs:

block =
  x
  |> conv()
  |> conv()
  ...

and I tried to register a hook, I would only have access to the input to the last conv in the block, not x. This motivates introducing a module-like abstraction which allows you to treat groups of layers as a single block. This is essentially the same issue we have in #59 as well.

@josevalim
Copy link
Contributor

Looks great! Just one note: I think even with host callbacks, you would want to choose if you want to bring the data back or the hook is a Nx transformation.

@seanmor5
Copy link
Contributor Author

seanmor5 commented Dec 5, 2021

I will have a new PR for this shortly integrated with Nx's new hook API

@seanmor5 seanmor5 closed this Dec 5, 2021
@seanmor5 seanmor5 deleted the sm-hooks branch December 10, 2021 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add training hooks
2 participants