# Demonstration Informed Specification Search: Experiment

Let's take a look at how the DISS algorithm can search for specifications by leveraging expert demonstrations. 
We'll focus on learning DFAs in this case, but note that this approach is not confined to any specific concept class. To start consider an agent operating in the following stochastic gridworld.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_1.svg"
         style="height: 20em;"
     />
</figure>

## Agent Actions

The agent can attempt to move up, down, left, or right as illustrated below.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_2.svg"
         style="height: 20em;"
     />
</figure>

## Stochastic Transitions

However, there is some small probability that the agent will slip downward do to wind!

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_3.svg"
         style="height: 20em;"
     />
</figure>



Let's assume the agent's task can be described in terms of the color's of the tiles. **What was the agent trying to do?**


## Probably avoiding the red tiles

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_4.svg"
         style="height: 20em;"
     />
</figure>

## Probably trying to reach yellow tile

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/enter_lava_augmented_5.svg"
         style="height: 20em;"
     />
</figure>

In this notebook, we will illustrate learning task representations (in the form of Deterministic Finite Automata, i.e. DFA) that can be learned incrementally and describe temporal tasks.

In particular, we shall consider a variation of the following gridworld from [this](https://mjvc.me/DISS/#/77) slide deck: Here the agent's task is a composition of three subtasks.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_1.svg"
         style="height: 20em;"
     />
</figure>

Where each subtask is a regular language represented as a DFA.

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_2_1.svg"
         style="height: 20em;"
     />
</figure>

Further, we shall assume that the first two subtasks are a-priori known (say due to learning in another workspace), 
<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_1_2.svg"
         style="height: 20em;"
     />
</figure>

and our task is to learn the third task given a partial demonstration.
<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>


# Preamble

In [None]:
from functools import lru_cache

import attr
import funcy as fn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from bidict import bidict
from IPython.display import Image, display
import networkx as nx
import pydot

from collections import Counter

import dfa
from dfa.utils import find_subset_counterexample, find_equiv_counterexample
from dfa_identify import find_dfa, find_dfas

from diss.product_mc import ProductMC
from diss.dfa_concept import DFAConcept
from diss.domains.gridworld_naive import GridWorldNaive as World
from diss.domains.gridworld_naive import GridWorldState as State
from diss import search, LabeledExamples, GradientGuidedSampler, ConceptIdException
from pprint import pprint
from itertools import combinations
from tqdm import tqdm_notebook
from tqdm.notebook import trange
from IPython.display import clear_output
from IPython.display import HTML as html_print

sns.set_context('paper')
sns.set_style('whitegrid')

## Let's first visualize our gridworld and a demonstration within the gridworld.

In [None]:
COLOR_ALIAS = {
    'yellow': '#ffff00', 'brown': '#ffb081',
    'red': '#ff8b8b', 'blue': '#afafff', 'green' : '#67f7a1'
}


def tile(color='black'):
    color = COLOR_ALIAS.get(color, color)
    s = '&nbsp;'*4
    return f"<text style='border: solid 1px;background-color:{color}'>{s}</text>"


def ap_at_state(x, y, world):
    """Use sensor to create colored tile."""
    if (x, y) in world.overlay:
        color = world.overlay[(x,y)]

        if color in COLOR_ALIAS.keys():
            return tile(color)
    return tile('white')

def print_map(world):
    """Scan the board row by row and print colored tiles."""
    order = range(1, world.dim + 1)
    for y in order:
        chars = (ap_at_state(x, y, world) for x in order)
        display(html_print('&nbsp;'.join(chars)))

def print_trc(trc, world, idx=0):
    states = [s for s, kind in trc if kind == 'ego']
    actions = [s.action for s, kind in trc if kind == 'env']

    obs = [ap_at_state(pos.x, pos.y, world) for pos in states]
    display(
        html_print(f'trc {idx}:&nbsp;&nbsp;&nbsp;' + ''.join(''.join(x) for x in zip(obs, actions)) + '\n')
    ) 

In [None]:
gw = World(
    dim=3,
    start=State(x=3, y=1),
    overlay={
      (1, 1): 'yellow',
      (1, 2): 'green',
      (1, 3): 'green',
      (2, 3): 'red',
      (3, 2): 'blue',
      (3, 3): 'blue',
    }
)

demos = [[
   (State(3, 1), 'ego'),
   (State(3, 1, '←'), 'env'),
   (State(3, 2), 'ego'),
   (State(3, 2, '←'), 'env'),
   (State(2, 2), 'ego'),
   (State(2, 2, '←'), 'env'),
   (State(1, 2), 'ego'),
   (State(1, 2, '↑'), 'env'),
   (State(1, 1), 'ego'),
]]

print_map(gw)
print_trc(demos[0], gw)

Compare against our target domain:

<figure style="padding: 1em; background: #494949;">
    <img src="imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>

# Search procedure

Now, we can define a set of expert demonstrations for this gridworld to guide our specification search procedure.

Let's start with some very simple base examples to warm-start our specification search process. We want to synthesize a spec that's consistent with the observed evidence thus far:

Going from the partial spec to a full spec :

In [None]:
def partial_dfa(inputs):
    def transition(s, c):
        if c == 'red':
            return s | 0b01
        elif c == 'yellow':
            return s | 0b10
        return s

    return dfa.DFA(
        start=0b00,
        inputs=inputs,
        label=lambda s: s == 0b10,
        transition=transition
    )

def ignore_white(path):
    return tuple(x for x in path if x != 'white')

Now, we can outline the machinery for the search process itself. We use the solution procedure in the DFA identification algorithm to synthesize a minimal DFA (in both states and non-stuttering edges) that is consistent with the observed examples to this point. 

In [None]:
@fn.memoize(key_func=lambda accepting, rejecting, alphabet, order_by_stutter: hash((accepting, rejecting)))
def find_dfas2(accepting, rejecting, alphabet, order_by_stutter):
    dfas = find_dfas(accepting, rejecting, alphabet=alphabet, order_by_stutter=order_by_stutter)
    dfas = fn.take(10, dfas)
    return dfas

In [None]:
def subset_check_wrapper(dfa_candidate):
    partial = partial_dfa(dfa_candidate.inputs)
    ce = find_subset_counterexample(dfa_candidate, partial)
    return ce is None


ALPHABET = frozenset({'red', 'yellow', 'blue', 'green'})

base_examples = LabeledExamples(
    positive=[
        ('yellow',),
        ('yellow', 'yellow'),
    ],
    negative=[
        (), ('red',), ('red', 'red'),
        ('red', 'yellow'), ('yellow', 'red'),
        ('yellow', 'red', 'yellow'),
        ('yellow', 'yellow', 'red'),
    ]
)


@fn.memoize
def subset_cegis(data):
    global base_examples

    for i in range(20):
        mydfa = find_dfa(data.positive, data.negative, order_by_stutter=True) 
        if mydfa is None:
            raise ConceptIdException
        partial = partial_dfa(mydfa.inputs)
        ce = find_subset_counterexample(mydfa, partial)
        if ce is None:
            break
        base_examples @= LabeledExamples(negative=[ce])
        data @= LabeledExamples(negative=[ce])

        for k, lbl in enumerate(partial.transduce(ce)):
            prefix = ce[:k]
            if not lbl:
                base_examples @= LabeledExamples(negative=[prefix])
                data @= LabeledExamples(negative=[prefix])
    return data


def to_concept(data):
    global base_examples
    
    data = data.map(ignore_white) @ base_examples
    data = subset_cegis(data)

    concept = DFAConcept.from_examples(data, subset_check_wrapper, alphabet=ALPHABET, find_dfas=find_dfas2) 
    # Adjust size to account for subset information (equiv. surprise more important).
    return attr.evolve(concept, size=concept.size / 100)  # /100 -> weight surprise more.



In [None]:
from diss.dfa_concept import remove_stutter
from collections import defaultdict

# adapted from the dfa library
def get_dot(dfa_):
    dfa_dict, init = dfa.dfa2dict(dfa_)
    remove_stutter(dfa_dict)
    g = pydot.Dot(rankdir="LR")

    nodes = {}
    for i, (k, (v, _)) in enumerate(dfa_dict.items()):
        shape = "doublecircle" if v else "circle"
        nodes[k] = pydot.Node(i+1, label=f"{k}", shape=shape)
        g.add_node(nodes[k])

    edges = defaultdict(list)
    for start, (_, transitions) in dfa_dict.items():        
        for action, end in transitions.items():
            color = COLOR_ALIAS[str(action)]
            edges[start, end].append(color)
    
    init_node = pydot.Node(0, shape="point", label="")
    g.add_node(init_node)
    g.add_edge(pydot.Edge(init_node, nodes[init]))

    for (start, end), colors in edges.items():
        for color in colors:
            g.add_edge(pydot.Edge(nodes[start], nodes[end], label='⬛', fontcolor=color))
            
    return g

def view_pydot(pdot):
    plt = Image(pdot.create_png())
    display(plt)

# Simulated Annealed + SGGS

In [None]:
from diss import diss

In [None]:
@fn.memoize(key_func=lambda c, t, psat: c.dfa)
def to_chain(c, t, psat):
    chain = ProductMC.construct(
        concept=c, tree=t, dyn=gw, max_depth=9,
        psat=0.8, sensor=gw.sensor
    )
    return chain

In [None]:
n_iters = 50


dfa_search = diss(
    demos=demos,
    to_concept=to_concept,
    to_chain=to_chain,
    competency=lambda *_: 0.8,
    lift_path=lambda x: ignore_white(map(gw.sensor, x)),
    n_iters=n_iters,
    reset_period=5,
)

concept2energy = {}
for _, (data, concept, metadata) in zip(trange(n_iters, desc='DISS'), dfa_search):
    concept2energy[concept] = metadata['energy']
    poi = metadata['poi']
    view_pydot(get_dot(concept.dfa))

In [None]:
sorted_concepts = sorted(list(concept2energy), key=concept2energy.get)
pmf = np.array([np.exp(-concept2energy[c]) for c in sorted_concepts])
pmf /= pmf.sum()
cmf = np.cumsum(pmf)
idx = (cmf < 0.8).sum()  # 80% of probability mass concentrated below this point.

for p, c in zip(pmf, sorted_concepts[:idx]):
    print(f'prob = {p:.3}')
    print(f'energy = {concept2energy[c]:.3}')
    print(f'size = {c.size:.3}')
    view_pydot(get_dot(c.dfa))

# Plot CDF of distribution.
sns.lineplot(x=list(range(len(cmf))), y=cmf)
plt.xlabel('DFA index (sorted by probability mass)')
plt.ylabel('CMF')
plt.show()

In [None]:
# 4. Compute current support's belief on unlabeled strings of interest.
p_accept = {}
for word in poi:
    votes = np.array([(word in c) for c in sorted_concepts])
    p_accept[word] = pmf @ votes
    
sorted_poi = sorted(poi, key=lambda x: -p_accept[x])

for word in sorted_poi:
    print(f'{word}'.ljust(50) + f'{p_accept[word]:.2}')
    print('-' * 54)

In [None]:
accepting, rejecting = set(), set()
for word in poi:
    belief = p_accept[word]
    if belief > 0.8:
        accepting.add(word)
    elif belief < 0.1:
        rejecting.add(word)
data = LabeledExamples(accepting, rejecting)
view_pydot(get_dot(to_concept(data).dfa))

# Enumeration baseline

In [None]:
def sampler_factory(demos):
    return GradientGuidedSampler.from_demos(
        demos=demos,
        to_chain=to_chain,
        competency=lambda *_: 0.8,
    )

In [None]:
to_chain.invalidate_all()

In [None]:
from dfa.utils import minimize

def enumerate_dfas():
    data = LabeledExamples(
        positive=[
            ('yellow',),
            ('yellow', 'yellow'),
        ],
        negative=[
            (), ('red',), ('red', 'red'),
            ('red', 'yellow'), ('yellow', 'red'),
            ('yellow', 'red', 'yellow'),
            ('yellow', 'yellow', 'red'),
        ]
    )

    # CEGIS loop to add constraints to enforce subsets.
    for i in range(20):
        tests = fn.take(5, find_dfas(
            data.positive,
            data.negative,
            order_by_stutter=True,
            allow_unminimized=True,
        ))
        new_data = LabeledExamples()
        for test in tests:
            assert test is not None
            partial = partial_dfa(test.inputs)
            ce = find_subset_counterexample(test, partial)
            if ce is not None:
                new_data @= LabeledExamples(negative=[ce])
            if new_data.size == 0:
                break

        dfas = find_dfas(
            data.positive,
            data.negative,
            order_by_stutter=True,
            alphabet=ALPHABET,
            allow_unminimized=True,
        )
        
        yield from map(minimize, filter(subset_check_wrapper, dfas))


In [None]:
sggs = sampler_factory(demos)

In [None]:
for i, d in enumerate(fn.distinct(enumerate_dfas())):
    concept = DFAConcept.from_dfa(d)
    print(f'             {i}                ')
    energy = sggs(concept)[1]['surprisal'] + concept.size / 100
    
    print(f'{energy=:.3}')
    print('--------------------------------')
    if i > 100:
        break
    view_pydot(get_dot(d))