Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Add constraint comparison utility to world.utils. #542

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/beanmachine/ppl/experimental/inference_compilation/ic_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from ...model.rv_identifier import RVIdentifier
from ...world import ProposalDistribution, Variable, World
from ...world.utils import is_constraint_eq
from . import utils


Expand Down Expand Up @@ -476,10 +477,16 @@ def _proposal_distribution_for_node(
f"Encountered node={node} with dim={ndim}"
)

if (
isinstance(support, dist.constraints._Real)
or isinstance(support, dist.constraints._Simplex)
or isinstance(support, dist.constraints._GreaterThan)
if any(
is_constraint_eq(
support,
(
dist.constraints.real,
dist.constraints.real_vector,
dist.constraints.simplex,
dist.constraints.greater_than,
),
)
):
k = self._GMM_NUM_COMPONENTS
if ndim == 0:
Expand Down Expand Up @@ -511,12 +518,12 @@ def _func(x):
return dist.MixtureSameFamily(mix, comp)

return (k + 2 * k * d, _func)
elif isinstance(support, dist.constraints._IntegerInterval) and isinstance(
distribution, dist.Categorical
):
elif is_constraint_eq(
support, dist.constraints.integer_interval
) and isinstance(distribution, dist.Categorical):
num_categories = distribution.param_shape[-1]
return (num_categories, lambda x: dist.Categorical(logits=x))
elif isinstance(support, dist.constraints._Boolean) and isinstance(
elif is_constraint_eq(support, dist.constraints.boolean) and isinstance(
distribution, dist.Bernoulli
):
return (1, lambda x: dist.Bernoulli(logits=x.item()))
Expand Down
22 changes: 14 additions & 8 deletions src/beanmachine/ppl/inference/compositional_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.model.utils import get_wrapper
from beanmachine.ppl.world.utils import is_constraint_eq


class CompositionalInference(AbstractMHInference):
Expand Down Expand Up @@ -82,17 +83,22 @@ def find_best_single_site_proposer(self, node: RVIdentifier):
# pyre-fixme
distribution = node_var.distribution
support = distribution.support
if (
isinstance(support, dist.constraints._Real)
or isinstance(support, dist.constraints._Simplex)
or isinstance(support, dist.constraints._GreaterThan)
if any(
is_constraint_eq(
support,
(
dist.constraints.real,
dist.constraints.simplex,
dist.constraints.greater_than,
),
)
):
self.proposers_per_rv_[node] = SingleSiteNewtonianMonteCarloProposer()
elif isinstance(support, dist.constraints._IntegerInterval) and isinstance(
distribution, dist.Categorical
):
elif is_constraint_eq(
support, dist.constraints.integer_interval
) and isinstance(distribution, dist.Categorical):
self.proposers_per_rv_[node] = SingleSiteUniformProposer()
elif isinstance(support, dist.constraints._Boolean) and isinstance(
elif is_constraint_eq(support, dist.constraints.boolean) and isinstance(
distribution, dist.Bernoulli
):
self.proposers_per_rv_[node] = SingleSiteUniformProposer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq
from beanmachine.ppl.world.variable import TransformType
from torch import Tensor

Expand Down Expand Up @@ -84,18 +85,20 @@ def get_proposal_distribution(
node_distribution_support = node_var.distribution.support
if world.get_transforms_for_node(
node
).transform_type != TransformType.NONE or isinstance(
node_distribution_support, dist.constraints._Real
).transform_type != TransformType.NONE or is_constraint_eq(
node_distribution_support, dist.constraints.real
):
self.proposers_[node] = SingleSiteRealSpaceNewtonianMonteCarloProposer(
self.alpha_, self.beta_
)

elif isinstance(node_distribution_support, dist.constraints._GreaterThan):
elif is_constraint_eq(
node_distribution_support, dist.constraints.greater_than
):
self.proposers_[node] = SingleSiteHalfSpaceNewtonianMonteCarloProposer()

elif isinstance(
node_distribution_support, dist.constraints._Simplex
elif is_constraint_eq(
node_distribution_support, dist.constraints.simplex
) or isinstance(node_var.distribution, dist.Beta):
self.proposers_[node] = SingleSiteSimplexNewtonianMonteCarloProposer()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq
from beanmachine.ppl.world.variable import TransformType
from torch import Tensor, tensor

Expand Down Expand Up @@ -98,10 +99,10 @@ def get_proposal_distribution(
# for now, assume all transforms will transform distributions into the realspace
if world.get_transforms_for_node(
node
).transform_type != TransformType.NONE or isinstance(
).transform_type != TransformType.NONE or is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints._Real,
dist.constraints.real,
):
return (
ProposalDistribution(
Expand All @@ -115,9 +116,7 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(
node_distribution.support, dist.constraints._GreaterThan
) or isinstance(node_distribution.support, dist.constraints._GreaterThan):
elif is_constraint_eq(node_distribution.support, dist.constraints.greater_than):
lower_bound = node_distribution.support.lower_bound
proposal_distribution = self.gamma_distbn_from_moments(
node_var.value - lower_bound, self.step_size ** 2
Expand All @@ -136,7 +135,7 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(node_distribution.support, dist.constraints._Interval):
elif is_constraint_eq(node_distribution.support, dist.constraints.interval):
lower_bound = node_distribution.support.lower_bound
width = node_distribution.support.upper_bound - lower_bound
# Compute first and second moments of the perturbation distribution
Expand All @@ -158,7 +157,7 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(node_distribution.support, dist.constraints._Simplex):
elif is_constraint_eq(node_distribution.support, dist.constraints.simplex):
proposal_distribution = self.dirichlet_distbn_from_moments(
node_var.value, self.step_size
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq


class SingleSiteUniformProposer(SingleSiteAncestralProposer):
Expand Down Expand Up @@ -40,10 +41,10 @@ def get_proposal_distribution(
"""
node_distribution = node_var.distribution
if (
isinstance(
is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints._Boolean,
dist.constraints.boolean,
)
and isinstance(node_distribution, dist.Bernoulli)
):
Expand All @@ -58,8 +59,8 @@ def get_proposal_distribution(
),
{},
)
if isinstance(
node_distribution.support, dist.constraints._IntegerInterval
if is_constraint_eq(
node_distribution.support, dist.constraints.integer_interval
) and isinstance(node_distribution, dist.Categorical):
probs = torch.ones(node_distribution.param_shape)
# In Categorical distrbution, the samples are integers from 0-k
Expand Down
85 changes: 68 additions & 17 deletions src/beanmachine/ppl/world/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates
from typing import List
from collections.abc import Iterable
from typing import Iterable as IterableType, List, Type, Union, overload

import torch
import torch.distributions as dist
Expand All @@ -9,6 +10,9 @@
from torch.distributions.transforms import Transform


ConstraintType = Union[constraints.Constraint, Type]


class BetaDimensionTransform(Transform):
bijective = True

Expand Down Expand Up @@ -42,16 +46,63 @@ def is_discrete(distribution: Distribution) -> bool:
:returns: a boolean that is true if the distribution is discrete and false
otherwise
"""
# pyre-fixme
support = distribution.support
if (
isinstance(support, constraints._Boolean)
or isinstance(support, constraints._IntegerGreaterThan)
or isinstance(support, constraints._IntegerInterval)
or isinstance(support, constraints._IntegerLessThan)
):
return True
return False
return any(
is_constraint_eq(
# pyre-fixme
distribution.support,
(
constraints.boolean,
constraints.integer_interval,
constraints._IntegerGreaterThan,
constraints._IntegerLessThan,
),
)
)


def _unwrap(constraint: ConstraintType):
return constraint if isinstance(constraint, type) else constraint.__class__


def _is_constraint_eq(constraint1: ConstraintType, constraint2: ConstraintType):
return _unwrap(constraint1) == _unwrap(constraint2)


@overload
def is_constraint_eq(
constraint: ConstraintType, check_constraints: ConstraintType
) -> bool:
...


@overload
def is_constraint_eq(
constraint: ConstraintType, check_constraints: IterableType[ConstraintType]
) -> IterableType[bool]:
...


def is_constraint_eq(
constraint: ConstraintType,
check_constraints: Union[ConstraintType, IterableType[ConstraintType]],
) -> Union[bool, IterableType[bool]]:
"""
This provides an equality check that works for different constraints
specified in :mod:`torch.distributions.constraints`. If `check_constraints`
is a single `Constraint` type or instance this returns a `True` if the
given `constraint` matches `check_constraints`. Otherwise, if
`check_constraints` is an iterable, this returns a `bool` list that
represents an element-wise check.

:param constraint: A constraint class or instance.
:param check_constraints: A constraint class or instance or an iterable
containing constraint classes or instances to check against.
:returns: bool (or a list of bool) values indicating if the given constraint
equals the constraint in `check_constraints`.
"""
if isinstance(check_constraints, Iterable):
return [_is_constraint_eq(constraint, c) for c in check_constraints]
return _is_constraint_eq(constraint, check_constraints)


def get_default_transforms(distribution: Distribution) -> List:
Expand All @@ -69,10 +120,10 @@ def get_default_transforms(distribution: Distribution) -> List:
sample = distribution.sample()
if is_discrete(distribution):
return []
elif isinstance(support, constraints._Real):
elif is_constraint_eq(support, constraints.real):
return []

elif isinstance(support, constraints._Interval):
elif is_constraint_eq(support, constraints.interval):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
lower_bound = tensor(lower_bound, dtype=sample.dtype)
Expand All @@ -87,8 +138,8 @@ def get_default_transforms(distribution: Distribution) -> List:

return [lower_bound_zero, upper_bound_one, beta_dimension, stick_breaking]

elif isinstance(support, constraints._GreaterThan) or isinstance(
support, constraints._GreaterThanEq
elif is_constraint_eq(support, constraints.greater_than) or isinstance(
support, constraints.greater_than_eq
):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
Expand All @@ -98,7 +149,7 @@ def get_default_transforms(distribution: Distribution) -> List:

return [lower_bound_zero, log_transform]

elif isinstance(support, constraints._LessThan):
elif is_constraint_eq(support, constraints.less_than):
upper_bound = support.upper_bound
if not isinstance(upper_bound, Tensor):
upper_bound = tensor(upper_bound, dtype=sample.dtype)
Expand All @@ -109,7 +160,7 @@ def get_default_transforms(distribution: Distribution) -> List:

return [upper_bound_zero, flip_to_greater, log_transform]

elif isinstance(support, constraints._Simplex):
elif is_constraint_eq(support, constraints.simplex):
return [dist.StickBreakingTransform().inv]

return []
15 changes: 8 additions & 7 deletions src/beanmachine/ppl/world/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from beanmachine.ppl.world.utils import (
BetaDimensionTransform,
get_default_transforms,
is_constraint_eq,
is_discrete,
)
from torch import Tensor
Expand Down Expand Up @@ -135,30 +136,30 @@ def initialize_value(
return obs
if initialize_from_prior:
return sample_val
elif isinstance(support, dist.constraints._Real):
elif is_constraint_eq(support, dist.constraints.real):
return torch.zeros(sample_val.shape, dtype=sample_val.dtype)
elif isinstance(support, dist.constraints._Simplex):
elif is_constraint_eq(support, dist.constraints.simplex):
value = torch.ones(sample_val.shape, dtype=sample_val.dtype)
return value / sample_val.shape[-1]
elif isinstance(support, dist.constraints._GreaterThan):
elif is_constraint_eq(support, dist.constraints.greater_than):
return (
torch.ones(sample_val.shape, dtype=sample_val.dtype)
+ support.lower_bound
)
elif isinstance(support, dist.constraints._Boolean):
elif is_constraint_eq(support, dist.constraints.boolean):
return dist.Bernoulli(torch.ones(sample_val.shape) / 2).sample()
elif isinstance(support, dist.constraints._Interval):
elif is_constraint_eq(support, dist.constraints.interval):
lower_bound = torch.ones(sample_val.shape) * support.lower_bound
upper_bound = torch.ones(sample_val.shape) * support.upper_bound
return dist.Uniform(lower_bound, upper_bound).sample()
elif isinstance(support, dist.constraints._IntegerInterval):
elif is_constraint_eq(support, dist.constraints.integer_interval):
integer_interval = support.upper_bound - support.lower_bound
return dist.Categorical(
(torch.ones(integer_interval)).expand(
sample_val.shape + (integer_interval,)
)
).sample()
elif isinstance(support, dist.constraints._IntegerGreaterThan):
elif is_constraint_eq(support, dist.constraints.nonnegative_integer):
return (
torch.ones(sample_val.shape, dtype=sample_val.dtype)
+ support.lower_bound
Expand Down