Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
MiniBM (#1415)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1415

This diff introduces MiniBM, a minimal implementation of Bean Machine in a little bit more than 100 lines of code.

The script comes with an implementation of the Metropolis Hastings algorithm and a coin flipping model at the end. It is standalone, in that MiniBM does not depend on the Bean Machine framework at all. To try it out, you can simply download `minibm.py` and run it with

```
python minibm.py
```

The only two dependencies for MiniBM are the PyTorch library and tqdm (for progress bar).

The goal of this file is to help developers get familiar with key Bean Machine concepts (such as `World` and `random_variable`), instead of providing a performant implementation.

Reviewed By: jpchen

Differential Revision: D27111773

fbshipit-source-id: 8ca261af4b8a9ed3546b7296490f8e9b3d6e3db7
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Apr 22, 2022
1 parent 7f88c0d commit ade424f
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 1 deletion.
155 changes: 155 additions & 0 deletions src/beanmachine/minibm.py
@@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import random
from collections import defaultdict
from functools import wraps
from typing import NamedTuple, Callable, Tuple

import torch
import torch.distributions as dist
from tqdm.auto import tqdm

WORLD_STACK = []


class RVIdentifier(NamedTuple):
wrapper: Callable
args: Tuple

@property
def function(self):
return self.wrapper.__wrapped__ # calls the original function


class World:
def __init__(self, observations=None):
self.observations = observations or {}
self.variables = {}

def __getitem__(self, node):
return self.variables[node]

def __enter__(self):
WORLD_STACK.append(self)

def __exit__(self, *args):
WORLD_STACK.pop()

@property
def latent_nodes(self):
return self.variables.keys() - self.observations.keys()

def update_graph(self, node):
if node not in self.variables:
# parent nodes will be invoked when calling node.get_distribution
distribution = self.invoke(node)
if node in self.observations:
self.variables[node] = self.observations[node]
else:
self.variables[node] = distribution.sample()
return self.variables[node]

def replace(self, values):
new_world = World(self.observations)
new_world.variables = {**self.variables, **values}
return new_world

def log_prob(self):
log_prob = torch.tensor(0.0)
for node, value in self.variables.items():
log_prob += self.invoke(node).log_prob(value).sum()
return log_prob

def invoke(self, node):
# return the distribution at node conditioned on the rest of values in world
with self:
return node.function(*node.args)


def random_variable(f):
@wraps(f)
def wrapper(*args):
rvid = RVIdentifier(wrapper, args)
if len(WORLD_STACK) > 0:
# return the value of random variable if it is invoked under
# an active world context
return WORLD_STACK[-1].update_graph(rvid)
# return an ID for the random variable
return rvid

return wrapper


class MH:
@staticmethod
def initialize_world(queries, observations):
world = World(observations)
# recursively invoke parents node to construct the graph
for node in itertools.chain(queries, observations):
world.update_graph(node)
return world

@staticmethod
def infer(
queries,
observations,
num_samples,
):
world = MH.initialize_world(queries, observations)

samples = defaultdict(list)
for _ in tqdm(range(num_samples)):
randomized_nodes = world.latent_nodes
random.shuffle(randomized_nodes)
for node in randomized_nodes:
proposer_distribution = world.invoke(node)
new_value = proposer_distribution.sample()
new_world = world.replace({node: new_value})
backward_dist = new_world.invoke(node)

# log P(x, y)
old_log_prob = world.log_prob()
# log P(x', y)
new_log_prob = new_world.log_prob()
# log g(x'|x)
forward_log_prob = proposer_distribution.log_prob(new_value).sum()
# log g(x|x')
backward_log_prob = backward_dist.log_prob(world[node]).sum()

accept_log_prob = (
new_log_prob + backward_log_prob - old_log_prob - forward_log_prob
)
if torch.bernoulli(accept_log_prob.exp().clamp(max=1)):
world = new_world

for node in queries:
samples[node].append(world[node])
samples = {node: torch.stack(samples[node]) for node in samples}
return samples


if __name__ == "__main__":
# coin fliping model adapted from our tutorial (https://beanmachine.org/docs/overview/tutorials/Coin_flipping/CoinFlipping/)
@random_variable
def theta():
return dist.Beta(2, 2)

@random_variable
def y():
return dist.Bernoulli(theta()).expand((N,))

# data generation
true_theta = 0.75
true_y = dist.Bernoulli(true_theta)
N = 100
y_obs = true_y.sample((N,))

print("Empirical mean:", y_obs.mean())

# running inference
samples = MH.infer([theta()], {y(): y_obs}, num_samples=1000)
print("Sample mean", samples[theta()].mean())
2 changes: 1 addition & 1 deletion website/Makefile
Expand Up @@ -6,7 +6,7 @@ SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = ./sphinx
MODULEDIR = ../src/beanmachine
MODULEIGNOREDIRS = ../src/beanmachine/{applications,graph,tutorials}* ../src/beanmachine/ppl/{conftest.py,compiler,diagnostics,examples,experimental,inference/utils,legacy,testlib,utils}
MODULEIGNOREDIRS = ../src/beanmachine/{applications,graph,minibm,tutorials}* ../src/beanmachine/ppl/{conftest.py,compiler,diagnostics,examples,experimental,inference/utils,legacy,testlib,utils}
BUILDDIR = ./static
ALLSPHINXOPTS = -q -d $(BUILDDIR)/doctrees $(SPHINXOPTS) $(SOURCEDIR)

Expand Down

0 comments on commit ade424f

Please sign in to comment.