From ade424f3d4177292c1b1a4f55f258ec56780949f Mon Sep 17 00:00:00 2001 From: Xiaoyan Wang Date: Thu, 21 Apr 2022 20:54:55 -0700 Subject: [PATCH] MiniBM (#1415) Summary: Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1415 This diff introduces MiniBM, a minimal implementation of Bean Machine in a little bit more than 100 lines of code. The script comes with an implementation of the Metropolis Hastings algorithm and a coin flipping model at the end. It is standalone, in that MiniBM does not depend on the Bean Machine framework at all. To try it out, you can simply download `minibm.py` and run it with ``` python minibm.py ``` The only two dependencies for MiniBM are the PyTorch library and tqdm (for progress bar). The goal of this file is to help developers get familiar with key Bean Machine concepts (such as `World` and `random_variable`), instead of providing a performant implementation. Reviewed By: jpchen Differential Revision: D27111773 fbshipit-source-id: 8ca261af4b8a9ed3546b7296490f8e9b3d6e3db7 --- src/beanmachine/minibm.py | 155 ++++++++++++++++++++++++++++++++++++++ website/Makefile | 2 +- 2 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 src/beanmachine/minibm.py diff --git a/src/beanmachine/minibm.py b/src/beanmachine/minibm.py new file mode 100644 index 0000000000..f5181a8a5e --- /dev/null +++ b/src/beanmachine/minibm.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import random +from collections import defaultdict +from functools import wraps +from typing import NamedTuple, Callable, Tuple + +import torch +import torch.distributions as dist +from tqdm.auto import tqdm + +WORLD_STACK = [] + + +class RVIdentifier(NamedTuple): + wrapper: Callable + args: Tuple + + @property + def function(self): + return self.wrapper.__wrapped__ # calls the original function + + +class World: + def __init__(self, observations=None): + self.observations = observations or {} + self.variables = {} + + def __getitem__(self, node): + return self.variables[node] + + def __enter__(self): + WORLD_STACK.append(self) + + def __exit__(self, *args): + WORLD_STACK.pop() + + @property + def latent_nodes(self): + return self.variables.keys() - self.observations.keys() + + def update_graph(self, node): + if node not in self.variables: + # parent nodes will be invoked when calling node.get_distribution + distribution = self.invoke(node) + if node in self.observations: + self.variables[node] = self.observations[node] + else: + self.variables[node] = distribution.sample() + return self.variables[node] + + def replace(self, values): + new_world = World(self.observations) + new_world.variables = {**self.variables, **values} + return new_world + + def log_prob(self): + log_prob = torch.tensor(0.0) + for node, value in self.variables.items(): + log_prob += self.invoke(node).log_prob(value).sum() + return log_prob + + def invoke(self, node): + # return the distribution at node conditioned on the rest of values in world + with self: + return node.function(*node.args) + + +def random_variable(f): + @wraps(f) + def wrapper(*args): + rvid = RVIdentifier(wrapper, args) + if len(WORLD_STACK) > 0: + # return the value of random variable if it is invoked under + # an active world context + return WORLD_STACK[-1].update_graph(rvid) + # return an ID for the random variable + return rvid + + return wrapper + + +class MH: + @staticmethod + def initialize_world(queries, observations): + world = World(observations) + # recursively invoke parents node to construct the graph + for node in itertools.chain(queries, observations): + world.update_graph(node) + return world + + @staticmethod + def infer( + queries, + observations, + num_samples, + ): + world = MH.initialize_world(queries, observations) + + samples = defaultdict(list) + for _ in tqdm(range(num_samples)): + randomized_nodes = world.latent_nodes + random.shuffle(randomized_nodes) + for node in randomized_nodes: + proposer_distribution = world.invoke(node) + new_value = proposer_distribution.sample() + new_world = world.replace({node: new_value}) + backward_dist = new_world.invoke(node) + + # log P(x, y) + old_log_prob = world.log_prob() + # log P(x', y) + new_log_prob = new_world.log_prob() + # log g(x'|x) + forward_log_prob = proposer_distribution.log_prob(new_value).sum() + # log g(x|x') + backward_log_prob = backward_dist.log_prob(world[node]).sum() + + accept_log_prob = ( + new_log_prob + backward_log_prob - old_log_prob - forward_log_prob + ) + if torch.bernoulli(accept_log_prob.exp().clamp(max=1)): + world = new_world + + for node in queries: + samples[node].append(world[node]) + samples = {node: torch.stack(samples[node]) for node in samples} + return samples + + +if __name__ == "__main__": + # coin fliping model adapted from our tutorial (https://beanmachine.org/docs/overview/tutorials/Coin_flipping/CoinFlipping/) + @random_variable + def theta(): + return dist.Beta(2, 2) + + @random_variable + def y(): + return dist.Bernoulli(theta()).expand((N,)) + + # data generation + true_theta = 0.75 + true_y = dist.Bernoulli(true_theta) + N = 100 + y_obs = true_y.sample((N,)) + + print("Empirical mean:", y_obs.mean()) + + # running inference + samples = MH.infer([theta()], {y(): y_obs}, num_samples=1000) + print("Sample mean", samples[theta()].mean()) diff --git a/website/Makefile b/website/Makefile index fc689429ec..916900cb13 100644 --- a/website/Makefile +++ b/website/Makefile @@ -6,7 +6,7 @@ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = ./sphinx MODULEDIR = ../src/beanmachine -MODULEIGNOREDIRS = ../src/beanmachine/{applications,graph,tutorials}* ../src/beanmachine/ppl/{conftest.py,compiler,diagnostics,examples,experimental,inference/utils,legacy,testlib,utils} +MODULEIGNOREDIRS = ../src/beanmachine/{applications,graph,minibm,tutorials}* ../src/beanmachine/ppl/{conftest.py,compiler,diagnostics,examples,experimental,inference/utils,legacy,testlib,utils} BUILDDIR = ./static ALLSPHINXOPTS = -q -d $(BUILDDIR)/doctrees $(SPHINXOPTS) $(SOURCEDIR)