This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #1415 This diff introduces MiniBM, a minimal implementation of Bean Machine in a little bit more than 100 lines of code. The script comes with an implementation of the Metropolis Hastings algorithm and a coin flipping model at the end. It is standalone, in that MiniBM does not depend on the Bean Machine framework at all. To try it out, you can simply download `minibm.py` and run it with ``` python minibm.py ``` The only two dependencies for MiniBM are the PyTorch library and tqdm (for progress bar). The goal of this file is to help developers get familiar with key Bean Machine concepts (such as `World` and `random_variable`), instead of providing a performant implementation. Reviewed By: jpchen Differential Revision: D27111773 fbshipit-source-id: 5f7deeb79a409484b9f7ad59ac9c73d21a091417
- Loading branch information
1 parent
7f88c0d
commit c4a719d
Showing
2 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import itertools | ||
import random | ||
from collections import defaultdict | ||
from functools import wraps | ||
from typing import NamedTuple, Callable, Tuple | ||
|
||
import torch | ||
import torch.distributions as dist | ||
from tqdm.auto import tqdm | ||
|
||
WORLD_STACK = [] | ||
|
||
|
||
class RVIdentifier(NamedTuple): | ||
wrapper: Callable | ||
args: Tuple | ||
|
||
@property | ||
def function(self): | ||
return self.wrapper.__wrapped__ # calls the original function | ||
|
||
|
||
class World: | ||
def __init__(self, observations=None): | ||
self.observations = observations or {} | ||
self.variables = {} | ||
|
||
def __getitem__(self, node): | ||
return self.variables[node] | ||
|
||
def __enter__(self): | ||
WORLD_STACK.append(self) | ||
|
||
def __exit__(self, *args): | ||
WORLD_STACK.pop() | ||
|
||
def update_graph(self, node): | ||
if node not in self.variables: | ||
# parent nodes will be invoked when calling node.get_distribution | ||
distribution = self.invoke(node) | ||
if node in self.observations: | ||
self.variables[node] = self.observations[node] | ||
else: | ||
self.variables[node] = distribution.sample() | ||
return self.variables[node] | ||
|
||
def replace(self, values): | ||
new_world = World(self.observations) | ||
new_world.variables = {**self.variables, **values} | ||
return new_world | ||
|
||
def log_prob(self): | ||
log_prob = torch.tensor(0.0) | ||
for node, value in self.variables.items(): | ||
log_prob += self.invoke(node).log_prob(value).sum() | ||
return log_prob | ||
|
||
def invoke(self, node): | ||
# return the distribution at node conditioned on the rest of values in world | ||
with self: | ||
return node.function(*node.args) | ||
|
||
|
||
def random_variable(f): | ||
@wraps(f) | ||
def wrapper(*args): | ||
rvid = RVIdentifier(wrapper, args) | ||
if len(WORLD_STACK) > 0: | ||
# return the value of random variable if it is invoked under | ||
# an active world context | ||
return WORLD_STACK[-1].update_graph(rvid) | ||
# return an ID for the random variable | ||
return rvid | ||
|
||
return wrapper | ||
|
||
|
||
def initialize_world(queries, observations): | ||
world = World(observations) | ||
# recursively invoke parents node to construct the graph | ||
for node in itertools.chain(queries, observations): | ||
world.update_graph(node) | ||
return world | ||
|
||
|
||
class MetropolisHastings: | ||
def infer( | ||
self, | ||
queries, | ||
observations, | ||
num_samples, | ||
): | ||
world = initialize_world(queries, observations) | ||
|
||
samples = defaultdict(list) | ||
for _ in tqdm(range(num_samples)): | ||
latent_nodes = world.variables.keys() - world.observations.keys() | ||
random.shuffle(latent_nodes) | ||
for node in latent_nodes: | ||
proposer_distribution = world.invoke(node) | ||
new_value = proposer_distribution.sample() | ||
new_world = world.replace({node: new_value}) | ||
backward_dist = new_world.invoke(node) | ||
|
||
# log P(x, y) | ||
old_log_prob = world.log_prob() | ||
# log P(x', y) | ||
new_log_prob = new_world.log_prob() | ||
# log g(x'|x) | ||
forward_log_prob = proposer_distribution.log_prob(new_value).sum() | ||
# log g(x|x') | ||
backward_log_prob = backward_dist.log_prob(world[node]).sum() | ||
|
||
accept_log_prob = ( | ||
new_log_prob + backward_log_prob - old_log_prob - forward_log_prob | ||
) | ||
if torch.bernoulli(accept_log_prob.exp().clamp(max=1)): | ||
world = new_world | ||
|
||
for node in queries: | ||
samples[node].append(world[node]) | ||
samples = {node: torch.stack(samples[node]) for node in samples} | ||
return samples | ||
|
||
|
||
if __name__ == "__main__": | ||
# coin fliping model adapted from our tutorial | ||
# (https://beanmachine.org/docs/overview/tutorials/Coin_flipping/CoinFlipping/) | ||
@random_variable | ||
def weight(): | ||
return dist.Beta(2, 2) | ||
|
||
@random_variable | ||
def y(): | ||
return dist.Bernoulli(weight()).expand((N,)) | ||
|
||
# data generation | ||
true_weight = 0.75 | ||
true_y = dist.Bernoulli(true_weight) | ||
N = 100 | ||
y_obs = true_y.sample((N,)) | ||
|
||
print("Head rate:", y_obs.mean()) | ||
|
||
# running inference | ||
samples = MetropolisHastings().infer([weight()], {y(): y_obs}, num_samples=500) | ||
print("Estimated weight of the coin:", samples[weight()].mean()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters