Skip to content

Commit

Permalink
Support CategoricalGraph->Dirichlet messaging
Browse files Browse the repository at this point in the history
  • Loading branch information
jluttine committed Apr 5, 2018
1 parent 85960fb commit b4e1200
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 0 deletions.
14 changes: 14 additions & 0 deletions bayespy/inference/vmp/nodes/categorical_graph.py
Expand Up @@ -153,6 +153,10 @@ def __init__(self, dag, plates={}, marginals={}):
for (name, config) in dag.items()
]

# Inform parents about this new child node
for cpt in cpts:
cpt.table._add_child(self, cpt.variable)

# Validate plates (children must have those plates that the parents have)

# Validate shapes of the CPTs
Expand Down Expand Up @@ -213,6 +217,10 @@ def get_potential_function(node):
cpt.variable: cpt.plates
for cpt in cpts
}
self._parent_shapes = {
cpt.variable: cpt.table.plates + cpt.table.dims[0]
for cpt in cpts
}

# Sizes of all axes (variables and plates), that is, just combine the
# two size dicts
Expand All @@ -237,6 +245,12 @@ def get_potential_function(node):
return


def _message_to_parent(self, variable, u_parent):
shape = self._parent_shapes[variable]
m0 = misc.sum_to_shape(self.u[variable], shape)
return [m0]


def lower_bound_contribution(self):
raise NotImplementedError()

Expand Down
177 changes: 177 additions & 0 deletions bayespy/inference/vmp/nodes/tests/test_categorical_graph.py
Expand Up @@ -760,4 +760,181 @@ def _check(X, y):


def test_message_to_parent(self):

def _run(parents, dag, messages, observations, **kwargs):

# Construct the DAG
dag = dag(parents)

def to_cpt(X):
return np.exp(
Dirichlet._ensure_moments(
X,
DirichletMoments
).get_moments()[0]
)



def _check(X, y):
X.update()
cpts = {
name: to_cpt(config["table"])
for (name, config) in dag.items()
}
for (name, ind) in y.items():
cpts[name] = cpts[name] * onehot(
ind,
cpts[name].shape[-1],
extradims=len(dag[name].get("given", []))
)
msgs = messages(cpts)
assert len(msgs) == len(parents)
for (parent, msg) in zip(parents, msgs):
self.assertMessage(parent._message_from_children(), [msg])
return


X = CategoricalGraph(dag, **kwargs)
_check(X, {})
for y in observations:
X.observe(y)
_check(X, y)

return


# Simple case
_run(
parents=[
Dirichlet(np.random.rand(2)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
},
},
messages=lambda cpts: [
normalize(cpts["x"])
],
observations=[
{"x": 1},
]
)

# Child has plates
_run(
parents=[
Dirichlet(np.random.rand(2)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
"plates": ["trials"],
},
},
messages=lambda cpts: [
np.einsum("ax->x", normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1)),
],
observations=[
{"x": np.ones(10, dtype=np.int)},
],
plates={
"trials": 10,
},
)

# Both have plates
_run(
parents=[
Dirichlet(np.random.rand(10, 2), plates=(10,)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
"plates": ["trials"],
},
},
messages=lambda cpts: [
normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1),
],
observations=[
{"x": np.ones(10, dtype=np.int)},
],
plates={
"trials": 10,
},
)

# Both have plates but parent is currently broadcasting them
_run(
parents=[
Dirichlet(np.random.rand(2), plates=(10,)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
"plates": ["trials"],
},
},
messages=lambda cpts: [
normalize(np.broadcast_to(cpts["x"], (10, 2)), axis=-1),
],
observations=[
{"x": np.ones(10, dtype=np.int)},
],
plates={
"trials": 10,
},
)

# Same parent in multiple CPTs
_run(
parents=[
Dirichlet(np.random.rand(2)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
},
"y": {
"table": parents[0],
},
},
messages=lambda cpts: [
normalize(cpts["x"]) + normalize(cpts["y"])
],
observations=[
{"x": 1},
{"y": 1},
{"x": 1, "y": 1},
]
)

# Multiple parents
_run(
parents=[
Dirichlet(np.random.rand(3)),
Dirichlet(np.random.rand(3, 4)),
],
dag=lambda parents: {
"x": {
"table": parents[0],
},
"y": {
"table": parents[1],
"given": ["x"],
},
},
messages=lambda cpts: [
sumproduct("x,xy->x", cpts["x"], cpts["y"]),
sumproduct("x,xy->xy", cpts["x"], cpts["y"]),
],
observations=[
{"x": 0},
{"y": 0},
{"x": 0, "y": 0},
]
)

pass

0 comments on commit b4e1200

Please sign in to comment.