In [None]:
# default_exp dbn


In [None]:
# hide
from fastcore.all import *

# Support code for Dynamic Bayes Nets

Some specialized code for Dynamic Bayes Nets.

In [None]:
# export
import gtsam

To show DBNs, we employ dirty tricks: we check whether the key is a `gtsam.Symbol` with character A/X/Z (or u/x/z for continuous) and then use the index to specify a position in the "DBN grid":

In [None]:
# export
def dbs_position(symbol):
    """Calculate position of a DBN Symbol"""
    chr, index  = symbol.chr(), symbol.index()
    if chr==ord('A') or chr==ord('u'): return (index,2)
    if chr==ord('X') or chr==ord('x'): return (index,1)
    if chr==ord('Z') or chr==ord('z'): return (index,0)
    if chr==ord('B'): return (index,0) # for battery example

def dbk_position(key):
    """Calculate position of a DBN Key """
    return dbs_position(gtsam.Symbol(key))

In [None]:
test_eq(dbs_position(gtsam.Symbol('A', 9)), (9,2))
test_eq(dbs_position(gtsam.Symbol('X', 7)), (7,1))
test_eq(dbs_position(gtsam.Symbol('B', 2)), (2,0))
test_eq(dbs_position(gtsam.Symbol('Z', 2)), (2,0))
U = gtsam.symbol_shorthand.U
X = gtsam.symbol_shorthand.X
Z = gtsam.symbol_shorthand.Z
test_eq(dbk_position(U(9)), (9,2))
test_eq(dbk_position(X(7)), (7,1))
test_eq(dbk_position(Z(2)), (2,0))

Code below uses this for (any) conditional:

In [None]:
# export
def dbc_position(conditional):
    """Return (key,position) for a Conditional with a DBN key"""
    key = conditional.firstFrontalKey()
    return key, dbk_position(key)

In [None]:
key = U(9)
conditional = gtsam.DiscreteDistribution((key,2), "1/1")
expected  = key, (9,2)
test_eq(dbc_position(conditional), expected)

And (any) BayesNet derived class:

In [None]:
# export
def dbn_positions(bayesNet):
    """Calculate positions for a DBN as a {key:position} dictionary"""
    positions = [dbc_position(bayesNet.at(j)) for j in range(bayesNet.size())]
    return dict(pair for pair in positions if pair[1] is not None)

In [None]:
dbn = gtsam.DiscreteBayesNet()
u1 = U(1),2
x4 = X(1),2
dbn.add(u1, [], "1/1")
dbn.add(x4, [u1], "1/1 1/2")
test_eq(dbn_positions(dbn), {U(1): (1, 2), X(1): (1, 1)})

If a Bayes net does *not* contain these keys, we get an empty dictionary:

In [None]:
non_dbn = gtsam.DiscreteBayesNet()
discreteKey = 123,2
non_dbn.add(discreteKey, [], "1/1")
test_eq(dbn_positions(non_dbn), {})

Finally, either create or amend a DotWriter to be use in show:

In [None]:
# export
def dbn_writer(obj, **kwargs):
    """Create a DotWriter with variable positions for showing DBNs"""
    writer_or_none = kwargs["writer"] if "writer" in kwargs else None
    if not isinstance(obj, gtsam.DiscreteBayesNet): return writer_or_none
    positions = dbn_positions(obj)
    if positions=={}: return writer_or_none
    # OK, we have a DBN, so add variablePositions to writer
    writer = gtsam.DotWriter() if writer_or_none is None else writer_or_none
    vp = writer.variablePositions
    for key, position in positions.items():
        # only overwrite if not given already:
        if key not in vp: vp[key] = position
    writer.variablePositions = vp
    return writer

In [None]:
# non DBN:
test_is(dbn_writer(non_dbn), None)
writer = gtsam.DotWriter()
test_eq(dbn_writer(non_dbn, writer=writer), writer)

# DBN
test_eq(len(dbn_writer(dbn).variablePositions), 2)
test_eq(len(dbn_writer(dbn, writer=writer).variablePositions), 2)
