# Demonstration Informed Specification Search: Experiment

Let's take a look at how the DISS algorithm can search for specifications by leveraging expert demonstrations. 
We'll focus on learning DFAs in this case, but note that this approach is not confined to any specific concept class. To start consider an agent operating in the following stochastic gridworld.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_1.svg"
         style="height: 20em;"
     />
</figure>

## Agent Actions

The agent can attempt to move up, down, left, or right as illustrated below.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_2.svg"
         style="height: 20em;"
     />
</figure>

## Stochastic Transitions

However, there is some small probability that the agent will slip downward do to wind!

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_3.svg"
         style="height: 20em;"
     />
</figure>



Let's assume the agent's task can be described in terms of the color's of the tiles. **What was the agent trying to do?**


## Probably avoiding the red tiles

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_4.svg"
         style="height: 20em;"
     />
</figure>

## Probably trying to reach yellow tile

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_5.svg"
         style="height: 20em;"
     />
</figure>

In this notebook, we will illustrate learning task representations (in the form of Deterministic Finite Automata, i.e. DFA) that can be learned incrementally and describe temporal tasks.

In particular, we shall consider a variation of the following gridworld from [this](https://mjvc.me/DISS/#/77) slide deck: Here the agent's task is a composition of three subtasks.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_1.svg"
         style="height: 20em;"
     />
</figure>

Where each subtask is a regular language represented as a DFA.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_2_1.svg"
         style="height: 20em;"
     />
</figure>

Further, we shall assume that the first two subtasks are a-priori known (say due to learning in another workspace), 
<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_1_2.svg"
         style="height: 20em;"
     />
</figure>

and our task is to learn the third task given a partial demonstration.
<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>


# Preamble

In [None]:
from functools import lru_cache

import attr
import funcy as fn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from bidict import bidict
from IPython.display import Image, display
import networkx as nx
import pydot

from collections import Counter

import dfa
from dfa.utils import find_subset_counterexample, find_equiv_counterexample
from dfa_identify import find_dfa, find_dfas

from diss.product_mc import ProductMC
from diss.dfa_concept import DFAConcept
from diss.domains.gridworld_naive import GridWorldNaive as World
from diss.domains.gridworld_naive import GridWorldState as State
from diss import search, LabeledExamples, GradientGuidedSampler, ConceptIdException
from pprint import pprint
from itertools import combinations
from tqdm import tqdm_notebook
from tqdm.notebook import trange
from IPython.display import clear_output
from IPython.display import HTML as html_print

sns.set_context('paper')
sns.set_style('whitegrid')

## Let's first visualize our gridworld and a demonstration within the gridworld.

In [None]:
COLOR_ALIAS = {
    'yellow': '#ffff00', 'brown': '#ffb081',
    'red': '#ff8b8b', 'blue': '#afafff', 'green' : '#67f7a1'
}


def tile(color='black'):
    color = COLOR_ALIAS.get(color, color)
    s = '&nbsp;'*4
    return f"<text style='border: solid 1px;background-color:{color}'>{s}</text>"


def ap_at_state(x, y, world):
    """Use sensor to create colored tile."""
    if (x, y) in world.overlay:
        color = world.overlay[(x,y)]

        if color in COLOR_ALIAS.keys():
            return tile(color)
    return tile('white')

def print_map(world):
    """Scan the board row by row and print colored tiles."""
    order = range(1, world.dim + 1)
    for y in order:
        chars = (ap_at_state(x, y, world) for x in order)
        display(html_print('&nbsp;'.join(chars)))

def print_trc(trc, world, idx=0):
    states = [s for s, kind in trc if kind == 'ego']
    actions = [s.action for s, kind in trc if kind == 'env']

    obs = [ap_at_state(pos.x, pos.y, world) for pos in states]
    display(
        html_print(f'trc {idx}:&nbsp;&nbsp;&nbsp;' + ''.join(''.join(x) for x in zip(obs, actions)) + '\n')
    ) 

In [None]:
gw = World(
    dim=3,
    start=State(x=3, y=1),
    overlay={
      (1, 1): 'yellow',
      (1, 2): 'green',
      (1, 3): 'green',
      (2, 3): 'red',
      (3, 2): 'blue',
      (3, 3): 'blue',
    }
)

demos = [[
   (State(3, 1), 'ego'),
   (State(3, 1, '←'), 'env'),
   (State(3, 2), 'ego'),
   (State(3, 2, '←'), 'env'),
   (State(2, 2), 'ego'),
   (State(2, 2, '←'), 'env'),
   (State(1, 2), 'ego'),
   (State(1, 2, '↑'), 'env'),
   (State(1, 1), 'ego'),
]]

print_map(gw)
print_trc(demos[0], gw)

Compare against our target domain:

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>

# Search procedure

Now, we can define a set of expert demonstrations for this gridworld to guide our specification search procedure.

Let's start with some very simple base examples to warm-start our specification search process. We want to synthesize a spec that's consistent with the observed evidence thus far:

Going from the partial spec to a full spec :

In [None]:
def partial_dfa(inputs):
    def transition(s, c):
        if c == 'red':
            return s | 0b01
        elif c == 'yellow':
            return s | 0b10
        return s

    return dfa.DFA(
        start=0b00,
        inputs=inputs,
        label=lambda s: s == 0b10,
        transition=transition
    )

def ignore_white(path):
    return tuple(x for x in path if x != 'white')

Now, we can outline the machinery for the search process itself. We use the solution procedure in the DFA identification algorithm to synthesize a minimal DFA (in both states and non-stuttering edges) that is consistent with the observed examples to this point. 

In [None]:
@fn.memoize(key_func=lambda accepting, rejecting, alphabet, order_by_stutter: hash((accepting, rejecting)))
def find_dfas2(accepting, rejecting, alphabet, order_by_stutter):
    dfas = find_dfas(accepting, rejecting, alphabet=alphabet, order_by_stutter=order_by_stutter)
    dfas = fn.take(10, dfas)
    return dfas

In [None]:
def subset_check_wrapper(dfa_candidate):
    partial = partial_dfa(dfa_candidate.inputs)
    ce = find_subset_counterexample(dfa_candidate, partial)
    return ce is None


ALPHABET = frozenset({'red', 'yellow', 'blue', 'green'})

base_examples = LabeledExamples(
    positive=[
        ('yellow',),
        ('yellow', 'yellow'),
    ],
    negative=[
        (), ('red',), ('red', 'red'),
        ('red', 'yellow'), ('yellow', 'red'),
        ('yellow', 'red', 'yellow'),
        ('yellow', 'yellow', 'red'),
    ]
)


@fn.memoize
def subset_cegis(data):
    global base_examples

    for i in range(20):
        mydfa = find_dfa(data.positive, data.negative, order_by_stutter=True) 
        if mydfa is None:
            raise ConceptIdException
        partial = partial_dfa(mydfa.inputs)
        ce = find_subset_counterexample(mydfa, partial)
        if ce is None:
            break
        base_examples @= LabeledExamples(negative=[ce])
        data @= LabeledExamples(negative=[ce])

        for k, lbl in enumerate(partial.transduce(ce)):
            prefix = ce[:k]
            if not lbl:
                base_examples @= LabeledExamples(negative=[prefix])
                data @= LabeledExamples(negative=[prefix])
    return data


def to_concept(data):
    global base_examples
    
    data = data.map(ignore_white) @ base_examples
    data = subset_cegis(data)

    concept = DFAConcept.from_examples(data, subset_check_wrapper, alphabet=ALPHABET, find_dfas=find_dfas2) 
    # Adjust size to account for subset information.
    return attr.evolve(concept, size=concept.size - np.log(3))



In [None]:
from diss.dfa_concept import remove_stutter
from collections import defaultdict

# adapted from the dfa library
def get_dot(dfa_):
    dfa_dict, init = dfa.dfa2dict(dfa_)
    remove_stutter(dfa_dict)
    g = pydot.Dot(rankdir="LR")

    nodes = {}
    for i, (k, (v, _)) in enumerate(dfa_dict.items()):
        shape = "doublecircle" if v else "circle"
        nodes[k] = pydot.Node(i+1, label=f"{k}", shape=shape)
        g.add_node(nodes[k])

    edges = defaultdict(list)
    for start, (_, transitions) in dfa_dict.items():        
        for action, end in transitions.items():
            color = COLOR_ALIAS[str(action)]
            edges[start, end].append(color)
    
    init_node = pydot.Node(0, shape="point", label="")
    g.add_node(init_node)
    g.add_edge(pydot.Edge(init_node, nodes[init]))

    for (start, end), colors in edges.items():
        for color in colors:
            g.add_edge(pydot.Edge(nodes[start], nodes[end], label='⬛', fontcolor=color))
            
    return g

def view_pydot(pdot):
    plt = Image(pdot.create_png())
    display(plt)

# Simulated Annealed + SGGS

In [None]:
from diss import diss

In [None]:
@fn.memoize(key_func=lambda c, t, psat: c.dfa)
def to_chain(c, t, psat):
    chain = ProductMC.construct(
        concept=c, tree=t, dyn=gw, max_depth=9,
        psat=0.8, sensor=gw.sensor
    )
    return chain

In [None]:
n_iters = 50


dfa_search = diss(
    demos=demos,
    to_concept=to_concept,
    to_chain=to_chain,
    competency=lambda *_: 0.8,
    lift_path=lambda x: ignore_white(map(gw.sensor, x)),
    n_iters=n_iters,
    reset_period=5,
    surprise_weight=20,  # Rescale surprise to make comparable to size.
    cmf_threshold=0.99,
)

concept2energy = {}
for _, (data, concept, metadata) in zip(trange(n_iters, desc='DISS'), dfa_search):
    concept2energy[concept] = metadata['energy']
    poi = metadata['poi']
    #view_pydot(get_dot(concept.dfa))

In [None]:
sorted_concepts = sorted(list(concept2energy), key=concept2energy.get)
pmf = np.array([np.exp(-concept2energy[c]) for c in sorted_concepts])
Z = pmf.sum()
print(f'{Z=}')
pmf /= Z
cmf = np.cumsum(pmf)
idx = (cmf < 0.99).sum()  # 99% of probability mass concentrated below this point.

for p, c in zip(pmf, sorted_concepts[:idx + 1]):
    print(f'prob = {p:.3}')
    print(f'energy = {concept2energy[c]:.3}')
    print(f'size = {c.size:.3}')
    view_pydot(get_dot(c.dfa))

# Plot CDF of distribution.
sns.lineplot(x=list(range(len(cmf))), y=cmf)
plt.xlabel('DFA index (sorted by probability mass)')
plt.ylabel('CMF')
plt.show()

In [None]:
# 4. Compute current support's belief on unlabeled strings of interest.
p_accept = {}
for word in poi:
    votes = np.array([(word in c) for c in sorted_concepts])
    p_accept[word] = pmf @ votes
    
sorted_poi = sorted(poi, key=lambda x: -p_accept[x])

for word in sorted_poi:
    print(f'{word}'.ljust(50) + f'{p_accept[word]:.2}')
    print('-' * 54)

In [None]:
accepting, rejecting = set(), set()
for word in poi:
    belief = p_accept[word]
    if belief > 0.9:
        accepting.add(word)
    elif belief < 0.1:
        rejecting.add(word)
data = LabeledExamples(accepting, rejecting)
view_pydot(get_dot(to_concept(data).dfa))

# Enumeration baseline

In [None]:
def sampler_factory(demos):
    return GradientGuidedSampler.from_demos(
        demos=demos,
        to_chain=to_chain,
        competency=lambda *_: 0.8,
    )

In [None]:
to_chain.invalidate_all()

In [None]:
from dfa.utils import minimize

def enumerate_dfas():
    data = LabeledExamples(
        positive=[
            ('yellow',),
            ('yellow', 'yellow'),
        ],
        negative=[
            (), ('red',), ('red', 'red'),
            ('red', 'yellow'), ('yellow', 'red'),
            ('yellow', 'red', 'yellow'),
            ('yellow', 'yellow', 'red'),
        ]
    )

    # CEGIS loop to add constraints to enforce subsets.
    for i in range(20):
        tests = fn.take(5, find_dfas(
            data.positive,
            data.negative,
            order_by_stutter=True,
            allow_unminimized=True,
        ))
        new_data = LabeledExamples()
        for test in tests:
            assert test is not None
            partial = partial_dfa(test.inputs)
            ce = find_subset_counterexample(test, partial)
            if ce is not None:
                new_data @= LabeledExamples(negative=[ce])
            if new_data.size == 0:
                break

        dfas = find_dfas(
            data.positive,
            data.negative,
            order_by_stutter=True,
            alphabet=ALPHABET,
            allow_unminimized=True,
        )
        
        yield from map(minimize, filter(subset_check_wrapper, dfas))


In [None]:
sggs = sampler_factory(demos)

In [None]:
#Z2 = 0
dist = Counter()
for i, d in enumerate(fn.distinct(enumerate_dfas())):
    concept = DFAConcept.from_dfa(d)
    dist.update([len(d.states())])
    print(dist)
    print(f'             {i}                ')
    #energy = 10 * sggs(concept)[1]['surprisal'] + concept.size
    #Z2 += np.exp(-energy)
    
    #print(f'{energy=:.3}')
    #print(f'Z2 / Z = {Z2 / Z:0.3}')
    print('--------------------------------')
    if len(d.states()) > 4:
        break
    #view_pydot(get_dot(d))

# Circuit + BDD based MaxEnt Policy (CAV '20)

# 1. Create Dynamical System

Here we create a BitVector sequential circuit, `DYN`, using `py-aiger`, the models a gridworld (line 4).

Afterwords, lines 6-8 describe introducing a slip probability of `1/32` (modeled by a biased coin with bias `31/32`). 

**Note that states are 1-hot encoded**

In [None]:
import aiger as A
import aiger_bv as BV
import aiger_gridworld as GW
import aiger_ptltl as LTL
from bidict import bidict
from aiger_bdd import to_bdd

In [None]:
STATE = BV.uatom(16, 'state')
X = STATE[:8]
Y = STATE[8:]
s0 = (3, 5)
#                            
DYN = GW.gridworld(8, start=(s0[0], 9 - s0[1]), compressed_inputs=True)
SLIP = BV.atom(1, 'c', signed=False).repeat(2) & BV.atom(2, 'a', signed=False)
SLIP = SLIP.with_output('a').aigbv
DYN <<= SLIP

def encode_state(x, y):
    x, y = x - 1, (9 - y) + 7
    return {'state': (1 << x) | (1 << y)}

In [None]:
GW.GridState(encode_state(5, 4)['state'], 8).y

In [None]:
DYN.latch2init

In [None]:
print(DYN({'a': GW.WEST, 'c': 0})[0]['state'].y)

In [None]:
print(DYN({'a': GW.WEST, 'c': 1})[0]['state'].x)

In [None]:
bin(encode_state(2, 3)['state'])

# 2. Create Sensor / Feature overlay

Next, we define the mapping from concrete states to sensor values / atomic predicates.
We use simple coordinate wise bitvector masks to encode the color overlays.

In [None]:
def mask_test(xmask, ymask):
    return ((X & xmask) !=0) & ((Y & ymask) != 0)


APS = {       #            x-axis       y-axis
    'yellow': mask_test(0b1000_0001, 0b1000_0001),
    'blue':   mask_test(0b0001_1000, 0b0011_1000),
    'brown':  mask_test(0b0011_1100, 0b1000_0001),
    'red':    mask_test(0b1000_0001, 0b0011_0010) \
            | mask_test(0b0100_0010, 0b0011_0011),
}
0b0100_1100
0b0011_0010
def create_sensor(aps):
    sensor = BV.aig2aigbv(A.empty())
    for name, ap in APS.items():
        sensor |= ap.with_output(name).aigbv
    return sensor

SENSOR = create_sensor(APS)

## Visualizing Overlay

This can all seem pretty abstract, so let's visualize the way the sensor sees the board.

In [None]:
from IPython.display import HTML as html_print


COLOR_ALIAS = {
    'yellow': '#ffff8c', 'brown': '#ffb081',
    'red': '#ff5454', 'blue': '#9595ff'
}


def tile(color='black'):
    color = COLOR_ALIAS.get(color, color)
    s = '&nbsp;'*4
    return f"<text style='border: solid 1px;background-color:{color}'>{s}</text>"


def ap_at_state(x, y, in_ascii=False):
    """Use sensor to create colored tile."""
    state = encode_state(x, y)
    obs = SENSOR(state)[0]   # <----------   

    for k in COLOR_ALIAS.keys():
        if obs[k][0]:
            return tile(k)
    return tile('white')

def print_map():
    """Scan the board row by row and print colored tiles."""
    order = range(1, 9)
    for y in order:
        chars = (ap_at_state(x, y, in_ascii=True) for x in order)
        display(html_print('&nbsp;'.join(chars)))
        
print_map()

# 3. Describe demonstrations

We now encode a collection of demonstrations from an agent attempting to:

1. Avoid the red tiles (lava).
2. Reach the yellow tiles (recharge).
3. If the agent touches a blue tile (water), then it must dry off (brown tile) before recharging.

**Note** Trace 5 corresponds to a very unlikely demonstration where the agent enters the water, but is unable to dry off due to wind.

In [None]:
def print_trc(trc, idx=0):
    actions, states = trc
    obs = (ap_at_state(*pos, in_ascii=True) for pos in states)
    display(
        html_print(f'trc {idx}:&nbsp;&nbsp;&nbsp;' + ''.join(''.join(x) for x in zip(actions, obs)) + '\n')
    )

ACTIONS0 = "→→↑↑↑↑→→→"
STATES0 = ((4, 5), (5, 5), (5, 4), (5, 3),(5, 2), (5, 1), (6, 1), (7, 1), (8, 1))
TRC0 = (ACTIONS0, STATES0)
print_trc(TRC0, 0)

ACTIONS1 = "↑↑↑↑←←←←←"
STATES1 = ((3, 4), (3, 3), (3, 2), (3, 1), (2, 1), (1, 1), (1, 1), (1, 1), (1, 1),)
TRC1 = (ACTIONS1, STATES1)
print_trc(TRC1, 1)

ACTIONS2 = "←→↑↑↑←↑←←"
STATES2 = ((2, 5), (3, 5), (3, 4), (3, 3), (3, 2), (2, 2), (2, 1), (1, 1), (1, 1))
TRC2 = (ACTIONS2, STATES2)
print_trc(TRC2, 2)

ACTIONS3 = "↑↑→←↑↑←←←"
STATES3 = ((3, 4), (3, 3), (4, 3), (3, 3), (3, 2), (3, 1), (2, 1), (1, 1), (1, 1))
TRC3 = (ACTIONS3, STATES3)
print_trc(TRC3, 3)

ACTIONS4 = "↑→↑↑↑←←←←"
STATES4 = ((3, 4), (4, 4), (4, 3), (4, 2), (4, 1), (3, 1), (2, 1), (1, 1), (1, 1))
TRC4 = (ACTIONS4, STATES4)
print_trc(TRC4, 4)

ACTIONS5 = "↑→↑↑→→→→→"
STATES5 = ((3, 4), (4, 4), (4, 3), (4, 2), (3, 2), (2, 2), (1, 2), (1, 2), (1, 2))
TRC5 = (ACTIONS5, STATES5)
print_trc(TRC5, 5)

TRACES = [TRC0, TRC1, TRC2, TRC3, TRC4]         # Variety of positive demos.
TRACES += [TRC5]                                # Unlucky, Negative Demonstration.
TRACES += 4 * [TRC4]                            # Additional "Safe" Demonstrations.

In [None]:
def encode_trace(trc):
    actions, states = trc
    actions = [{'a': a} for a in actions]
    states = [encode_state(*s) for s in states]

    
    for s, a, s2 in zip([encode_state(*s0)] + states, actions, states):
        s, s2 = GW.GridState(s['state'], 8), GW.GridState(s2['state'], 8)
        action = a['a']
        print((s.x, s.y), (s2.x, s2.y), action)
        if action == GW.WEST:
            a['c'] = 1
        elif action == GW.EAST:
            a['c'] = int((s2.x > s.x) or s.x == 8)
        else:


            #print(s.x, s2.x, action)

            a['c'] = int(s.x == s2.x)
    actions[-1]['c'] = 1  # Last action needs some arbitrary assignment to slipping.
    
    return actions, states

encode_trace(TRC5)

# 4. Define Specification Circuits / Concept Class

First, we describe the properties over colors of the map. This is done in past tense temporal logic using `py-aiger-ptltl`.

In [None]:
LAVA, RECHARGE, WATER, DRY = map(LTL.atom, ['red', 'yellow', 'blue', 'brown'])

EVENTUALLY_RECHARGE = RECHARGE.once()
AVOID_LAVA = (~LAVA).historically()

RECHARGED_AND_ONCE_WET = RECHARGE & WATER.once()
DRIED_OFF = (~WATER).since(DRY)

DIDNT_RECHARGE_WHILE_WET = (RECHARGED_AND_ONCE_WET).implies(DRIED_OFF)
DONT_RECHARGE_WHILE_WET = DIDNT_RECHARGE_WHILE_WET.historically()

CONST_TRUE = LTL.atom(True)


SPECS = [
    CONST_TRUE, AVOID_LAVA, EVENTUALLY_RECHARGE, DONT_RECHARGE_WHILE_WET,
    AVOID_LAVA & EVENTUALLY_RECHARGE & DONT_RECHARGE_WHILE_WET,
    AVOID_LAVA & EVENTUALLY_RECHARGE,
    AVOID_LAVA & DONT_RECHARGE_WHILE_WET,
    EVENTUALLY_RECHARGE & DONT_RECHARGE_WHILE_WET,
]

SPEC_NAMES = [
    "CONST_TRUE", "AVOID_LAVA", "EVENTUALLY_RECHARGE", "DONT_RECHARGE_WHILE_WET",
    "AVOID_LAVA & EVENTUALLY_RECHARGE & DONT_RECHARGE_WHILE_WET",
    "AVOID_LAVA & EVENTUALLY_RECHARGE",
    "AVOID_LAVA & DONT_RECHARGE_WHILE_WET",
    "EVENTUALLY_RECHARGE & DONT_RECHARGE_WHILE_WET",
]

In [None]:
def spec2monitor(spec):
    spec = spec.with_output('SAT')
    monitor = spec.aig | A.sink(['red', 'yellow', 'brown', 'blue'])
    monitor = BV.aig2aigbv(monitor)
    return DYN >> SENSOR >> monitor
    
SPEC2MONITORS = { spec: spec2monitor(spec) for spec in SPECS }

# Creating BDD game-graph

In [None]:
spec2monitor(SPECS[4]).simulate(encode_trace(TRC4)[0])[-1]

In [None]:
unrolled = spec2monitor(SPECS[4]).aigbv \
                                 .cone('SAT') \
                                 .unroll(10, only_last_outputs=True)
causal_order = []
for t in range(10):
    causal_order.append(f'a##time_{t}[0]')
    causal_order.append(f'a##time_{t}[1]')
    causal_order.append(f'c##time_{t}[0]')
causal_order = {x: i for i, x in enumerate(causal_order)}

In [None]:
causal_order

In [None]:
from dd.cudd import BDD

In [None]:
manager = BDD()
manager.declare(*causal_order)
bexpr, *_ = to_bdd(unrolled, manager=manager, renamer=lambda _, x: x, levels=causal_order)
bexpr.dag_size

In [None]:
len(manager.vars)

In [None]:
import networkx as nx

In [None]:
def to_nx(bexpr):
    # DFS to translate edge-compelemented BDD to networkx graph.
    dag = nx.DiGraph()

    stack, visited = [(bexpr, False, int(bexpr))], set()
    while stack:
        bexpr, parity, ref = stack.pop()

        if ref in visited:
            continue

        visited.add(ref)
        if bexpr in (bexpr.bdd.true, bexpr.bdd.false):
            label = bexpr == bexpr.bdd.true
            dag.add_node(ref, label=label, level=len(bexpr.bdd.vars))
            continue

        dag.add_node(ref, label=bexpr.var, level=bexpr.level)

        parity = bexpr.negated ^ parity
        for lbl, bexpr2 in [(0, bexpr.low), (1, bexpr.high)]:
            ref2 = int(bexpr2 if parity else ~bexpr2)
            dag.add_edge(ref, ref2, label=lbl)
            stack.append((bexpr2, parity, ref2))
    return dag

In [None]:
dag = to_nx(bexpr)

In [None]:
list(dag.neighbors(int(bexpr)))

In [None]:
for src, data in dag.nodes(data=True):
    label = data['label']
    if isinstance(label, bool):
        data['kind'] = label
    elif label.startswith('a'):
        data['kind'] = 'ego'
    else:
        data['kind'] = 'env'

for src, tgt, data in dag.edges(data=True):
    entropy = dag.nodes[tgt]['level'] - dag.nodes[src]['level'] - 1
    entropy /= np.log2(np.e)  # Convert from base 2.
    data['entropy'] = entropy
    
    if dag.nodes[src]['kind'] == 'env':
        data['prob'] = 31/32 if data['label'] else 1/32

In [None]:
dag.nodes[int(bexpr)]

In [None]:
from diss.tabular import TabularPolicy

In [None]:
TabularPolicy.from_psat(dag, psat=0.92).psat()