Skip to content

Commit

Permalink
Fix slow InferenceEngine construction due to early query (#176)
Browse files Browse the repository at this point in the history
* Speed up query by caching baseline marginals

* Update type hint in inference.py

* Refactor query() and _single_query(); handle invalid observations type

* Fix type hint for observations argument in query() and  _single_query()
  • Loading branch information
oentaryorj committed Sep 6, 2021
1 parent 13aa747 commit b89c6d3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 44 deletions.
98 changes: 56 additions & 42 deletions causalnex/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import inspect
import re
import types
from typing import Callable, Dict, Hashable, List, Tuple, Union
from collections import defaultdict
from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union

import networkx as nx
import pandas as pd
Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(self, bn: BayesianNetwork):
if bad_nodes:
raise ValueError(
"Variable names must match ^[0-9a-zA-Z_]+$ - please fix the "
"following nodes: {0}".format(bad_nodes)
f"following nodes: {bad_nodes}"
)

if not bn.cpds:
Expand All @@ -119,17 +120,16 @@ def __init__(self, bn: BayesianNetwork):
)

self._cpds = None
self._upstream_cpds = {}
self._detached_cpds = {}
self._baseline_marginals = None

self._create_cpds_dict_bn(bn)
self._generate_domains_bn(bn)
self._generate_bbn()

# TODO: can we do it without a query() call? # pylint: disable=fixme
self._default_marginals = self.query()

def _single_query(
self, observations: Dict[str, Hashable] = None
self,
observations: Optional[Dict[str, Any]] = None,
) -> Dict[str, Dict[Hashable, float]]:
"""
Queries the ``BayesianNetwork`` for marginals given some observations.
Expand All @@ -146,49 +146,62 @@ def _single_query(
bbn_results = (
self._bbn.query(**observations) if observations else self._bbn.query()
)
results = {node: {} for node in self._cpds}
results = defaultdict(dict)

for (node, state), prob in bbn_results.items():
results[node][state] = prob

# the upstream nodes are set to the default marginals based on the
# original cpds of the bn
for detached_node in self._upstream_cpds:
results[detached_node] = self._default_marginals[detached_node]
# the detached nodes are set to the baseline marginals based on original CPDs
for node in self._detached_cpds:
results[node] = self._baseline_marginals[node]

return results

def query(
self,
observations: Union[Dict[str, Hashable], List[Dict[str, Hashable]]] = None,
observations: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
parallel: bool = False,
num_cores: int = None,
num_cores: Optional[int] = None,
) -> Union[
Dict[str, Dict[Hashable, float]], List[Dict[str, Dict[Hashable, float]]]
Dict[str, Dict[Hashable, float]],
List[Dict[str, Dict[Hashable, float]]],
]:
"""
Queries the ``BayesianNetwork`` for marginals given one or more observations.
Args:
observations: one or more observations of states of nodes in the Bayesian Network.
parallel: if True, run the query using multiprocessing
num_cores: only applicable if paralle=True. The number of cores used during multiprocessing.
num_cores: only applicable if parallel=True. The number of cores used during multiprocessing.
If num_cores is not provided, number of processors will be autodetected and used
Returns:
A dictionary or a list of dictionaries of marginal probabilities of the network.
"""
if isinstance(observations, dict) or observations is None:
return self._single_query(observations)
result = []

if parallel:
with multiprocessing.Pool(num_cores) as p:
result = p.map(self._single_query, observations)
Raises:
TypeError: if observations is neither None nor a dictionary nor a list
"""
if observations is not None and not isinstance(observations, (dict, list)):
raise TypeError("Expecting observations to be a dict, list or None")

# initialise baseline marginals if not done previously
if self._baseline_marginals is None:
self._baseline_marginals = self._single_query(None)

if observations is None:
# perform single query if there was a do-intervention before, else return baseline marginals
if self._detached_cpds:
result = self._single_query(None)
else:
result = self._baseline_marginals
elif isinstance(observations, dict):
result = self._single_query(observations)
else:
for obs in observations:
result.append(self._single_query(obs))
if parallel:
with multiprocessing.Pool(num_cores) as p:
result = p.map(self._single_query, observations)
else:
result = [self._single_query(obs) for obs in observations]

return result

