This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding experimental
nnc_compile
option to NUTS and HMC (#1385)
Summary: **This PR is not ready for review yet**: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files. ### Motivation With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences. ### Changes proposed - `functorch>=0.1.0` is added to out list of dependencies - Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step. - We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s) - All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time. - The docstring of HMC & NUTS classes are updated as well To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g. ``` nuts = bm.GlobalNoUTurnSampler(nnc_compile=True) nuts.infer(...) # same arguments as usual ``` Pull Request resolved: #1385 Test Plan: I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with ``` buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc ``` or equivalently, for OSS: ``` pytest src/beanmachine/ppl/experimental/tests/nnc_test.py ``` Differential Revision: D35127777 Pulled By: horizon-blue fbshipit-source-id: 8efcb3c6234a7f4558517a50661d7b65f1f4bb2e
- Loading branch information
1 parent
2af74d2
commit 7817d4a
Showing
11 changed files
with
286 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
from typing import Callable, TypeVar, Optional, Tuple | ||
|
||
from typing_extensions import ParamSpec | ||
|
||
P = ParamSpec("P") | ||
R = TypeVar("R") | ||
|
||
|
||
def nnc_jit( | ||
f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None | ||
) -> Callable[P, R]: | ||
""" | ||
A helper function that lazily import the NNC utils, which initialize the compiler | ||
and displaying a experimental warning, then invoke the underlying nnc_jit on | ||
the function f. | ||
""" | ||
try: | ||
# The setup code in `nnc.utils` will only be executed once in a Python session | ||
from beanmachine.ppl.experimental.nnc.utils import nnc_jit as raw_nnc_jit | ||
except ImportError as e: | ||
if sys.platform.startswith("win"): | ||
message = "functorch is not available on Windows." | ||
else: | ||
message = ( | ||
"Fails to initialize NNC. This is likely caused by version mismatch " | ||
"between PyTorch and functorch. Please checkout the functorch project " | ||
"for installation guide." | ||
) | ||
raise RuntimeError(message) from e | ||
|
||
return raw_nnc_jit(f, static_argnums) | ||
|
||
|
||
__all__ = ["nnc_jit"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import warnings | ||
|
||
import functorch | ||
import torch | ||
import torch.jit | ||
import torch.utils._pytree as pytree | ||
from functorch.compile import ( | ||
nop, | ||
aot_function, | ||
decomposition_table, | ||
register_decomposition, | ||
) | ||
|
||
# the warning wIll only be shown to user once when this module is imported | ||
warnings.warn( | ||
"The support of NNC compiler is experimental and the API is subject to" | ||
"change in the future releases of Bean Machine. For questions regarding NNC, please" | ||
"checkout the functorch project (https://github.com/pytorch/functorch)." | ||
) | ||
|
||
torch._C._jit_set_texpr_reductions_enabled(True) | ||
|
||
# override the usage of torch.jit.script, which has a bit of issue handling | ||
# empty lists (functorch#440) | ||
def simple_ts_compile(fx_g, example_inps): | ||
f = torch.jit.trace(fx_g, example_inps, strict=False) | ||
f = torch.jit.freeze(f.eval()) | ||
torch._C._jit_pass_remove_mutation(f.graph) | ||
|
||
return f | ||
|
||
|
||
aten = torch.ops.aten | ||
decompositions = [aten.detach] | ||
bm_decompositions = { | ||
k: v for k, v in decomposition_table.items() if k in decompositions | ||
} | ||
|
||
|
||
@register_decomposition(aten.mv, bm_decompositions) | ||
def mv(a, b): | ||
return (a * b).sum(dim=-1) | ||
|
||
|
||
@register_decomposition(aten.dot, bm_decompositions) | ||
def dot(a, b): | ||
return (a * b).sum(dim=-1) | ||
|
||
|
||
# @register_decomposition(aten.nan_to_num, bm_decompositions) | ||
# def nan_to_num(a, val): | ||
# return aten.where(a != a, val, a) | ||
|
||
|
||
@register_decomposition(aten.zeros_like, bm_decompositions) | ||
def zeros_like(a, **kwargs): | ||
return a * 0 | ||
|
||
|
||
@register_decomposition(aten.ones_like, bm_decompositions) | ||
def ones_like(a, **kwargs): | ||
return a * 0 + 1 | ||
|
||
|
||
def nnc_jit(f, static_argnums=None): | ||
return aot_function( | ||
f, | ||
simple_ts_compile, | ||
nop, | ||
static_argnums=static_argnums, | ||
decompositions=bm_decompositions, | ||
) | ||
|
||
|
||
functorch._src.compilers.simple_ts_compile = simple_ts_compile | ||
|
||
|
||
# override default dict flatten (which requires keys to be sortable) | ||
def _dict_flatten(d): | ||
keys = list(d.keys()) | ||
values = [d[key] for key in keys] | ||
return values, keys | ||
|
||
|
||
def _dict_unflatten(values, context): | ||
return {key: value for key, value in zip(context, values)} | ||
|
||
|
||
pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten) | ||
|
||
__all__ = ["nnc_jit"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
import warnings | ||
|
||
import beanmachine.ppl as bm | ||
import pytest | ||
import torch | ||
import torch.distributions as dist | ||
|
||
if sys.platform.startswith("win"): | ||
pytest.skip("functorch is not available on Windows", allow_module_level=True) | ||
|
||
|
||
class SampleModel: | ||
@bm.random_variable | ||
def foo(self): | ||
return dist.Normal(0.0, 1.0) | ||
|
||
@bm.random_variable | ||
def bar(self): | ||
return dist.Normal(self.foo(), 1.0) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"algorithm", | ||
[ | ||
bm.GlobalNoUTurnSampler(nnc_compile=True), | ||
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True), | ||
], | ||
) | ||
def test_nnc_compile(algorithm): | ||
model = SampleModel() | ||
queries = [model.foo()] | ||
observations = {model.bar(): torch.tensor(0.5)} | ||
num_samples = 30 | ||
num_chains = 2 | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
# verify that NNC can run through | ||
samples = algorithm.infer( | ||
queries, | ||
observations, | ||
num_samples, | ||
num_adaptive_samples=num_samples, | ||
num_chains=num_chains, | ||
) | ||
# sanity check: make sure that the samples are valid | ||
assert not torch.isnan(samples[model.foo()]).any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.