Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Adding experimental nnc_compile option to NUTS and HMC (#1385)
Browse files Browse the repository at this point in the history
Summary:
**This PR is not ready for review yet**: I'm creating the PR just so that the changes can be imported and run against the internal tests. I'll update the summary of the PR after polishing the files.

### Motivation
With the first Beta release of [functorch](https://github.com/pytorch/functorch), we can begin to merge in our BM-NNC integration prototype, which uses NNC to JIT compile part of the algorithm to accelerate inferences.

### Changes proposed
- `functorch>=0.1.0` is added to out list of dependencies
- Because NNC is yet to support control flow primitives, in NUTS, NNC is applied on the base case of recursive tree building algorithm. In HMC, NNC is applied on a single leapfrog step.
- We use `torch.Tensor` instead of raw scalars for some variables because TorchScript tracer requires inputs/outputs to be of the same type (?) (i.e., we can't return a tuple of a mixture of `Tensor`s and `float`s)
- All of the NNC utils are put into `beanmachine.ppl.experimental.nnc.utils`, which will throw a warning when it's being imported for the first time.
- The docstring of HMC & NUTS classes are updated as well

To try NNC out, simply set `nnc_compile` to `True` when initializing the inference class, e.g.

```
nuts = bm.GlobalNoUTurnSampler(nnc_compile=True)
nuts.infer(...)  # same arguments as usual
```

Pull Request resolved: #1385

Test Plan:
I've added a simple sanity check to cover the NNC compile option on NUTS and HMC, which can be run with
```
buck test //beanmachine/beanmachine/ppl:test-ppl -- nnc
```
or equivalently, for OSS:
```
pytest src/beanmachine/ppl/experimental/tests/nnc_test.py
```

Differential Revision: D35127777

Pulled By: horizon-blue

fbshipit-source-id: 8efcb3c6234a7f4558517a50661d7b65f1f4bb2e
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Mar 25, 2022
1 parent 2af74d2 commit 7817d4a
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 58 deletions.
39 changes: 20 additions & 19 deletions setup.py
Expand Up @@ -16,7 +16,24 @@
REQUIRED_MAJOR = 3
REQUIRED_MINOR = 7


INSTALL_REQUIRES = [
"arviz>=0.11.0",
"astor>=0.7.1",
"botorch>=0.5.1",
"flowtorch>=0.3",
"gpytorch>=1.3.0",
"graphviz>=0.17",
"numpy>=1.18.1",
"pandas>=0.24.2",
"parameterized>=0.8.1",
"plotly>=2.2.1",
"scipy>=0.16",
"statsmodels>=0.12.0",
"torch>=1.9.0",
"tqdm>=4.46.0",
"typing-extensions>=3.10",
"xarray>=0.16.0",
]
TEST_REQUIRES = ["pytest>=7.0.0", "pytest-cov"]
TUTORIALS_REQUIRES = [
"bokeh",
Expand Down Expand Up @@ -48,6 +65,7 @@
CPP_COMPILE_ARGS = ["/WX", "/permissive-", "-DEIGEN_HAS_C99_MATH"]
else:
CPP_COMPILE_ARGS = ["-std=c++17", "-Werror"]
INSTALL_REQUIRES.append("functorch>=0.1.0")


# Check for python version
Expand Down Expand Up @@ -125,24 +143,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">={}.{}".format(REQUIRED_MAJOR, REQUIRED_MINOR),
install_requires=[
"arviz>=0.11.0",
"astor>=0.7.1",
"botorch>=0.5.1",
"flowtorch>=0.3",
"gpytorch>=1.3.0",
"graphviz>=0.17",
"numpy>=1.18.1",
"pandas>=0.24.2",
"parameterized>=0.8.1",
"plotly>=2.2.1",
"scipy>=0.16",
"statsmodels>=0.12.0",
"torch>=1.9.0",
"tqdm>=4.46.0",
"typing-extensions>=3.10",
"xarray>=0.16.0",
],
install_requires=INSTALL_REQUIRES,
packages=find_packages("src"),
package_dir={"": "src"},
ext_modules=[
Expand Down
40 changes: 40 additions & 0 deletions src/beanmachine/ppl/experimental/nnc/__init__.py
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from typing import Callable, TypeVar, Optional, Tuple

from typing_extensions import ParamSpec

P = ParamSpec("P")
R = TypeVar("R")


def nnc_jit(
f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None
) -> Callable[P, R]:
"""
A helper function that lazily import the NNC utils, which initialize the compiler
and displaying a experimental warning, then invoke the underlying nnc_jit on
the function f.
"""
try:
# The setup code in `nnc.utils` will only be executed once in a Python session
from beanmachine.ppl.experimental.nnc.utils import nnc_jit as raw_nnc_jit
except ImportError as e:
if sys.platform.startswith("win"):
message = "functorch is not available on Windows."
else:
message = (
"Fails to initialize NNC. This is likely caused by version mismatch "
"between PyTorch and functorch. Please checkout the functorch project "
"for installation guide."
)
raise RuntimeError(message) from e

return raw_nnc_jit(f, static_argnums)


__all__ = ["nnc_jit"]
96 changes: 96 additions & 0 deletions src/beanmachine/ppl/experimental/nnc/utils.py
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import functorch
import torch
import torch.jit
import torch.utils._pytree as pytree
from functorch.compile import (
nop,
aot_function,
decomposition_table,
register_decomposition,
)

# the warning wIll only be shown to user once when this module is imported
warnings.warn(
"The support of NNC compiler is experimental and the API is subject to"
"change in the future releases of Bean Machine. For questions regarding NNC, please"
"checkout the functorch project (https://github.com/pytorch/functorch)."
)

torch._C._jit_set_texpr_reductions_enabled(True)

# override the usage of torch.jit.script, which has a bit of issue handling
# empty lists (functorch#440)
def simple_ts_compile(fx_g, example_inps):
f = torch.jit.trace(fx_g, example_inps, strict=False)
f = torch.jit.freeze(f.eval())
torch._C._jit_pass_remove_mutation(f.graph)

return f


aten = torch.ops.aten
decompositions = [aten.detach]
bm_decompositions = {
k: v for k, v in decomposition_table.items() if k in decompositions
}


@register_decomposition(aten.mv, bm_decompositions)
def mv(a, b):
return (a * b).sum(dim=-1)


@register_decomposition(aten.dot, bm_decompositions)
def dot(a, b):
return (a * b).sum(dim=-1)


# @register_decomposition(aten.nan_to_num, bm_decompositions)
# def nan_to_num(a, val):
# return aten.where(a != a, val, a)


@register_decomposition(aten.zeros_like, bm_decompositions)
def zeros_like(a, **kwargs):
return a * 0


@register_decomposition(aten.ones_like, bm_decompositions)
def ones_like(a, **kwargs):
return a * 0 + 1


def nnc_jit(f, static_argnums=None):
return aot_function(
f,
simple_ts_compile,
nop,
static_argnums=static_argnums,
decompositions=bm_decompositions,
)


functorch._src.compilers.simple_ts_compile = simple_ts_compile


# override default dict flatten (which requires keys to be sortable)
def _dict_flatten(d):
keys = list(d.keys())
values = [d[key] for key in keys]
return values, keys


def _dict_unflatten(values, context):
return {key: value for key, value in zip(context, values)}


pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten)

__all__ = ["nnc_jit"]
52 changes: 52 additions & 0 deletions src/beanmachine/ppl/experimental/tests/nnc_test.py
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
import warnings

import beanmachine.ppl as bm
import pytest
import torch
import torch.distributions as dist

if sys.platform.startswith("win"):
pytest.skip("functorch is not available on Windows", allow_module_level=True)


class SampleModel:
@bm.random_variable
def foo(self):
return dist.Normal(0.0, 1.0)

@bm.random_variable
def bar(self):
return dist.Normal(self.foo(), 1.0)


@pytest.mark.parametrize(
"algorithm",
[
bm.GlobalNoUTurnSampler(nnc_compile=True),
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True),
],
)
def test_nnc_compile(algorithm):
model = SampleModel()
queries = [model.foo()]
observations = {model.bar(): torch.tensor(0.5)}
num_samples = 30
num_chains = 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# verify that NNC can run through
samples = algorithm.infer(
queries,
observations,
num_samples,
num_adaptive_samples=num_samples,
num_chains=num_chains,
)
# sanity check: make sure that the samples are valid
assert not torch.isnan(samples[model.foo()]).any()
8 changes: 8 additions & 0 deletions src/beanmachine/ppl/inference/hmc_inference.py
Expand Up @@ -30,6 +30,8 @@ class GlobalHamiltonianMonteCarlo(BaseInference):
adapt_mass_matrix (bool): Whether to adapt the mass matrix. Defaults to True,
target_accept_prob (float): Target accept prob. Increasing this value would lead
to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -39,12 +41,14 @@ def __init__(
adapt_step_size: bool = True,
adapt_mass_matrix: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.trajectory_length = trajectory_length
self.initial_step_size = initial_step_size
self.adapt_step_size = adapt_step_size
self.adapt_mass_matrix = adapt_mass_matrix
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposer = None

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -66,6 +70,7 @@ def get_proposers(
self.adapt_step_size,
self.adapt_mass_matrix,
self.target_accept_prob,
self.nnc_compile,
)
return [self._proposer]

Expand Down Expand Up @@ -94,12 +99,14 @@ def __init__(
adapt_step_size: bool = True,
adapt_mass_matrix: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.trajectory_length = trajectory_length
self.initial_step_size = initial_step_size
self.adapt_step_size = adapt_step_size
self.adapt_mass_matrix = adapt_mass_matrix
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposers = {}

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -123,6 +130,7 @@ def get_proposers(
self.adapt_step_size,
self.adapt_mass_matrix,
self.target_accept_prob,
self.nnc_compile,
)
proposers.append(self._proposers[node])
return proposers
10 changes: 10 additions & 0 deletions src/beanmachine/ppl/inference/nuts_inference.py
Expand Up @@ -38,6 +38,8 @@ class GlobalNoUTurnSampler(BaseInference):
defaults to True.
target_accept_prob (float): Target accept probability. Increasing this would
lead to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -49,6 +51,7 @@ def __init__(
adapt_mass_matrix: bool = True,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.max_tree_depth = max_tree_depth
self.max_delta_energy = max_delta_energy
Expand All @@ -57,6 +60,7 @@ def __init__(
self.adapt_mass_matrix = adapt_mass_matrix
self.multinomial_sampling = multinomial_sampling
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposer = None

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -80,6 +84,7 @@ def get_proposers(
self.adapt_mass_matrix,
self.multinomial_sampling,
self.target_accept_prob,
self.nnc_compile,
)
return [self._proposer]

Expand All @@ -106,6 +111,8 @@ class SingleSiteNoUTurnSampler(BaseInference):
defaults to True.
target_accept_prob (float): Target accept probability. Increasing this would
lead to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -117,6 +124,7 @@ def __init__(
adapt_mass_matrix: bool = True,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.max_tree_depth = max_tree_depth
self.max_delta_energy = max_delta_energy
Expand All @@ -125,6 +133,7 @@ def __init__(
self.adapt_mass_matrix = adapt_mass_matrix
self.multinomial_sampling = multinomial_sampling
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposers = {}

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -150,6 +159,7 @@ def get_proposers(
self.adapt_mass_matrix,
self.multinomial_sampling,
self.target_accept_prob,
self.nnc_compile,
)
proposers.append(self._proposers[node])
return proposers

0 comments on commit 7817d4a

Please sign in to comment.