Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Add constraint comparison utility to world.utils. (#542)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #542

This adds an `isinstance` like check for `torch.constraint` objects/classes, so that these comparisons can be consolidated within a single utility function `is_constraint_eq`.

Usage:
```
is_constraint_eq(dist.support, (constraints.real, constraints.greater_than))
```

, instead of:

```
isinstance(dist.support, (constraints._Real, constraints._GreaterThan))
```
, or the more obfuscatory

```
dist.support is constraints.real or isinstance(dist.support, constraints.greater_than)
```

See pytorch/pytorch#50616 for more details (note that the changes suggested in the issue are complementary).

 - This avoids usage of the non-public constraint classes (like `constraints._Real`, `constraints._Interval`).
 - Makes it possible to consolidate future changes (e.g. those arising out of the introduction of an `Independent` constraint - pytorch/pytorch#50547) within a single function.

This is a pre-requisite to some other fixes that are currently blocking D25918330. I will add these small fixes when I merge D25918330.

Differential Revision: D25935106

fbshipit-source-id: aa2a63fbc5d550ba1ed8a9abd772513e11bc2437
  • Loading branch information
Neeraj Pradhan authored and facebook-github-bot committed Jan 19, 2021
1 parent 1bedc9a commit bba916b
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 56 deletions.
21 changes: 13 additions & 8 deletions src/beanmachine/ppl/experimental/inference_compilation/ic_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from ...model.rv_identifier import RVIdentifier
from ...world import ProposalDistribution, Variable, World
from ...world.utils import is_constraint_eq
from . import utils


Expand Down Expand Up @@ -476,10 +477,14 @@ def _proposal_distribution_for_node(
f"Encountered node={node} with dim={ndim}"
)

if (
isinstance(support, dist.constraints._Real)
or isinstance(support, dist.constraints._Simplex)
or isinstance(support, dist.constraints._GreaterThan)
if is_constraint_eq(
support,
(
dist.constraints.real,
dist.constraints.real_vector,
dist.constraints.simplex,
dist.constraints.greater_than,
),
):
k = self._GMM_NUM_COMPONENTS
if ndim == 0:
Expand Down Expand Up @@ -511,12 +516,12 @@ def _func(x):
return dist.MixtureSameFamily(mix, comp)

return (k + 2 * k * d, _func)
elif isinstance(support, dist.constraints._IntegerInterval) and isinstance(
distribution, dist.Categorical
):
elif is_constraint_eq(
support, dist.constraints.integer_interval
) and isinstance(distribution, dist.Categorical):
num_categories = distribution.param_shape[-1]
return (num_categories, lambda x: dist.Categorical(logits=x))
elif isinstance(support, dist.constraints._Boolean) and isinstance(
elif is_constraint_eq(support, dist.constraints.boolean) and isinstance(
distribution, dist.Bernoulli
):
return (1, lambda x: dist.Bernoulli(logits=x.item()))
Expand Down
20 changes: 12 additions & 8 deletions src/beanmachine/ppl/inference/compositional_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.model.utils import get_wrapper
from beanmachine.ppl.world.utils import is_constraint_eq


class CompositionalInference(AbstractMHInference):
Expand Down Expand Up @@ -82,17 +83,20 @@ def find_best_single_site_proposer(self, node: RVIdentifier):
# pyre-fixme
distribution = node_var.distribution
support = distribution.support
if (
isinstance(support, dist.constraints._Real)
or isinstance(support, dist.constraints._Simplex)
or isinstance(support, dist.constraints._GreaterThan)
if is_constraint_eq(
support,
(
dist.constraints.real,
dist.constraints.simplex,
dist.constraints.greater_than,
),
):
self.proposers_per_rv_[node] = SingleSiteNewtonianMonteCarloProposer()
elif isinstance(support, dist.constraints._IntegerInterval) and isinstance(
distribution, dist.Categorical
):
elif is_constraint_eq(
support, dist.constraints.integer_interval
) and isinstance(distribution, dist.Categorical):
self.proposers_per_rv_[node] = SingleSiteUniformProposer()
elif isinstance(support, dist.constraints._Boolean) and isinstance(
elif is_constraint_eq(support, dist.constraints.boolean) and isinstance(
distribution, dist.Bernoulli
):
self.proposers_per_rv_[node] = SingleSiteUniformProposer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq
from beanmachine.ppl.world.variable import TransformType
from torch import Tensor

Expand Down Expand Up @@ -84,18 +85,20 @@ def get_proposal_distribution(
node_distribution_support = node_var.distribution.support
if world.get_transforms_for_node(
node
).transform_type != TransformType.NONE or isinstance(
node_distribution_support, dist.constraints._Real
).transform_type != TransformType.NONE or is_constraint_eq(
node_distribution_support, dist.constraints.real
):
self.proposers_[node] = SingleSiteRealSpaceNewtonianMonteCarloProposer(
self.alpha_, self.beta_
)

elif isinstance(node_distribution_support, dist.constraints._GreaterThan):
elif is_constraint_eq(
node_distribution_support, dist.constraints.greater_than
):
self.proposers_[node] = SingleSiteHalfSpaceNewtonianMonteCarloProposer()

elif isinstance(
node_distribution_support, dist.constraints._Simplex
elif is_constraint_eq(
node_distribution_support, dist.constraints.simplex
) or isinstance(node_var.distribution, dist.Beta):
self.proposers_[node] = SingleSiteSimplexNewtonianMonteCarloProposer()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq
from beanmachine.ppl.world.variable import TransformType
from torch import Tensor, tensor

Expand Down Expand Up @@ -98,10 +99,10 @@ def get_proposal_distribution(
# for now, assume all transforms will transform distributions into the realspace
if world.get_transforms_for_node(
node
).transform_type != TransformType.NONE or isinstance(
).transform_type != TransformType.NONE or is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints._Real,
dist.constraints.real,
):
return (
ProposalDistribution(
Expand All @@ -115,9 +116,10 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(
node_distribution.support, dist.constraints._GreaterThan
) or isinstance(node_distribution.support, dist.constraints._GreaterThan):
elif is_constraint_eq(
node_distribution.support,
(dist.constraints.greater_than, dist.constraints.greater_than),
):
lower_bound = node_distribution.support.lower_bound
proposal_distribution = self.gamma_distbn_from_moments(
node_var.value - lower_bound, self.step_size ** 2
Expand All @@ -136,7 +138,7 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(node_distribution.support, dist.constraints._Interval):
elif is_constraint_eq(node_distribution.support, dist.constraints.interval):
lower_bound = node_distribution.support.lower_bound
width = node_distribution.support.upper_bound - lower_bound
# Compute first and second moments of the perturbation distribution
Expand All @@ -158,7 +160,7 @@ def get_proposal_distribution(
),
{},
)
elif isinstance(node_distribution.support, dist.constraints._Simplex):
elif is_constraint_eq(node_distribution.support, dist.constraints.simplex):
proposal_distribution = self.dirichlet_distbn_from_moments(
node_var.value, self.step_size
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from beanmachine.ppl.model.rv_identifier import RVIdentifier
from beanmachine.ppl.world import ProposalDistribution, Variable, World
from beanmachine.ppl.world.utils import is_constraint_eq


class SingleSiteUniformProposer(SingleSiteAncestralProposer):
Expand Down Expand Up @@ -40,10 +41,10 @@ def get_proposal_distribution(
"""
node_distribution = node_var.distribution
if (
isinstance(
is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints._Boolean,
dist.constraints.boolean,
)
and isinstance(node_distribution, dist.Bernoulli)
):
Expand All @@ -58,8 +59,8 @@ def get_proposal_distribution(
),
{},
)
if isinstance(
node_distribution.support, dist.constraints._IntegerInterval
if is_constraint_eq(
node_distribution.support, dist.constraints.integer_interval
) and isinstance(node_distribution, dist.Categorical):
probs = torch.ones(node_distribution.param_shape)
# In Categorical distrbution, the samples are integers from 0-k
Expand Down
62 changes: 45 additions & 17 deletions src/beanmachine/ppl/world/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates
from typing import List
from collections.abc import Iterable
from typing import Iterable as IterableType, List, Type, Union

import torch
import torch.distributions as dist
Expand All @@ -9,6 +10,9 @@
from torch.distributions.transforms import Transform


ConstraintType = Union[constraints.Constraint, Type]


class BetaDimensionTransform(Transform):
bijective = True

Expand Down Expand Up @@ -42,16 +46,40 @@ def is_discrete(distribution: Distribution) -> bool:
:returns: a boolean that is true if the distribution is discrete and false
otherwise
"""
# pyre-fixme
support = distribution.support
if (
isinstance(support, constraints._Boolean)
or isinstance(support, constraints._IntegerGreaterThan)
or isinstance(support, constraints._IntegerInterval)
or isinstance(support, constraints._IntegerLessThan)
):
return True
return False
return is_constraint_eq(
# pyre-fixme
distribution.support,
(
constraints.boolean,
constraints.integer_interval,
constraints._IntegerGreaterThan,
constraints._IntegerLessThan,
),
)


def _unwrap(constraint: ConstraintType):
return constraint if isinstance(constraint, type) else constraint.__class__


def is_constraint_eq(
constraint, check_constraints: Union[ConstraintType, IterableType[ConstraintType]]
) -> bool:
"""
This provides an `isinstance` like check that works for different constraints
specified in :mod:`torch.distributions.constraints`. Returns `True` if the
given `constraint` matches one of the constraints in `check_constraints`.
:param constraint: A constraint class or instance.
:param check_constraints: An iterable containing constraint classes or instances
to check against.
:returns: bool value indicating if the check is successful.
"""
if isinstance(check_constraints, Iterable):
union = {_unwrap(c) for c in check_constraints}
else:
union = {_unwrap(check_constraints)}
return _unwrap(constraint) in union


def get_default_transforms(distribution: Distribution) -> List:
Expand All @@ -69,10 +97,10 @@ def get_default_transforms(distribution: Distribution) -> List:
sample = distribution.sample()
if is_discrete(distribution):
return []
elif isinstance(support, constraints._Real):
elif is_constraint_eq(support, constraints.real):
return []

elif isinstance(support, constraints._Interval):
elif is_constraint_eq(support, constraints.interval):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
lower_bound = tensor(lower_bound, dtype=sample.dtype)
Expand All @@ -87,8 +115,8 @@ def get_default_transforms(distribution: Distribution) -> List:

return [lower_bound_zero, upper_bound_one, beta_dimension, stick_breaking]

elif isinstance(support, constraints._GreaterThan) or isinstance(
support, constraints._GreaterThanEq
elif is_constraint_eq(support, constraints.greater_than) or isinstance(
support, constraints.greater_than_eq
):
lower_bound = support.lower_bound
if not isinstance(lower_bound, Tensor):
Expand All @@ -98,7 +126,7 @@ def get_default_transforms(distribution: Distribution) -> List:

return [lower_bound_zero, log_transform]

elif isinstance(support, constraints._LessThan):
elif is_constraint_eq(support, constraints.less_than):
upper_bound = support.upper_bound
if not isinstance(upper_bound, Tensor):
upper_bound = tensor(upper_bound, dtype=sample.dtype)
Expand All @@ -109,7 +137,7 @@ def get_default_transforms(distribution: Distribution) -> List:

return [upper_bound_zero, flip_to_greater, log_transform]

elif isinstance(support, constraints._Simplex):
elif is_constraint_eq(support, constraints.simplex):
return [dist.StickBreakingTransform().inv]

return []
15 changes: 8 additions & 7 deletions src/beanmachine/ppl/world/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from beanmachine.ppl.world.utils import (
BetaDimensionTransform,
get_default_transforms,
is_constraint_eq,
is_discrete,
)
from torch import Tensor
Expand Down Expand Up @@ -135,30 +136,30 @@ def initialize_value(
return obs
if initialize_from_prior:
return sample_val
elif isinstance(support, dist.constraints._Real):
elif is_constraint_eq(support, dist.constraints.real):
return torch.zeros(sample_val.shape, dtype=sample_val.dtype)
elif isinstance(support, dist.constraints._Simplex):
elif is_constraint_eq(support, dist.constraints.simplex):
value = torch.ones(sample_val.shape, dtype=sample_val.dtype)
return value / sample_val.shape[-1]
elif isinstance(support, dist.constraints._GreaterThan):
elif is_constraint_eq(support, dist.constraints.greater_than):
return (
torch.ones(sample_val.shape, dtype=sample_val.dtype)
+ support.lower_bound
)
elif isinstance(support, dist.constraints._Boolean):
elif is_constraint_eq(support, dist.constraints.boolean):
return dist.Bernoulli(torch.ones(sample_val.shape) / 2).sample()
elif isinstance(support, dist.constraints._Interval):
elif is_constraint_eq(support, dist.constraints.interval):
lower_bound = torch.ones(sample_val.shape) * support.lower_bound
upper_bound = torch.ones(sample_val.shape) * support.upper_bound
return dist.Uniform(lower_bound, upper_bound).sample()
elif isinstance(support, dist.constraints._IntegerInterval):
elif is_constraint_eq(support, dist.constraints.integer_interval):
integer_interval = support.upper_bound - support.lower_bound
return dist.Categorical(
(torch.ones(integer_interval)).expand(
sample_val.shape + (integer_interval,)
)
).sample()
elif isinstance(support, dist.constraints._IntegerGreaterThan):
elif is_constraint_eq(support, dist.constraints._IntegerGreaterThan):
return (
torch.ones(sample_val.shape, dtype=sample_val.dtype)
+ support.lower_bound
Expand Down

0 comments on commit bba916b

Please sign in to comment.