This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
Adding experimental nnc_compile
option to NUTS and HMC
#1385
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Mar 24, 2022
@horizon-blue has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
horizon-blue
added a commit
that referenced
this pull request
Mar 25, 2022
Summary: **This PR is not ready for review yet**: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files. ### Motivation With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences. ### Changes proposed - `functorch>=0.1.0` is added to out list of dependencies - Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step. - We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s) - All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time. - The docstring of HMC & NUTS classes are updated as well To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g. ``` nuts = bm.GlobalNoUTurnSampler(nnc_compile=True) nuts.infer(...) # same arguments as usual ``` Pull Request resolved: #1385 Test Plan: I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with ``` buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc ``` or equivalently, for OSS: ``` pytest src/beanmachine/ppl/experimental/tests/nnc_test.py ``` Differential Revision: D35127777 Pulled By: horizon-blue fbshipit-source-id: 8efcb3c6234a7f4558517a50661d7b65f1f4bb2e
horizon-blue
force-pushed
the
nnc-prototype
branch
from
March 25, 2022 00:18
acd2287
to
7817d4a
Compare
This pull request was exported from Phabricator. Differential Revision: D35127777 |
horizon-blue
added a commit
that referenced
this pull request
Mar 25, 2022
Summary: ### Motivation With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences. ### Changes proposed - `functorch>=0.1.0` is added to out list of dependencies - Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step. - We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s) - All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time. - The docstring of HMC & NUTS classes are updated as well To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g. ``` nuts = bm.GlobalNoUTurnSampler(nnc_compile=True) nuts.infer(...) # same arguments as usual ``` Pull Request resolved: #1385 Test Plan: I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with ``` buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc ``` or equivalently, for OSS: ``` pytest src/beanmachine/ppl/experimental/tests/nnc_test.py ``` Reviewed By: jpchen Differential Revision: D35127777 Pulled By: horizon-blue fbshipit-source-id: 43779efec1d72380b06229755cb5c699dfe75b05
horizon-blue
force-pushed
the
nnc-prototype
branch
from
March 25, 2022 22:45
7817d4a
to
30af6cb
Compare
This pull request was exported from Phabricator. Differential Revision: D35127777 |
Summary: ### Motivation With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences. ### Changes proposed - `functorch>=0.1.0` is added to out list of dependencies - Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step. - We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s) - All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time. - The docstring of HMC & NUTS classes are updated as well To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g. ``` nuts = bm.GlobalNoUTurnSampler(nnc_compile=True) nuts.infer(...) # same arguments as usual ``` Pull Request resolved: #1385 Test Plan: I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with ``` buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc ``` or equivalently, for OSS: ``` pytest src/beanmachine/ppl/experimental/tests/nnc_test.py ``` Reviewed By: jpchen Differential Revision: D35127777 Pulled By: horizon-blue fbshipit-source-id: 578c62b7ef10555fc0f58849e47a6a22c3e51f68
horizon-blue
force-pushed
the
nnc-prototype
branch
from
March 25, 2022 22:56
30af6cb
to
cf0f890
Compare
This pull request was exported from Phabricator. Differential Revision: D35127777 |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is not ready for review yet: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files.
Motivation
With the first Beta release of functorch, we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.
Changes proposed
functorch>=0.1.0
is added to out list of dependenciestorch.Tensor
instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture ofTensor
s andfloat
s)beanmachine.ppl.experimental.nnc.utils
, which will throw a warning when it's being imported for the first time.To try NNC out, simply set
nnc_compile
toTrue
when initializing the inference class, e.g.Test Plan
I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with
Types of changes
Checklist