Expand All @@ -203,7 +216,6 @@ def _do(self, observation: str, state: Dict[Hashable, float]):
Raises:
ValueError: if states do not match original states of the node, or probabilities do not sum to 1.
"""

if sum(state.values()) != 1.0:
raise ValueError("The cpd for the provided observation must sum to 1")

Expand All @@ -213,19 +225,18 @@ def _do(self, observation: str, state: Dict[Hashable, float]):
)

if not set(state.keys()) == set(self._cpds_original[observation]):
expected = set(self._cpds_original[observation])
found = set(state.keys())
raise ValueError(
"The cpd states do not match expected states: expected {expected}, found {found}".format(
expected=set(self._cpds_original[observation]),
found=set(state.keys()),
)
f"The cpd states do not match expected states: expected {expected}, found {found}"
)

self._cpds[observation] = {s: {(): p} for s, p in state.items()}

def do_intervention(
self,
node: str,
state: Union[Hashable, Dict[Hashable, float]] = None,
state: Optional[Union[Hashable, Dict[Hashable, float]]] = None,
):
"""
Makes an intervention on the Bayesian Network.
Expand Down Expand Up @@ -269,21 +280,27 @@ def reset_do(self, observation: str):
"""
self._cpds[observation] = self._cpds_original[observation]

for upstream_node, original_cpds in self._upstream_cpds.items():
for upstream_node, original_cpds in self._detached_cpds.items():
self._cpds[upstream_node] = original_cpds

self._upstream_cpds = {}
self._detached_cpds = {}
self._generate_bbn()

def _generate_bbn(self):
"""Re-creates the _bbn."""
self._node_functions = self._create_node_functions()
self._bbn = build_bbn(
list(self._node_functions.values()), domains=self._domains
list(self._node_functions.values()),
domains=self._domains,
)

def _generate_domains_bn(self, bn: BayesianNetwork):
"""Generates domains from Bayesian network"""
"""
Generates domains from Bayesian network
Args:
bn: Bayesian network
"""
self._domains = {
variable: list(cpd.index.values) for variable, cpd in bn.cpds.items()
}
Expand Down Expand Up @@ -387,8 +404,7 @@ def _create_node_functions(self) -> Dict[str, Callable]:
condition_nodes = [n for n, v in state_conditions]

node_args = tuple([node] + condition_nodes) # type: Tuple[str]
function_name = "f_{node}".format(node=node)
node_function = self._create_node_function(function_name, node_args)
node_function = self._create_node_function(f"f_{node}", node_args)
node_functions[node] = node_function

return node_functions
Expand All @@ -413,12 +429,10 @@ def _remove_disconnected_nodes(self, var: str):
# construct graph from CPDs
g = nx.DiGraph()

# add nodes as there could be isolates (e.g. A->B->C intervening on B
# makes A an isolate)
for node, states in self._cpds.items():
sample_state = next(iter(states.values()))
parents = next(iter(sample_state.keys()))
g.add_node(node)
g.add_node(node) # add nodes as there could be isolates

for parent, _ in parents:
g.add_edge(parent, node)
Expand All @@ -427,5 +441,5 @@ def _remove_disconnected_nodes(self, var: str):
for sub_graph in nx.weakly_connected_components(g):
if var not in sub_graph:
for node in sub_graph:
self._upstream_cpds[node] = self._cpds[node]
self._detached_cpds[node] = self._cpds[node]
self._cpds.pop(node)
27 changes: 25 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_create_inference_from_bn(self, train_model, train_data_idx):
def test_create_inference_with_bad_variable_names_fails(
self, train_model, train_data_idx
):
"""Test creation of InferenceEngine with bad variable names"""

model = StructureModel()
model.add_edges_from(
Expand All @@ -64,6 +65,28 @@ def test_create_inference_with_bad_variable_names_fails(
with pytest.raises(ValueError, match="Variable names must match.*"):
InferenceEngine(bn)

def test_invalid_observations(self, train_model, train_data_idx):
"""Test with invalid observations type"""

bn = BayesianNetwork(train_model)
bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx)
ie = InferenceEngine(bn)

with pytest.raises(
TypeError, match="Expecting observations to be a dict, list or None"
):
ie.query("123")

with pytest.raises(
TypeError, match="Expecting observations to be a dict, list or None"
):
ie.query({"123", "abc"})

with pytest.raises(
TypeError, match="Expecting observations to be a dict, list or None"
):
ie.query(("123", "abc"))

def test_empty_query_returns_marginals(
self, train_model, train_data_idx, train_data_idx_marginals
):
Expand Down Expand Up @@ -437,7 +460,7 @@ def test_query_after_do_intervention_has_split_graph(self, chain_network):

# assert the _cpds of the upstream nodes are stored correctly
orig_cpds = ie._cpds_original # pylint: disable=protected-access
upstream_cpds = ie._upstream_cpds # pylint: disable=protected-access
upstream_cpds = ie._detached_cpds # pylint: disable=protected-access
assert orig_cpds["a"] == upstream_cpds["a"]
assert orig_cpds["b"] == upstream_cpds["b"]

Expand Down Expand Up @@ -466,7 +489,7 @@ def test_query_after_do_intervention_has_split_graph(self, chain_network):

# assert the _cpds of the upstream nodes are stored correctly
orig_cpds = ie._cpds_original # pylint: disable=protected-access
upstream_cpds = ie._upstream_cpds # pylint: disable=protected-access
upstream_cpds = ie._detached_cpds # pylint: disable=protected-access
assert orig_cpds["a"] == upstream_cpds["a"]

ie.reset_do(var_b)
Expand Down

0 comments on commit b89c6d3

Please sign in to comment.