In [3]:
#| default_exp dbn

In [4]:
#| include: false
from fastcore.all import *

In [5]:
#| export
import gtsam

Create or amend a DotWriter to be use in show:

In [10]:
#| export
from gtsam import GraphvizFormatting
Axis = GraphvizFormatting.Axis

def dbn_writer(writer=None, hints: dict = None, positions: dict = None,
               boxes: set = None, factor_positions: dict = None,
               binary_edges=False, **kwargs):
    """ Create a DotWriter depending on input arguments:
        If writer is supplied, we will add but not overwrite hints or positions.
    """
    if writer is None and hints is None and positions is None and boxes is None and factor_positions is None and binary_edges==False:
        return None
    writer = GraphvizFormatting() if writer is None else writer

    writer.paperHorizontalAxis = Axis.X
    writer.paperVerticalAxis = Axis.Y

    # Copy hints without overwriting
    if hints is not None:
        assert isinstance(hints, dict)
        ph: dict = writer.positionHints
        for key, y in hints.items():
            if key not in ph:
                ph[key] = y
        writer.positionHints = ph
    # Copy positions without overwriting
    if positions is not None:
        assert isinstance(positions, dict)
        kp: dict = writer.variablePositions
        for key, position in positions.items():
            if key not in kp:
                kp[key] = position
        writer.variablePositions = kp
    # Add boxes
    if boxes is not None:
        assert isinstance(boxes, set)
        bx: set = writer.boxes
        for key in boxes:
            bx.add(key)
        writer.boxes = bx
    # Copy factor positions without overwriting
    if factor_positions is not None:
        assert isinstance(factor_positions, dict)
        kp: dict = writer.factorPositions
        for i, position in factor_positions.items():
            if i not in kp:
                kp[i] = position
        writer.factorPositions = kp
    writer.binaryEdges = binary_edges
    return writer


def has_positions(writer):
    """Check if writer has positions for engine selection"""
    if writer is None:
        return False
    return len(writer.positionHints) > 0 or len(writer.variablePositions) > 0 or len(writer.factorPositions) > 0

In [11]:
# Check None cases
assert dbn_writer() is None
assert dbn_writer(exact=True) is None

# Check passthrough
writer = GraphvizFormatting()
test_eq(dbn_writer(writer), writer)
test_eq(has_positions(writer), False)

# Check boxes, and that they don't stomp
writer = dbn_writer(boxes={1, 2})
test_eq(writer.boxes, {1, 2})
writer = dbn_writer(writer, boxes={2, 3})
test_eq(writer.boxes, {1, 2, 3})
test_eq(has_positions(writer), False)

# Check hints, and that they don't stomp
writer = dbn_writer(hints={"A": 2})
test_eq(writer.positionHints, {"A": 2})
writer = dbn_writer(writer, hints={"A": 3})
test_eq(writer.positionHints, {"A": 2})
test_eq(has_positions(writer), True)

# Check positions, and that they don't stomp
key = 123
writer = dbn_writer(positions={key: (2, 0)})
test_eq(len(writer.variablePositions), 1)
writer = dbn_writer(writer, positions={key: (3, 0)})
test_eq(len(writer.variablePositions), 1)
test_eq(writer.variablePositions[key], (2, 0))
test_eq(has_positions(writer), True)

# Check factor positions, and that they don't stomp
i = 0
writer = dbn_writer(factor_positions={i: (2, 0)})
test_eq(len(writer.factorPositions), 1)
writer = dbn_writer(writer, factor_positions={i: (3, 0)})
test_eq(len(writer.factorPositions), 1)
test_eq(writer.factorPositions[i], (2, 0))
test_eq(has_positions(writer), True)