Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Adding experimental nnc_compile option to NUTS and HMC #1385

Closed
wants to merge 1 commit into from

Conversation

horizon-blue
Copy link
Contributor

@horizon-blue horizon-blue commented Mar 24, 2022

This PR is not ready for review yet: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files.

Motivation

With the first Beta release of functorch, we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.

Changes proposed

  • functorch>=0.1.0 is added to out list of dependencies
  • Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step.
  • We use torch.Tensor instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of Tensors and floats)
  • All of the NNC utils are put into beanmachine.ppl.experimental.nnc.utils, which will throw a warning when it's being imported for the first time.
  • The docstring of HMC & NUTS classes are updated as well

To try NNC out, simply set nnc_compile to True when initializing the inference class, e.g.

nuts = bm.GlobalNoUTurnSampler(nnc_compile=True)
nuts.infer(...)  # same arguments as usual

Test Plan

I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with

pytest src/beanmachine/ppl/experimental/tests/nnc_test.py

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Checklist

  • My code follows the code style of this project.
  • My change requires a change to the documentation.
  • I have updated the documentation accordingly.
  • I have read the CONTRIBUTING document.
  • I have added tests to cover my changes.
  • All new and existing tests passed.
  • The title of my pull request is a short description of the requested changes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 24, 2022
@facebook-github-bot
Copy link
Collaborator

@horizon-blue has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

horizon-blue added a commit that referenced this pull request Mar 25, 2022
Summary:
**This PR is not ready for review yet**: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files.

### Motivation
With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.

### Changes proposed
- `functorch>=0.1.0` is added to out list of dependencies
- Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step.
- We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s)
- All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time.
- The docstring of HMC & NUTS classes are updated as well

To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g.

```
nuts = bm.GlobalNoUTurnSampler(nnc_compile=True)
nuts.infer(...)  # same arguments as usual
```

Pull Request resolved: #1385

Test Plan:
I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with
```
buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc
```
or equivalently, for OSS:
```
pytest src/beanmachine/ppl/experimental/tests/nnc_test.py
```

Differential Revision: D35127777

Pulled By: horizon-blue

fbshipit-source-id: 8efcb3c6234a7f4558517a50661d7b65f1f4bb2e
@facebook-github-bot
Copy link
Collaborator

This pull request was exported from Phabricator. Differential Revision: D35127777

horizon-blue added a commit that referenced this pull request Mar 25, 2022
Summary:
### Motivation
With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.

### Changes proposed
- `functorch>=0.1.0` is added to out list of dependencies
- Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step.
- We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s)
- All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time.
- The docstring of HMC & NUTS classes are updated as well

To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g.

```
nuts = bm.GlobalNoUTurnSampler(nnc_compile=True)
nuts.infer(...)  # same arguments as usual
```

Pull Request resolved: #1385

Test Plan:
I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with
```
buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc
```
or equivalently, for OSS:
```
pytest src/beanmachine/ppl/experimental/tests/nnc_test.py
```

Reviewed By: jpchen

Differential Revision: D35127777

Pulled By: horizon-blue

fbshipit-source-id: 43779efec1d72380b06229755cb5c699dfe75b05
@facebook-github-bot
Copy link
Collaborator

This pull request was exported from Phabricator. Differential Revision: D35127777

Summary:
### Motivation
With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.

### Changes proposed
- `functorch>=0.1.0` is added to out list of dependencies
- Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step.
- We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s)
- All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time.
- The docstring of HMC & NUTS classes are updated as well

To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g.

```
nuts = bm.GlobalNoUTurnSampler(nnc_compile=True)
nuts.infer(...)  # same arguments as usual
```

Pull Request resolved: #1385

Test Plan:
I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with
```
buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc
```
or equivalently, for OSS:
```
pytest src/beanmachine/ppl/experimental/tests/nnc_test.py
```

Reviewed By: jpchen

Differential Revision: D35127777

Pulled By: horizon-blue

fbshipit-source-id: 578c62b7ef10555fc0f58849e47a6a22c3e51f68
@facebook-github-bot
Copy link
Collaborator

This pull request was exported from Phabricator. Differential Revision: D35127777

@horizon-blue horizon-blue deleted the nnc-prototype branch March 29, 2022 07:07
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants