Skip to content

Commit

Permalink
Added sanity check: distribution should have expected shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Elie Wolfe committed May 31, 2023
1 parent 47f4c33 commit d3f4359
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions inflation/lp/InflationLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self,
self.network_scenario = inflationproblem.is_network
self.inflation_levels = inflationproblem.inflation_level_per_source
self.setting_cardinalities = inflationproblem.settings_per_party
self.private_setting_cardinalities = inflationproblem.private_settings_per_party
self.rectify_fake_setting = inflationproblem.rectify_fake_setting
self.factorize_monomial = inflationproblem.factorize_monomial
self._is_knowable_q_non_networks = \
Expand All @@ -90,6 +91,10 @@ def __init__(self,

# The following depends on the form of CG notation
self.outcome_cardinalities = inflationproblem.outcomes_per_party + 1
self.expected_distro_shape = tuple(np.hstack(
(inflationproblem.outcomes_per_party,
self.private_setting_cardinalities)).tolist())

self._lexorder = inflationproblem._lexorder
self._nr_operators = inflationproblem._nr_operators
self.lexorder_symmetries = inflationproblem.inf_symmetries
Expand Down Expand Up @@ -318,6 +323,9 @@ def set_distribution(self,
``True``), only atomic monomials are assigned numerical values.
"""
if prob_array is not None:
assert prob_array.shape == self.expected_distro_shape, f"Cardinalities mismatch: \n" \
f"expected {self.expected_distro_shape}, \n " \
f"got {prob_array.shape}"
knowable_values = {atom: atom.compute_marginal(prob_array)
for atom in self.knowable_atoms}
else:
Expand Down

0 comments on commit d3f4359

Please sign in to comment.