Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
suppress errors in beanmachine/beanmachine/ppl
Browse files Browse the repository at this point in the history
Differential Revision: D33201580

fbshipit-source-id: 139d7ba56f8488da5192f44ce4e3f49262e4cb61
  • Loading branch information
Pyre Bot Jr authored and facebook-github-bot committed Dec 18, 2021
1 parent 9a9e859 commit f5cadae
Show file tree
Hide file tree
Showing 14 changed files with 13 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def do_adaptation(
loss = -(proposal_distribution.log_prob(node_var.value))
optimizer.zero_grad()
loss.backward()
# pyre-fixme[20]: Argument `closure` expected.
optimizer.step()


Expand Down Expand Up @@ -419,7 +420,6 @@ def _proposer_func(
raise Exception("No observation embedding network found!")

obs_vec = torch.stack(obs_nodes, dim=0).flatten()
# pyre-fixme
obs_embedding = obs_embedding_net.forward(obs_vec)

node_embedding_nets = self._node_embedding_nets
Expand All @@ -429,7 +429,6 @@ def _proposer_func(
mb_embedding = torch.zeros(self._MB_EMBEDDING_DIM)
mb_nodes = list(
map(
# pyre-fixme[29]: `Union[Tensor, nn.Module]` is not a function.
lambda mb_node: node_embedding_nets(mb_node).forward(
utils.ensure_1d(
world.get_node_in_world_raise_error(mb_node).value
Expand All @@ -449,15 +448,11 @@ def _proposer_func(
mb_vec = torch.stack(mb_nodes, dim=0).unsqueeze(1)
# TODO: try pooling rather than just slicing out last hidden
mb_embedding = utils.ensure_1d(
# pyre-fixme[29]: `Union[Tensor, nn.Module]` is not a function.
mb_embedding_nets(node)
.forward(mb_vec)[0][-1, :, :]
.squeeze()
mb_embedding_nets(node).forward(mb_vec)[0][-1, :, :].squeeze()
)
node_proposal_param_nets = self._node_proposal_param_nets
if node_proposal_param_nets is None:
raise Exception("No node proposal parameter networks found!")
# pyre-fixme[29]: `Union[Tensor, nn.Module]` is not a function.
param_vec = node_proposal_param_nets(node).forward(
torch.cat((mb_embedding, obs_embedding))
)
Expand All @@ -475,9 +470,7 @@ def _proposal_distribution_for_node(
"""
node_var = self.world_.get_node_in_world_raise_error(node)
distribution = node_var.distribution
# pyre-fixme
sample_val = distribution.sample()
# pyre-fixme
support = distribution.support

ndim = sample_val.dim()
Expand Down
4 changes: 4 additions & 0 deletions src/beanmachine/ppl/experimental/neutra/iaflayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def __init__(
self.loga_max_clip_ = loga_max_clip
self.stable_ = stable

# pyre-fixme[14]: `get_parameter` overrides method defined in `Module`
# inconsistently.
# pyre-fixme[15]: `get_parameter` overrides method defined in `Module`
# inconsistently.
def get_parameter(self, x: Tensor) -> Tuple[Tensor, Tensor]:

"""
Expand Down
1 change: 1 addition & 0 deletions src/beanmachine/ppl/experimental/neutra/maskedlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def set_mask(self, mask: Tensor) -> None:
raise ValueError("Dimension mismatches between mask and layer.")
self.mask.data.copy_(mask.t())

# pyre-fixme[14]: `forward` overrides method defined in `Linear` inconsistently.
def forward(self, input_: Tensor) -> Tensor:
"""
the forward method that does the masked linear computation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
if not base_args:
base_args = {}

event_shape = target_dist.event_shape # pyre-ignore[16]
event_shape = target_dist.event_shape
# form independent product distribution of `base_dist` for `event_shape`
if len(event_shape) == 0:
self.base_args = base_args
Expand Down Expand Up @@ -105,14 +105,14 @@ def _base_dist(**kwargs):
)

# unwrap nested independents before setting transform
support = target_dist.support # pyre-ignore[16]
support = target_dist.support
while isinstance(support, constraints.independent):
support = support.base_constraint
self._transform = biject_to(support)

super().__init__(
self.new_dist.batch_shape, # pyre-ignore
self.new_dist.event_shape, # pyre-ignore
self.new_dist.batch_shape,
self.new_dist.event_shape,
validate_args=validate_args,
)

Expand Down
1 change: 0 additions & 1 deletion src/beanmachine/ppl/inference/compositional_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def get_proposers(
proposers = []
for node in target_rvs:
if node not in self._proposers:
# pyre-ignore[16]
support = world.get_variable(node).distribution.support
if any(
is_constraint_eq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def propose(self, world: World):
"""
proposal_dist = forward_dist = self.get_proposal_distribution(world)
old_value = world[self.node]
# pyre-ignore[20]
proposed_value = proposal_dist.sample()
new_world = world.replace({self.node: proposed_value})
backward_dist = self.get_proposal_distribution(new_world)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def do_adaptation(self, world, accept_log_prob, *args, **kwargs) -> None:
def get_proposal_distribution(self, world: World) -> dist.Distribution:
"""Propose a new value for self.node using the prior distribution."""
node = world.get_variable(self.node)
node_support = node.distribution.support # pyre-ignore [16]
node_support = node.distribution.support

if is_constraint_eq(node_support, dist.constraints.real):
return dist.Normal(node.value, self.step_size)
Expand Down
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/single_site_nmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _init_nmc_proposer(self, node: RVIdentifier, world: World) -> BaseProposer:
of NMC proposer will be chosen based on a node's support.
"""
distribution = world.get_variable(node).distribution
support = distribution.support # pyre-ignore
support = distribution.support
if is_constraint_eq(support, dist.constraints.real):
return SingleSiteRealSpaceNMCProposer(node, self.alpha, self.beta)
elif is_constraint_eq(support, dist.constraints.greater_than):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_proposal_distribution(
that was used or needs to be used to find the proposal distribution
"""
if node not in self.proposers_:
# pyre-fixme
node_distribution_support = node_var.distribution.support
if world.get_transforms_for_node(
node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def get_proposal_distribution(
if world.get_transforms_for_node(
node
).transform_type != TransformType.NONE or is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints.real,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def get_proposal_distribution(
node_distribution = node_var.distribution
if (
is_constraint_eq(
# pyre-fixme
node_distribution.support,
dist.constraints.boolean,
)
Expand Down
1 change: 0 additions & 1 deletion src/beanmachine/ppl/legacy/world/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def bar(self):

@property
def is_discrete(self) -> bool:
# pyre-fixme
return self.distribution.support.is_discrete

def __post_init__(self) -> None:
Expand Down
3 changes: 0 additions & 3 deletions src/beanmachine/ppl/world/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def get_default_transforms(distribution: Distribution) -> dist.Transform:
:returns: a Transform that need to be applied to the distribution
to transform it from constrained space into unconstrained space
"""
# pyre-fixme
if distribution.support.is_discrete:
return dist.transforms.identity_transform
else:
Expand All @@ -115,11 +114,9 @@ def initialize_value(distribution: Distribution, initialize_from_prior: bool = F
:param initialize_from_prior: if true, returns sample from prior
:returns: the value to the set the Variable value to
"""
# pyre-fixme
sample_val = distribution.sample()
if initialize_from_prior:
return sample_val
# pyre-fixme
support = distribution.support
if isinstance(support, dist.constraints.independent):
support = support.base_constraint
Expand Down
2 changes: 0 additions & 2 deletions src/beanmachine/ppl/world/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,8 @@ def enumerate_node(self, node: RVIdentifier) -> torch.Tensor:
A tensor enumerating the support of the node.
"""
distribution = self._variables[node].distribution
# pyre-ignore[16]
if not distribution.has_enumerate_support:
raise ValueError(str(node) + " is not enumerable")
# pyre-ignore[16]
return distribution.enumerate_support()

def _run_node(
Expand Down

0 comments on commit f5cadae

Please sign in to comment.