Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Adding experimental nnc_compile option to NUTS and HMC #1385

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 20 additions & 19 deletions setup.py
Expand Up @@ -16,7 +16,24 @@
REQUIRED_MAJOR = 3
REQUIRED_MINOR = 7


INSTALL_REQUIRES = [
"arviz>=0.11.0",
"astor>=0.7.1",
"botorch>=0.5.1",
"flowtorch>=0.3",
"gpytorch>=1.3.0",
"graphviz>=0.17",
"numpy>=1.18.1",
"pandas>=0.24.2",
"parameterized>=0.8.1",
"plotly>=2.2.1",
"scipy>=0.16",
"statsmodels>=0.12.0",
"torch>=1.9.0",
"tqdm>=4.46.0",
"typing-extensions>=3.10",
"xarray>=0.16.0",
]
TEST_REQUIRES = ["pytest>=7.0.0", "pytest-cov"]
TUTORIALS_REQUIRES = [
"bokeh",
Expand Down Expand Up @@ -48,6 +65,7 @@
CPP_COMPILE_ARGS = ["/WX", "/permissive-", "-DEIGEN_HAS_C99_MATH"]
else:
CPP_COMPILE_ARGS = ["-std=c++17", "-Werror"]
INSTALL_REQUIRES.append("functorch>=0.1.0")


# Check for python version
Expand Down Expand Up @@ -125,24 +143,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">={}.{}".format(REQUIRED_MAJOR, REQUIRED_MINOR),
install_requires=[
"arviz>=0.11.0",
"astor>=0.7.1",
"botorch>=0.5.1",
"flowtorch>=0.3",
"gpytorch>=1.3.0",
"graphviz>=0.17",
"numpy>=1.18.1",
"pandas>=0.24.2",
"parameterized>=0.8.1",
"plotly>=2.2.1",
"scipy>=0.16",
"statsmodels>=0.12.0",
"torch>=1.9.0",
"tqdm>=4.46.0",
"typing-extensions>=3.10",
"xarray>=0.16.0",
],
install_requires=INSTALL_REQUIRES,
packages=find_packages("src"),
package_dir={"": "src"},
ext_modules=[
Expand Down
2 changes: 1 addition & 1 deletion sphinx/Makefile
Expand Up @@ -12,7 +12,7 @@ ALLSPHINXOPTS = -q -d $(BUILDDIR)/doctrees $(SPHINXOPTS) ./source
# NOTE: Since we cannot ignore the no-duplicates warning specifically
# we use a Bash hack for the same purpose. Note, ! is to invert error
# status of grep
BASHHACK = 2>&1 >/dev/null | grep -v -E "(use :noindex:)|(more than one target found)"
BASHHACK = 2>&1 >/dev/null | grep -v -E "(use :noindex:)|(more than one target found)|(warnings.warn)|(experimental)"


# Put it first so that "make" without argument is like "make help".
Expand Down
40 changes: 40 additions & 0 deletions src/beanmachine/ppl/experimental/nnc/__init__.py
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from typing import Callable, TypeVar, Optional, Tuple

from typing_extensions import ParamSpec

P = ParamSpec("P")
R = TypeVar("R")


def nnc_jit(
f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None
) -> Callable[P, R]:
"""
A helper function that lazily imports the NNC utils, which initialize the compiler
and displaying a experimental warning, then invoke the underlying nnc_jit on
the function f.
"""
try:
# The setup code in `nnc.utils` will only be executed once in a Python session
from beanmachine.ppl.experimental.nnc.utils import nnc_jit as raw_nnc_jit
except ImportError as e:
if sys.platform.startswith("win"):
message = "functorch is not available on Windows."
else:
message = (
"Fails to initialize NNC. This is likely caused by version mismatch "
"between PyTorch and functorch. Please checkout the functorch project "
"for installation guide (https://github.com/pytorch/functorch)."
)
raise RuntimeError(message) from e

return raw_nnc_jit(f, static_argnums)


__all__ = ["nnc_jit"]
93 changes: 93 additions & 0 deletions src/beanmachine/ppl/experimental/nnc/utils.py
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import functorch
import torch
import torch.jit
import torch.utils._pytree as pytree
from functorch.compile import (
nop,
aot_function,
decomposition_table,
register_decomposition,
)

# the warning will only be shown to user once when this module is imported
warnings.warn(
"The support of NNC compiler is experimental and the API is subject to"
"change in the future releases of Bean Machine. For questions regarding NNC, please"
"checkout the functorch project (https://github.com/pytorch/functorch)."
)

# allows reductions to be compiled by NNC
torch._C._jit_set_texpr_reductions_enabled(True)

# override the usage of torch.jit.script, which has a bit of issue handling
# empty lists (functorch#440)
def simple_ts_compile(fx_g, example_inps):
f = torch.jit.trace(fx_g, example_inps, strict=False)
f = torch.jit.freeze(f.eval())
torch._C._jit_pass_remove_mutation(f.graph)

return f


# Overrides decomposition rules for some operators
aten = torch.ops.aten
decompositions = [aten.detach]
bm_decompositions = {
k: v for k, v in decomposition_table.items() if k in decompositions
}


@register_decomposition(aten.mv, bm_decompositions)
def mv(a, b):
return (a * b).sum(dim=-1)


@register_decomposition(aten.dot, bm_decompositions)
def dot(a, b):
return (a * b).sum(dim=-1)


@register_decomposition(aten.zeros_like, bm_decompositions)
def zeros_like(a, **kwargs):
return a * 0


@register_decomposition(aten.ones_like, bm_decompositions)
def ones_like(a, **kwargs):
return a * 0 + 1


def nnc_jit(f, static_argnums=None):
return aot_function(
f,
simple_ts_compile,
nop,
static_argnums=static_argnums,
decompositions=bm_decompositions,
)


functorch._src.compilers.simple_ts_compile = simple_ts_compile


# override default dict flatten (which requires keys to be sortable)
def _dict_flatten(d):
keys = list(d.keys())
values = [d[key] for key in keys]
return values, keys


def _dict_unflatten(values, context):
return {key: value for key, value in zip(context, values)}


pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten)

__all__ = ["nnc_jit"]
52 changes: 52 additions & 0 deletions src/beanmachine/ppl/experimental/tests/nnc_test.py
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
import warnings

import beanmachine.ppl as bm
import pytest
import torch
import torch.distributions as dist

if sys.platform.startswith("win"):
pytest.skip("functorch is not available on Windows", allow_module_level=True)


class SampleModel:
@bm.random_variable
def foo(self):
return dist.Normal(0.0, 1.0)

@bm.random_variable
def bar(self):
return dist.Normal(self.foo(), 1.0)


@pytest.mark.parametrize(
"algorithm",
[
bm.GlobalNoUTurnSampler(nnc_compile=True),
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True),
],
)
def test_nnc_compile(algorithm):
model = SampleModel()
queries = [model.foo()]
observations = {model.bar(): torch.tensor(0.5)}
num_samples = 30
num_chains = 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# verify that NNC can run through
samples = algorithm.infer(
queries,
observations,
num_samples,
num_adaptive_samples=num_samples,
num_chains=num_chains,
)
# sanity check: make sure that the samples are valid
assert not torch.isnan(samples[model.foo()]).any()
8 changes: 8 additions & 0 deletions src/beanmachine/ppl/inference/hmc_inference.py
Expand Up @@ -30,6 +30,8 @@ class GlobalHamiltonianMonteCarlo(BaseInference):
adapt_mass_matrix (bool): Whether to adapt the mass matrix. Defaults to True,
target_accept_prob (float): Target accept prob. Increasing this value would lead
to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -39,12 +41,14 @@ def __init__(
adapt_step_size: bool = True,
adapt_mass_matrix: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.trajectory_length = trajectory_length
self.initial_step_size = initial_step_size
self.adapt_step_size = adapt_step_size
self.adapt_mass_matrix = adapt_mass_matrix
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposer = None

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -66,6 +70,7 @@ def get_proposers(
self.adapt_step_size,
self.adapt_mass_matrix,
self.target_accept_prob,
self.nnc_compile,
)
return [self._proposer]

Expand Down Expand Up @@ -94,12 +99,14 @@ def __init__(
adapt_step_size: bool = True,
adapt_mass_matrix: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.trajectory_length = trajectory_length
self.initial_step_size = initial_step_size
self.adapt_step_size = adapt_step_size
self.adapt_mass_matrix = adapt_mass_matrix
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposers = {}

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -123,6 +130,7 @@ def get_proposers(
self.adapt_step_size,
self.adapt_mass_matrix,
self.target_accept_prob,
self.nnc_compile,
)
proposers.append(self._proposers[node])
return proposers
10 changes: 10 additions & 0 deletions src/beanmachine/ppl/inference/nuts_inference.py
Expand Up @@ -38,6 +38,8 @@ class GlobalNoUTurnSampler(BaseInference):
defaults to True.
target_accept_prob (float): Target accept probability. Increasing this would
lead to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -49,6 +51,7 @@ def __init__(
adapt_mass_matrix: bool = True,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.max_tree_depth = max_tree_depth
self.max_delta_energy = max_delta_energy
Expand All @@ -57,6 +60,7 @@ def __init__(
self.adapt_mass_matrix = adapt_mass_matrix
self.multinomial_sampling = multinomial_sampling
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposer = None

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -80,6 +84,7 @@ def get_proposers(
self.adapt_mass_matrix,
self.multinomial_sampling,
self.target_accept_prob,
self.nnc_compile,
)
return [self._proposer]

Expand All @@ -106,6 +111,8 @@ class SingleSiteNoUTurnSampler(BaseInference):
defaults to True.
target_accept_prob (float): Target accept probability. Increasing this would
lead to smaller step size. Defaults to 0.8.
nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the
inference (defaults to False).
"""

def __init__(
Expand All @@ -117,6 +124,7 @@ def __init__(
adapt_mass_matrix: bool = True,
multinomial_sampling: bool = True,
target_accept_prob: float = 0.8,
nnc_compile: bool = False,
):
self.max_tree_depth = max_tree_depth
self.max_delta_energy = max_delta_energy
Expand All @@ -125,6 +133,7 @@ def __init__(
self.adapt_mass_matrix = adapt_mass_matrix
self.multinomial_sampling = multinomial_sampling
self.target_accept_prob = target_accept_prob
self.nnc_compile = nnc_compile
self._proposers = {}

def _get_default_num_adaptive_samples(self, num_samples: int) -> int:
Expand All @@ -150,6 +159,7 @@ def get_proposers(
self.adapt_mass_matrix,
self.multinomial_sampling,
self.target_accept_prob,
self.nnc_compile,
)
proposers.append(self._proposers[node])
return proposers