diff --git a/bayespy/inference/vmp/nodes/categorical_graph.py b/bayespy/inference/vmp/nodes/categorical_graph.py index 39b448860..7760e4742 100644 --- a/bayespy/inference/vmp/nodes/categorical_graph.py +++ b/bayespy/inference/vmp/nodes/categorical_graph.py @@ -153,6 +153,10 @@ def __init__(self, dag, plates={}, marginals={}): for (name, config) in dag.items() ] + # Inform parents about this new child node + for cpt in cpts: + cpt.table._add_child(self, cpt.variable) + # Validate plates (children must have those plates that the parents have) # Validate shapes of the CPTs @@ -213,6 +217,10 @@ def get_potential_function(node): cpt.variable: cpt.plates for cpt in cpts } + self._parent_shapes = { + cpt.variable: cpt.table.plates + cpt.table.dims[0] + for cpt in cpts + } # Sizes of all axes (variables and plates), that is, just combine the # two size dicts @@ -237,6 +245,12 @@ def get_potential_function(node): return + def _message_to_parent(self, variable, u_parent): + shape = self._parent_shapes[variable] + m0 = misc.sum_to_shape(self.u[variable], shape) + return [m0] + + def lower_bound_contribution(self): raise NotImplementedError() diff --git a/bayespy/inference/vmp/nodes/tests/test_categorical_graph.py b/bayespy/inference/vmp/nodes/tests/test_categorical_graph.py index 6ce964e14..e898ddc9b 100644 --- a/bayespy/inference/vmp/nodes/tests/test_categorical_graph.py +++ b/bayespy/inference/vmp/nodes/tests/test_categorical_graph.py @@ -760,4 +760,181 @@ def _check(X, y): def test_message_to_parent(self): + + def _run(parents, dag, messages, observations, **kwargs): + + # Construct the DAG + dag = dag(parents) + + def to_cpt(X): + return np.exp( + Dirichlet._ensure_moments( + X, + DirichletMoments + ).get_moments()[0] + ) + + + + def _check(X, y): + X.update() + cpts = { + name: to_cpt(config["table"]) + for (name, config) in dag.items() + } + for (name, ind) in y.items(): + cpts[name] = cpts[name] * onehot( + ind, + cpts[name].shape[-1], + extradims=len(dag[name].get("given", [])) + ) + msgs = messages(cpts) + assert len(msgs) == len(parents) + for (parent, msg) in zip(parents, msgs): + self.assertMessage(parent._message_from_children(), [msg]) + return + + + X = CategoricalGraph(dag, **kwargs) + _check(X, {}) + for y in observations: + X.observe(y) + _check(X, y) + + return + + + # Simple case + _run( + parents=[ + Dirichlet(np.random.rand(2)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + }, + }, + messages=lambda cpts: [ + normalize(cpts["x"]) + ], + observations=[ + {"x": 1}, + ] + ) + + # Child has plates + _run( + parents=[ + Dirichlet(np.random.rand(2)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + "plates": ["trials"], + }, + }, + messages=lambda cpts: [ + np.einsum("ax->x", normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1)), + ], + observations=[ + {"x": np.ones(10, dtype=np.int)}, + ], + plates={ + "trials": 10, + }, + ) + + # Both have plates + _run( + parents=[ + Dirichlet(np.random.rand(10, 2), plates=(10,)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + "plates": ["trials"], + }, + }, + messages=lambda cpts: [ + normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1), + ], + observations=[ + {"x": np.ones(10, dtype=np.int)}, + ], + plates={ + "trials": 10, + }, + ) + + # Both have plates but parent is currently broadcasting them + _run( + parents=[ + Dirichlet(np.random.rand(2), plates=(10,)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + "plates": ["trials"], + }, + }, + messages=lambda cpts: [ + normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1), + ], + observations=[ + {"x": np.ones(10, dtype=np.int)}, + ], + plates={ + "trials": 10, + }, + ) + + # Same parent in multiple CPTs + _run( + parents=[ + Dirichlet(np.random.rand(2)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + }, + "y": { + "table": parents[0], + }, + }, + messages=lambda cpts: [ + normalize(cpts["x"]) + normalize(cpts["y"]) + ], + observations=[ + {"x": 1}, + {"y": 1}, + {"x": 1, "y": 1}, + ] + ) + + # Multiple parents + _run( + parents=[ + Dirichlet(np.random.rand(3)), + Dirichlet(np.random.rand(3, 4)), + ], + dag=lambda parents: { + "x": { + "table": parents[0], + }, + "y": { + "table": parents[1], + "given": ["x"], + }, + }, + messages=lambda cpts: [ + sumproduct("x,xy->x", cpts["x"], cpts["y"]), + sumproduct("x,xy->xy", cpts["x"], cpts["y"]), + ], + observations=[ + {"x": 0}, + {"y": 0}, + {"x": 0, "y": 0}, + ] + ) + pass