# Demonstration Informed Specification Search: Experiment

Let's take a look at how the DISS algorithm can search for specifications by leveraging expert demonstrations. 
We'll focus on learning DFAs in this case, but note that this approach is not confined to any specific concept class. To start consider an agent operating in the following stochastic gridworld.

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/enter_lava_augmented_1.svg"
         style="height: 20em;"
     />
</figure>

## Agent Actions

The agent can attempt to move up, down, left, or right as illustrated below.

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/enter_lava_augmented_2.svg"
         style="height: 20em;"
     />
</figure>

## Stochastic Transitions

However, there is some small probability that the agent will slip downward do to wind!

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/enter_lava_augmented_3.svg"
         style="height: 20em;"
     />
</figure>



Let's assume the agent's task can be described in terms of the color's of the tiles. **What was the agent trying to do?**


## Probably avoiding the red tiles

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/enter_lava_augmented_4.svg"
         style="height: 20em;"
     />
</figure>

## Probably trying to reach yellow tile

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/enter_lava_augmented_5.svg"
         style="height: 20em;"
     />
</figure>

In this notebook, we will illustrate learning task representations (in the form of Deterministic Finite Automata, i.e. DFA) that can be learned incrementally and describe temporal tasks.

In particular, we shall consider a variation of the following gridworld from [this](https://mjvc.me/DISS/#/77) slide deck: Here the agent's task is a composition of three subtasks.

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/example_domain_1.svg"
         style="height: 20em;"
     />
</figure>

Where each subtask is a regular language represented as a DFA.

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/example_domain_2_1.svg"
         style="height: 20em;"
     />
</figure>

Further, we shall assume that the first two subtasks are a-priori known (say due to learning in another workspace), 
<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/example_domain_1_2.svg"
         style="height: 20em;"
     />
</figure>

and our task is to learn the third task given a partial demonstration.
<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>


# Preamble

In [None]:
from functools import lru_cache

import funcy as fn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from bidict import bidict
from IPython.display import Image, display
import networkx as nx
import pydot

from collections import Counter

import dfa
from dfa.utils import find_subset_counterexample, find_equiv_counterexample
from dfa_identify import find_dfa, find_dfas

from diss.product_mc import ProductMC
from diss.dfa_concept import DFAConcept
from diss.domains.gridworld_naive import GridWorldNaive as World
from diss.domains.gridworld_naive import GridWorldState as State
from diss import search, LabeledExamples, GradientGuidedSampler, ConceptIdException
from pprint import pprint
from itertools import combinations
from tqdm import tqdm_notebook
from tqdm.notebook import trange
from IPython.display import clear_output
from IPython.display import HTML as html_print

sns.set_context('paper')
sns.set_style('whitegrid')

## Let's first visualize our gridworld and a demonstration within the gridworld.

In [None]:
COLOR_ALIAS = {
    'yellow': '#ffff00', 'brown': '#ffb081',
    'red': '#ff8b8b', 'blue': '#afafff', 'green' : '#67f7a1'
}


def tile(color='black'):
    color = COLOR_ALIAS.get(color, color)
    s = '&nbsp;'*4
    return f"<text style='border: solid 1px;background-color:{color}'>{s}</text>"


def ap_at_state(x, y, world):
    """Use sensor to create colored tile."""
    if (x, y) in world.overlay:
        color = world.overlay[(x,y)]

        if color in COLOR_ALIAS.keys():
            return tile(color)
    return tile('white')

def print_map(world):
    """Scan the board row by row and print colored tiles."""
    order = range(1, world.dim + 1)
    for y in order:
        chars = (ap_at_state(x, y, world) for x in order)
        display(html_print('&nbsp;'.join(chars)))

def print_trc(trc, world, idx=0):
    states = [s for s, kind in trc if kind == 'ego']
    actions = [s.action for s, kind in trc if kind == 'env']

    obs = [ap_at_state(pos.x, pos.y, world) for pos in states]
    display(
        html_print(f'trc {idx}:&nbsp;&nbsp;&nbsp;' + ''.join(''.join(x) for x in zip(obs, actions)) + '\n')
    ) 

In [None]:
gw = World(
    dim=3,
    start=State(x=3, y=1),
    overlay={
      (1, 1): 'yellow',
      (1, 2): 'green',
      (1, 3): 'green',
      (2, 3): 'red',
      (3, 2): 'blue',
      (3, 3): 'blue',
    }
)

demos = [[
   (State(3, 1), 'ego'),
   (State(3, 1, '←'), 'env'),
   (State(3, 2), 'ego'),
   (State(3, 2, '←'), 'env'),
   (State(2, 2), 'ego'),
   (State(2, 2, '←'), 'env'),
   (State(1, 2), 'ego'),
   (State(1, 2, '↑'), 'env'),
   (State(1, 1), 'ego'),
]]

print_map(gw)
print_trc(demos[0], gw)

Compare against our target domain:

<figure style="padding: 1em; background: #191919;">
    <img src="http://mjvc.me/DISS/imgs/example_domain_3_2.svg"
         style="height: 20em;"
     />
</figure>

# Search procedure

Now, we can define a set of expert demonstrations for this gridworld to guide our specification search procedure.

Let's start with some very simple base examples to warm-start our specification search process. We want to synthesize a spec that's consistent with the observed evidence thus far:

In [None]:
@fn.memoize(key_func=lambda c, t: c.dfa)
def to_chain(c, t):
    chain = ProductMC.construct(
        concept=c, tree=t, dyn=gw, max_depth=9, psat=0.8
    )
    return chain


def sampler_factory(demos):
    return GradientGuidedSampler.from_demos(
        demos=demos,
        to_chain=to_chain,
    )

base_examples = LabeledExamples(
    positive=[
        ('yellow',),
        ('yellow', 'yellow'),
    ],
    negative=[
        (), ('red',), ('red', 'red'),
        ('red', 'yellow'), ('yellow', 'red'),
        ('yellow', 'red', 'yellow'),
        ('yellow', 'yellow', 'red'),
    ]
)

Going from the partial spec to a full spec :

In [None]:
def partial_dfa(inputs):
    def transition(s, c):
        if c == 'red':
            return s | 0b01
        elif c == 'yellow':
            return s | 0b10
        return s

    return dfa.DFA(
        start=0b00,
        inputs=inputs,
        label=lambda s: s == 0b10,
        transition=transition
    )

def trace(path):
    return tuple(x for x in map(gw.sensor, path) if x != 'white')

Now, we can outline the machinery for the search process itself. We use the solution procedure in the DFA identification algorithm to synthesize a minimal DFA (in both states and non-stuttering edges) that is consistent with the observed examples to this point. 

In [None]:
@fn.memoize(key_func=lambda accepting, rejecting, alphabet, order_by_stutter: hash((accepting, rejecting)))
def find_dfas2(accepting, rejecting, alphabet, order_by_stutter):
    dfas = find_dfas(accepting, rejecting, alphabet=alphabet, order_by_stutter=order_by_stutter)
    dfas = fn.take(10, dfas)
    return dfas

In [None]:
def subset_check_wrapper(dfa_candidate):
    partial = partial_dfa(dfa_candidate.inputs)
    ce = find_subset_counterexample(dfa_candidate, partial)
    return ce is None


ALPHABET = frozenset({'red', 'yellow', 'blue', 'green'})


@fn.memoize
def subset_cegis(data):
    global base_examples
    for i in range(20):
        mydfa = find_dfa(data.positive, data.negative, order_by_stutter=True) 
        if mydfa is None:
            raise ConceptIdException
        partial = partial_dfa(mydfa.inputs)
        ce = find_subset_counterexample(mydfa, partial)
        if ce is None:
            break
        base_examples @= LabeledExamples(negative=[ce])
        data @= LabeledExamples(negative=[ce])

        partial = partial_dfa(mydfa.inputs)
        for k, lbl in enumerate(partial.transduce(ce)):
            prefix = ce[:k]
            if not lbl:
                base_examples @= LabeledExamples(negative=[prefix])
                data @= LabeledExamples(negative=[prefix])
    return data


def to_concept(data, skip_trace=False):
    global base_examples
    
    if not skip_trace:
        data = LabeledExamples(
            positive = frozenset([trace(x) for x in data.positive]),
            negative = frozenset([trace(x) for x in data.negative]),
        )
    data @= base_examples
    data = subset_cegis(data)

    concept = DFAConcept.from_examples(data, gw.sensor, subset_check_wrapper, alphabet=ALPHABET, find_dfas=find_dfas2) 
    return concept



In [None]:
from diss.dfa_concept import remove_stutter
from collections import defaultdict

# adapted from the dfa library
def get_dot(dfa_):
    dfa_dict, init = dfa.dfa2dict(dfa_)
    remove_stutter(dfa_dict)
    g = pydot.Dot(rankdir="LR")

    nodes = {}
    for i, (k, (v, _)) in enumerate(dfa_dict.items()):
        shape = "doublecircle" if v else "circle"
        nodes[k] = pydot.Node(i+1, label=f"{k}", shape=shape)
        g.add_node(nodes[k])

    edges = defaultdict(list)
    for start, (_, transitions) in dfa_dict.items():        
        for action, end in transitions.items():
            color = COLOR_ALIAS[str(action)]
            edges[start, end].append(color)
    
    init_node = pydot.Node(0, shape="point", label="")
    g.add_node(init_node)
    g.add_edge(pydot.Edge(init_node, nodes[init]))

    for (start, end), colors in edges.items():
        for color in colors:
            g.add_edge(pydot.Edge(nodes[start], nodes[end], label='⬛', fontcolor=color))
            
    return g

def view_pydot(pdot):
    plt = Image(pdot.create_png())
    display(plt)

# Simulated Annealed + SGGS

In [None]:
example_sampler = sampler_factory(demos)

unlabeled = set()
dfa_dist = {}
n_iters = 5
n_sggs_trials = 5
t = 0
labeled = LabeledExamples()
for i in trange(n_iters, desc="Number of times to restart ----"):
    prev_energy = float('inf')
    for j in trange(n_sggs_trials, desc='SGGS + Simulated Annealing', leave=False):
        t += 1
        # 1. Set temperature.
        temp = 10*(1 - t / (n_iters*n_sggs_trials)) + 0.01
        
        # 2. Pick Neighbor.
        try:
            concept = to_concept(labeled, skip_trace=True)
            new_data, metadata = example_sampler(concept)
            dfa_dist[concept.dfa] = metadata['surprisal'] + concept.size / 100
            new_data = LabeledExamples(
                positive=frozenset(map(trace, new_data.positive)),
                negative=frozenset(map(trace, new_data.negative)),
            )
            unlabeled |= new_data.positive | new_data.negative
        except ConceptIdException:
            break

        # Compute neighbor energy and energy difference.
        energy = metadata['surprisal'] + concept.size / 100
        dE = energy - prev_energy
            
        # Accept/Reject based on energy delta.
        if (dE > 0) and (np.exp(-dE/temp) < np.random.rand()):
            energy = prev_energy
        else:
            labeled @= new_data
            #view_pydot(get_dot(concept.dfa))

        prev_energy= energy
        
    # 2. Compute CDF + Normalizer
    sorted_dfas = sorted(list(dfa_dist), key=lambda x: dfa_dist[x])
    Z1 = sum(np.exp(-x) for x in dfa_dist.values())
    cdf = [0]
    r = 0
    for k, dfa_ in enumerate(sorted_dfas):
        pdf = np.exp(-dfa_dist[dfa_]) / Z1
        cdf.append(cdf[-1] + pdf)
        if cdf[-1] > .8:
            r = k

    # 3. Compute distiguishing strings for top 80%.
    for dfa1, dfa2 in combinations(sorted_dfas[:r], 2):
        ce = find_equiv_counterexample(dfa1, dfa2)
        unlabeled.add(ce)

    # 4. Compute current support's belief on unlabeled strings of interest.
    weighted_words = defaultdict(lambda: 0)
    beta =  (1 - i / (n_iters)) + 0.01
    Z2 = sum(np.exp(-x / beta) for x in dfa_dist.values())
    for word in unlabeled:
        for dfa_, energy in dfa_dist.items():
            pdfa = np.exp(-energy / beta) / Z2
            weighted_words[word] += pdfa * dfa_.label(word)

    # 5. Restart based on marginalizing over current concept class.
    positive, negative = set(), set()
    for x, weight in weighted_words.items():
        confidence = 2*(weight - 0.5 if weight > 0.5 else 0.5 - weight)
        if np.random.rand() > confidence:
            continue
        if weight < 0.5:
            negative.add(x)
        elif weight > 0.5:
            positive.add(x)
    labeled = LabeledExamples(positive, negative)  


    #continue  # Comment out to show live CDF
    # Plot CDF of distribution.
    sns.lineplot(x=list(range(len(cdf))), y=cdf)
    plt.xlabel('DFA index (sorted by probability mass)')
    plt.ylabel('CDF')
    plt.show()


print('=====================================')
print('   Predicting Labeled Examples       ')
print('=====================================')
positive, negative = set(), set()
for x, weight in weighted_words.items():
    if np.random.rand() > confidence:
        continue
    if weight < 0.2:
        negative.add(x)
    elif weight > 0.8:
        positive.add(x)
labeled = LabeledExamples(positive, negative)
print(labeled)


print('=====================================')
print('          Top 50% of DFAs            ')
print('=====================================')
for dfa_, cd in zip(sorted_dfas, cdf[1:]):
    if cd > 0.8:
        break
    view_pydot(get_dot(dfa_))
    print(f'probability = {np.exp(-dfa_dist[dfa_]) / Z1:.4}')