# Tree Search to Differentiable First-Order Planning
This is a notebook describing some of my thoughts on how to take a classic Artifical Intelligence planning technique and how to augment it using continuous representations. The resulting design will allow us to backpropagate through a significant portion of the planner.

For this article, we'll be considering the example of a robot operating in a typical supermarket.

We can model a store like that as a fixed number of locations (aisles, front segments, and back segments):

In [1]:
from pprint import pprint
A = 5
aisles = [f'Aisle{a}' for a in range(A)]
front = [f'Front{s}' for s in range(A)]
back = [f'Back{s}' for s in range(A)]
adjacent = (frozenset(zip(aisles, front))
            .union(zip(aisles, back))
            .union((front[i], front[i+1]) for i in range(len(front) - 1))
            .union((back[i], back[i+1]) for i in range(len(back) - 1)))
adjacent = adjacent.union((e, s) for s, e in adjacent)
#pprint(adjacent)

If we want to plan in such an environment, we could implement a simple fixed-depth tree search as follows. It uses depth-first search with early termination, and represents plans as sequences of states. For practical purposes, we could use a wide variety of more efficient search methods (MCTS, A*, etc.). However, the focus of this post isn't the search method, but the state transition function, `step`. For any search method there is an analogous function, so coming up with alternative formulations of the function is relevant to all search methods.

In [2]:
def search(state, max_len):
    successful_plans = []
    global visit_count
    visit_count = 0
    all_plans = list(plans(state, max_len))
    print('visited', visit_count, 'states')
    print(len(all_plans), 'total plans')
    return all_plans

visit_count = 0

def plans(state, max_len):
    global visit_count
    visit_count += 1
    if goal(state):
        yield (state,)
    if max_len > 0:
        for tails in step(state, max_len - 1):
            for tail in tails:
                yield (state,) + tail

def step(state, max_len):
    return (plans(e, max_len) for (s, e) in adjacent
            if s == state)

def goal(state):
    return state == aisles[3]

search(aisles[0], 7)[0]

visited 589 states
26 total plans


('Aisle0', 'Back0', 'Aisle0', 'Back0', 'Back1', 'Back2', 'Back3', 'Aisle3')

It may seem that the `step` function has a relatively small role, and can't affect performance. However, since it can control what gets expanded, it actually can.

For example, a small improvement which prevents needless backtracking:

In [3]:
def step(state, max_len):
    last_act, loc = state
    return (plans((('Move', s, e), e), max_len)
            for s, e in adjacent
            if s == loc
                and last_act != ('Move', e, s) # Don't allow moving backwards.
           )

def goal(state):
    return state[-1] == aisles[3]

search(('Start', aisles[0]), 7); # Remove the semicolon to see the output.

visited 61 states
6 total plans


Partly, this method does well because the branching factor is very small. If we were managing multiple robots, that would make the branching factor much higher, and the search method less efficient.

Here's how we might model that:

In [4]:
from collections import defaultdict
def update_state(state, to_add):
    state = state.copy()
    state.update(to_add)
    return state

R = 8
robots = list(range(R))
def step(state, max_len):
    p = lambda a: update_state(state, a)
    return (plans(p({('last_act', rob): ('Move', s, e), ('loc', rob): e}), max_len)
            for rob in robots
            for s, e in adjacent
            if state[('loc', rob)] == s and state[('last_act', rob)] != ('Move', e, s))

def goal(state):
    return any(state[('loc', r)] == aisles[3] for r in robots)
start_state = defaultdict(tuple, {('loc', r): aisles[r % 3] for r in robots})

search(start_state, 5); # Remove the semicolon to see the output.

visited 910311 states
8260 total plans


Clearly, the method works much worse with a higher branching factor. Fortunately, finding one plan can still be done quickly:

In [5]:
def sample(state, max_len):
    global visit_count
    visit_count = 0
    try:
        result = next(plans(state, max_len))
    except StopIteration:
        result = None
    print('visited', visit_count, 'states')
    return result
sample(start_state, 5); # Remove the semicolon to see the output.

visited 260 states


We'll be taking single plans (or "sampling") from now on to avoid waiting. In this case, we get a non-optimal plan. However, if we used a more advanced search method (e.g. branch-and-bound), we would have similar performance and get a near-optimal plan.

We can also imagine modelling more complex environments, with more actions.
Let's add some carts, which impede movement unless pushed to a destination.

In [6]:
def step(state, max_len):
    p = lambda a: update_state(state, a)
    yield from (plans(p({('last_act', rob): ('Move', s, e), ('loc', rob): e}), max_len)
             for rob in robots
             for s, e in adjacent
             if state[('loc', rob)] == s
                and state[('last_act', rob)] != ('Move', e, s)
                and not any(state[('cart_loc', c)] == e for c in carts))
    yield from (plans(p({('last_act', rob): ('PushCart', e1, e2),
                      ('loc', rob): e1,
                      ('cart_loc', c1): e2}), max_len)
             for rob in robots
             for s1, e1 in adjacent
             for s2, e2 in adjacent
             for c1 in carts
             if state[('loc', rob)] == s1
                and state[('cart_loc', c1)] == e1
                and e1 == s2
                and s1 != e2
                and not any(state[('loc', r2)] == e2 for r2 in robots)
                and not any(state[('cart_loc', c2)] == e2 for c2 in carts))

robots = [0, 1, 2]
carts = [0, 1, 2]
cart_start_state = defaultdict(tuple,
                               {('loc', 0): aisles[0],
                                ('loc', 1): aisles[1],
                                ('loc', 2): aisles[2],
                                ('cart_loc', 0): aisles[3],
                                ('cart_loc', 1): back[3],
                                ('cart_loc', 2): front[3]
                               })
sample(cart_start_state, 7)

visited 61495 states


(defaultdict(tuple,
             {('loc', 0): 'Aisle0',
              ('loc', 1): 'Aisle1',
              ('loc', 2): 'Aisle2',
              ('cart_loc', 0): 'Aisle3',
              ('cart_loc', 1): 'Back3',
              ('cart_loc', 2): 'Front3',
              ('last_act', 0): (),
              ('last_act', 1): ()}),
 defaultdict(tuple,
             {('loc', 0): 'Aisle0',
              ('loc', 1): 'Back1',
              ('loc', 2): 'Aisle2',
              ('cart_loc', 0): 'Aisle3',
              ('cart_loc', 1): 'Back3',
              ('cart_loc', 2): 'Front3',
              ('last_act', 0): (),
              ('last_act', 1): ('Move', 'Aisle1', 'Back1')}),
 defaultdict(tuple,
             {('loc', 0): 'Aisle0',
              ('loc', 1): 'Back2',
              ('loc', 2): 'Aisle2',
              ('cart_loc', 0): 'Aisle3',
              ('cart_loc', 1): 'Back3',
              ('cart_loc', 2): 'Front3',
              ('last_act', 0): (),
              ('last_act', 1): ('Move', 'Back1',

This has some neat properties. The robots automatically "collaborate," since the planner can reason about all of them.

However, unless your computer is extremely fast, you probably noticed a delay to sample a single plan. Adding additional robots would make this even worse. Noticeably, the number of states we're visiting is still small (around 60000) compared to what we were looking at before (910311). `step` has become much more expensive to compute.

`step` has also gotten rather large, even though it has a regular structure.

This particular regular structure is easy to translate into a set of logic clauses (what a coincidence!).
Not only would this allow expressing our actions more succinctly, it would allow some very non-obvious optimizations.
For example, avoiding impossible parts of the search space, searching "backwards" from all goal states, searching abstractly over all robots' actions at once, etc. We could even specify heuristics in the same relational language, making all of these optimizations more effective.

A very simple optimization this allows is using hash-joins in place of our naive outer joins (the nested `for` loops). We'll implement that below.

The last `step` function above could be expressed in an extended Prolog / Datalog syntax as:
```prolog
-last_act(R, _, _, _), last_act(R, "move", S, E), -loc(R, _), loc(R, E) :-
    loc(R, S),
    adjacent(S, E),
    -last_act(R, "move", E, S),
    -cart_loc(C, E)

-last_act(R, _, _, _), last_act(R, "push_cart", E1, E2), -loc(R, _), loc(R, E1),
        -cart_loc(C, _), cart_loc(C, E2) :-
    loc(R, L),
    adjacent(L, E1),
    adjacent(E1, E2),
    L != E2,
    cart_loc(C, E1),
    -loc(R2, E2),
    -cart_loc(C2, E2)
```
This syntax is not ideal for this application, but is the most common syntax for logic programming. For more commonly used syntaxes for planning specifically, I recommend reading about [STRIPS](https://en.wikipedia.org/wiki/STRIPS#A_sample_STRIPS_problem), [Action Description Language](https://en.wikipedia.org/wiki/Action_description_language#Example), or [Planning Domain Definition Language](https://en.wikipedia.org/wiki/Planning_Domain_Definition_Language#Example).
In this syntax, the effects are on the left of the implication sign (`:-`), all positive atoms (e.g. `loc(R, L)` or `L != E2`) to the right of the implication sign must be true, all negative atoms (e.g. `-loc(R2, E2)`) must not be true for any values of the atom's unique variables, and all repeated uses of variables are implicit equality constraints.

Below is how an extremely simple Datalog frontend would have compiled our step function into backend operations defined further below. It isn't necessary to understand these functions in detail. Essentially, each clause is its own subfunction, where all constraints are enforced using joins. Then, effects are generated by iterating over all of the rows.

I've commented each line with what constraint each line is enforcing, the resulting layout of the temporary table (if it changed), and each effect (when relevant).

I've also randomized the order of the next states. The hash function makes the search order non-deterministic, so we may as well make it fully random. This wasn't easy to do with the previous definition of the step function, since we never built a representation of "all next states." However, this is still depth-first search.

In [7]:
def step(state, max_len):
    yield from step_move(state, max_len)
    yield from step_push(state, max_len)

def step_move(state, max_len):
    t = outer_eq_join(gen_index(state['loc'], 1), fixed['adj'][0]) # L == S: R,L,S,E
    t = outer_eq_join(gen_index(t, 0), gen_index(state['last_act'], 0)) # R == Q: R,L,S,E,Q,A,T,Y
    t = constant(t, 'move') # R,L,S,E,Q,A,T,Y,"move"
    t = filter_ne(t, (3, 5), (7, 8))  # E != Y || A != "move"
    t = left_eq_ajoin(gen_index(t, 3), gen_index(state['cart_loc'], 1)) # E == no D
    shuffle_table(t)
    for row in t:
        new_state = state.copy()
        # -last_act(R, _, _, _)
        new_state['last_act'] = filter_eq_key(new_state['last_act'], row, 0, 0)
        append_row(new_state['last_act'], row, (0, 8, 2, 3)) # last_act(R, "move", S, E)
        new_state['loc'] = filter_eq_key(new_state['loc'], row, 0, 0) # -loc(R, _)
        append_row(new_state['loc'], row, (0, 3)) # loc(R, E)
        yield plans(new_state, max_len)

def step_push(state, max_len):
    t = outer_eq_join(gen_index(state['loc'], 1), fixed['adj'][0]) # L == S: R,L,S,E
    t = outer_eq_join(gen_index(t, 3), gen_index(state['cart_loc'], 1)) # E == D: R,L,S,E,C,D
    t = outer_eq_join(gen_index(t, 3), fixed['adj'][0]) # E1 == S2: R,L,S1,E1,C,D,S2,E2
    t = filter_ne(t, (2,), (7,)) # S1 != E2
    t = left_eq_ajoin(gen_index(t, 7), gen_index(state['loc'], 1)) # E2 == no L2
    t = left_eq_ajoin(gen_index(t, 7), gen_index(state['cart_loc'], 1)) # E2 == no D2
    t = constant(t, 'push') # R,L,S1,E1,C,D,S2,E2,"move"
    shuffle_table(t)
    for row in t:
        new_state = state.copy()
        # -last_act(R, _, _, _)
        new_state['last_act'] = filter_eq_key(new_state['last_act'], row, 0, 0)
        append_row(new_state['last_act'], row, (0, 8, 3, 7)) # last_act(R, "push_cart", E1, E2)
        new_state['loc'] = filter_eq_key(new_state['loc'], row, 0, 0) # -loc(R, _)
        append_row(new_state['loc'], row, (0, 3)) # loc(R, E1)
        new_state['cart_loc'] = filter_eq_key(new_state['cart_loc'], row, 0, 4) # -cart_loc(C, _)
        append_row(new_state['cart_loc'], row, (4, 7)) # cart_loc(C, E2)
        yield plans(new_state, max_len)

Below is how this backend works. It's extremely simple, using [hash-joins](https://en.wikipedia.org/wiki/Hash_join) to impose all constraints, and searching forwards (from the start state). In this implementation, "databases" (which are equivalent to states in our search process), are `dict`s mapping from `str`s to "tables." Tables in are `list`s of "rows." Rows in turn are `list`s of "symbols" (`int`s or `str`s). Of particular note, is that the implementation works by creating temporary "indices", which are `dict`s mapping from a specific column's value to each row where the specified column had that value.

For example, a start state in this implementation looks like this:
```python
{'cart_loc': [(0, 'Aisle3'), (1, 'Front3'), (2, 'Back3')],
 'last_act': [(0, 'start', None, None),
              (1, 'start', None, None),
              (2, 'start', None, None)],
 'loc': [(0, 'Aisle0'), (1, 'Aisle1'), (2, 'Aisle2')]}
```

An index over the `1` column of the `loc` table would look like this:
```python
{'Aisle0': [(0, 'Aisle0')],
 'Aisle1': [(1, 'Aisle1')],
 'Aisle2': [(2, 'Aisle2')]}
```

And an index over the `0` column of the `adj` table would look like this:
```python
{'Aisle0': [('Aisle0', 'Front0'), ('Aisle0', 'Back0')],
 'Aisle1': [('Aisle1', 'Back1'), ('Aisle1', 'Front1')],
 'Aisle2': [('Aisle2', 'Front2'), ('Aisle2', 'Back2')],
 'Aisle3': [('Aisle3', 'Back3'), ('Aisle3', 'Front3')],
 'Aisle4': [('Aisle4', 'Front4'), ('Aisle4', 'Back4')],
 'Back0': [('Back0', 'Aisle0'), ('Back0', 'Back1')],
 'Back1': [('Back1', 'Back2'), ('Back1', 'Aisle1'), ('Back1', 'Back0')],
 'Back2': [('Back2', 'Back3'), ('Back2', 'Aisle2'), ('Back2', 'Back1')],
 'Back3': [('Back3', 'Back2'), ('Back3', 'Aisle3'), ('Back3', 'Back4')],
 'Back4': [('Back4', 'Back3'), ('Back4', 'Aisle4')],
 'Front0': [('Front0', 'Aisle0'), ('Front0', 'Front1')],
 'Front1': [('Front1', 'Aisle1'), ('Front1', 'Front0'), ('Front1', 'Front2')],
 'Front2': [('Front2', 'Front3'), ('Front2', 'Front1'), ('Front2', 'Aisle2')],
 'Front3': [('Front3', 'Front2'), ('Front3', 'Front4'), ('Front3', 'Aisle3')],
 'Front4': [('Front4', 'Aisle4'), ('Front4', 'Front3')]}
```

We could then compute all the places all robots can move by performing an outer join between these two indices. The result would be a new table, which (in this case) shows that each robot can move to the front or back of its current aisle.
```python
[(0, 'Aisle0', 'Aisle0', 'Front0'),
 (0, 'Aisle0', 'Aisle0', 'Back0'),
 (1, 'Aisle1', 'Aisle1', 'Back1'),
 (1, 'Aisle1', 'Aisle1', 'Front1'),
 (2, 'Aisle2', 'Aisle2', 'Front2'),
 (2, 'Aisle2', 'Aisle2', 'Back2')]
```

The advantage of this method over a naive join is that it is linear time in all inputs (and the output). For large tables (such as the adjacency table), this is more efficient.

In [8]:
import random

def gen_index(table, i):
    '''Create a "hash index" from a table (a list of rows (which are also lists)).'''
    index = {}
    for row in table:
        index.setdefault(row[i], []).append(row)
    return index

def outer_eq_join(a, b):
    '''Given two hash indices, compute the union where the indices are equal.'''
    if len(a) <= len(b):
        return [a_v + b_v for k, v in a.items()
                for a_v in v
                for b_v in b.get(k, ())]
    else:
        return [a_v + b_v for k, v in b.items()
                for b_v in v
                for a_v in a.get(k, ())]

def left_eq_ajoin(a, b):
    '''Given two hash indices, compute the rows of the left index not in the right index.'''
    return [a_v for k, v in a.items()
            for a_v in v
            if k not in b]

def filter_ne(table, ix, jx):
    '''Given a table, remove all rows where any of a pair of columns are equal.'''
    return [row for row in table if any((row[i] != row[j] for i,j in zip(ix, jx)))]

def constant(table, c):
    '''Add a constant to each row of the table.'''
    return [row + (c,) for row in table]

def filter_eq_key(table, key_row, column, key):
    '''Given a table, remove all rows where a column has a specific key.'''
    return [row for row in table if row[column] != key_row[key]]

def append_row(table, row, columns):
    table.append(tuple([row[c] for c in columns]))

def gen_fixed():
    '''Generate indices for all the fixed tables (just the adjacency table).'''
    adj_list = list(adjacent)
    return {'adj': [gen_index(adj_list, i) for i in (0, 1)]}

def gen_start():
    '''Generate a database representing the start state.'''
    loc_tables = list(enumerate(robot_locations))
    cart_loc_tables = list(enumerate(cart_locations))
    return {
        'loc': loc_tables, 
        'cart_loc': cart_loc_tables, 
        'last_act': [(r, 'start', None, None) for r in range(len(robot_locations))],
    }

def shuffle_table(table):
    random.shuffle(table)

And, here's that implementation in action. Note that the number of states searched ranges from around 10000 to 200000. If we were to use heuristics, we could keep this on the lower side of that number.

In [9]:
robot_locations = [aisles[0], aisles[1], aisles[2]]
cart_locations = [aisles[3], front[3], back[3]]
fixed = gen_fixed()
start = gen_start()
def goal(state):
    return any((r, aisles[3]) in state['loc'] for r in robots)
pprint(sample(start, 7))

visited 78055 states
({'cart_loc': [(0, 'Aisle3'), (1, 'Front3'), (2, 'Back3')],
  'last_act': [(0, 'start', None, None),
               (1, 'start', None, None),
               (2, 'start', None, None)],
  'loc': [(0, 'Aisle0'), (1, 'Aisle1'), (2, 'Aisle2')]},
 {'cart_loc': [(0, 'Aisle3'), (1, 'Front3'), (2, 'Back3')],
  'last_act': [(0, 'start', None, None),
               (2, 'start', None, None),
               (1, 'move', 'Aisle1', 'Front1')],
  'loc': [(0, 'Aisle0'), (2, 'Aisle2'), (1, 'Front1')]},
 {'cart_loc': [(0, 'Aisle3'), (1, 'Front3'), (2, 'Back3')],
  'last_act': [(0, 'start', None, None),
               (2, 'start', None, None),
               (1, 'move', 'Front1', 'Front2')],
  'loc': [(0, 'Aisle0'), (2, 'Aisle2'), (1, 'Front2')]},
 {'cart_loc': [(0, 'Aisle3'), (1, 'Front3'), (2, 'Back3')],
  'last_act': [(0, 'start', None, None),
               (1, 'move', 'Front1', 'Front2'),
               (2, 'move', 'Aisle2', 'Back2')],
  'loc': [(0, 'Aisle0'), (1, 'Front2'), (2, '

So far, everything I've described is well understood. However, it operates under the assumption that our state space is discrete. In particular, our robot locations are either a specific aisle, back, or front. Typically, if we wanted to apply this method to a continuous space, we would discretize the space into a set of cells, maintain weights for each cell, apply our rules as usual over the cells, and add some noise. However, I would like to propose a different method: define the logic over continuous values instead.

Instead of our "symbols" being `int`s or `str`s (which could represent robots, locations, carts, or grid cells), our "symbols" will be vectors in a continous space. Instead of using exact equality to compare these "symbols", we will use a similarity measure. A similar approach was taken for query answering in [End-to-End Differentiable Proving](https://arxiv.org/abs/1705.11040).

In [10]:
import torch

def similarity(key1, key2, sigma=.25):
    '''Measures key similarity, returning 1 for identical keys and 0 for infinitely distant keys.'''
    return torch.exp(-((key1 - key2) ** 2).sum() / sigma)

Since our continuous vectors can't be hashed, we could use [locality sensitive hashing](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions) in place of a `dict`. That is implemented below, but the rest of the notebook uses "full lookup", applying the similarity between the lookup key and every key in the index. This is much slower, but guarantees we don't miss any of the search space.

In [11]:
class LocalitySensitiveHashIndex:
    
    def __init__(self, table, i, num_hashes=32, num_buckets=4):
        self.entries = defaultdict(list)
        self.num_entries = 0
        self.num_hashes = num_hashes
        self.num_buckets = num_buckets
        D = table[0][0][i].shape[0]
        self.projections = torch.randn(num_hashes, D)
        self.r = 0.5
        self.bs = self.r * torch.rand(num_hashes)
        for row, weight in table:
            self.insert(row[i], row, weight)

    def __iter__(self):
        for b in range(self.num_buckets):
            yield from self.entries[(0, b)]

    def __len__(self):
        return self.num_entries

    def hash_key(self, v):
        projected = self.projections.matmul(v)
        offset = projected + self.bs
        return (offset / self.r).floor() % self.num_buckets

    def insert(self, key, row, weight):
        hs = self.hash_key(key)
        self.num_entries += 1
        for ih, h in enumerate(hs):
            self.entries[(ih, int(h.item()))].append((key, row, weight))

    def sample_lookups(self, key, n=8):
        '''Samples a random lookup given the key, returning a list of rows and weights.
        
        This method isn't actually used in this notebook, but illustrates how
        locality sensitive hashing can accelerate lookup (with some caveats).
        '''
        result = []
        for i in range(n):
            hs = self.hash_key(key)
            ih = random.randrange(len(hs))
            h = hs[ih].item()
            entries = self.entries[(ih, h)]
            if entries:
                key2, value, weight = random.choice(entries)
                result.append((key2, value, weight * similarity(key, key2)))
        return result
    
    def full_lookup(self, key):
        '''Looks up all entries given the key, returning a list of rows and weights.'''
        for key2, value, weight in self:
            yield (key2, value, weight * similarity(key, key2))
    
    def best_lookup(self, key):
        '''Finds the row with highest lookup weight using key, returning a row and weight.'''
        hs = self.hash_key(key)
        best_row = None
        best_weight = 0.0
        for ih, h in enumerate(hs):
            for row_key, row, weight in self.entries[(ih, h.item())]:
                w = weight * similarity(key, row_key)
                if w > best_weight:
                    best_row = row
                    best_weight = w
        return best_row, best_weight

We can then replace "the backend", so that our step function (which we "compiled" from first-order logic), can run un-modified. This mostly involves replacing our "hash indices" with the `LocalitySensitiveHashIndex` defined above, and replacing equality comparison with thresholding on our similarity measure. Note that in a more fully developed implementation, we would likely want to learn the parameters of the similarity measure, and an appropriate cutoff.

In [12]:
def gen_index(table, i):
    if len(table) > 0:
        return LocalitySensitiveHashIndex(table, i)
    else:
        return None

# Similarities below this are considered irrelevant.
EQ_CUTOFF = 0.5

def outer_eq_join(a, b):
    '''Given two hash indices, compute the union where the indices are equal.'''
    if a is None or b is None:
        return []
    if len(a) <= len(b):
        return [(a_v + b_v, a_w * b_w)
                for a_k, a_v, a_w in a
                for b_k, b_v, b_w in b.full_lookup(a_k)
                if similarity(a_k, b_k) > EQ_CUTOFF]
    else:
        return [(a_v + b_v, a_w * b_w)
                for b_k, b_v, b_w in b
                for a_k, a_v, a_w in a.full_lookup(b_k)
                if similarity(a_k, b_k) > EQ_CUTOFF]

def left_eq_ajoin(a, b):
    '''Given two hash indices, compute the rows of the left index not in the right index.'''
    if a is None:
        return []
    if b is None:
        return [(a_v, a_w) for k_a, a_v, a_w in a]
    return [(a_v, a_w * (1 - b.best_lookup(k_a)[1])) for k_a, a_v, a_w in a]

def filter_ne(table, ix, jx):
    '''Given a table, remove all rows where any of a pair of columns are equal.'''
    return [(row, w) for row, w in table
            if any((similarity(row[i], row[j]) < EQ_CUTOFF for i,j in zip(ix, jx)))]

def filter_eq_key(table, key_row, column, key):
    '''Given a table, remove all rows where a column has a specific key.'''
    return [row for row in table if similarity(row[0][column], key_row[0][key]) < EQ_CUTOFF]

def append_row(table, row, columns):
    row, weight = row
    table.append((tuple([row[c] for c in columns]), weight))

# In a real implementation, we would probably support non-vectors in the backend.
CONSTANT_VECTORS = {
    'start': torch.tensor([0.0, 0.0, 0.0]),
    'move': torch.tensor([1.0, 1.0, 1.0]),
    'push': torch.tensor([2.0, 2.0, 2.0]),
}

def constant(table, c):
    '''Add a constant to each row of the table.'''
    return [(row + (CONSTANT_VECTORS[c],), w) for row, w in table]

WEIGHT_CUTOFF = 0.15

def shuffle_table(table):
    random.shuffle(table)
    # Discard actions that are very unlikely to succeed.
    table[:] = [(row, w) for row, w in table if w > WEIGHT_CUTOFF]

aisles_t = [torch.tensor([i, 0.0]) for i in range(A)]
back_t = [torch.tensor([i, -1.0]) for i in range(A)]
front_t = [torch.tensor([i, 1.0]) for i in range(A)]

def gen_fixed():
    '''Generate indices for all the fixed tables (just the adjacency table).'''
    adjacent = (list(zip(aisles_t, front_t)) +
                list(zip(aisles_t, back_t)) +
                [(front_t[i], front_t[i+1]) for i in range(len(front_t) - 1)] +
                [(back_t[i], back_t[i+1]) for i in range(len(back_t) - 1)])
    adjacent = adjacent + [(e, s) for s, e in adjacent]
    adj_table = [(row, 1.0) for row in adjacent]
    return {'adj': [gen_index(adj_table, i) for i in (0, 1)]}

def gen_start():
    '''Generate a database representing the start state.'''
    loc_table = [((torch.tensor([float(rob)]), loc), 1.0)
                 for rob, loc in enumerate(robot_locations)]
    cart_loc_table = [((torch.tensor([float(cart)]), loc), 1.0)
                      for cart, loc in enumerate(cart_locations)]
    return {
        'loc': loc_table, 
        'cart_loc': cart_loc_table, 
        'last_act': [((rob, CONSTANT_VECTORS['start'], loc, loc), 1.0)
                     for (rob, loc), _ in loc_table],
    }

We can then search with the same step function we defined before. Since there isn't any noise in the planning system, most vectors match exactly. However, since our similarity measure has infinite support, we still end up with weights less than 1 due to "collisions" with nearby objects. For example, moving next to a cart decreases the weight / confidence in that action.

In [13]:
robot_locations = [aisles_t[1], aisles_t[2]]
cart_locations = [aisles_t[3], front_t[3], back_t[3]]
fixed = gen_fixed()
start = gen_start()
def goal(state):
    return any(similarity(loc, aisles_t[3]) > EQ_CUTOFF for (rob, loc), w in state['loc'])

pprint(sample(start, 7))

visited 3335 states
({'cart_loc': [((tensor([0.]), tensor([3., 0.])), 1.0),
               ((tensor([1.]), tensor([3., 1.])), 1.0),
               ((tensor([2.]), tensor([ 3., -1.])), 1.0)],
  'last_act': [((tensor([0.]),
                 tensor([0., 0., 0.]),
                 tensor([1., 0.]),
                 tensor([1., 0.])),
                1.0),
               ((tensor([1.]),
                 tensor([0., 0., 0.]),
                 tensor([2., 0.]),
                 tensor([2., 0.])),
                1.0)],
  'loc': [((tensor([0.]), tensor([1., 0.])), 1.0),
          ((tensor([1.]), tensor([2., 0.])), 1.0)]},
 {'cart_loc': [((tensor([0.]), tensor([3., 0.])), 1.0),
               ((tensor([1.]), tensor([3., 1.])), 1.0),
               ((tensor([2.]), tensor([ 3., -1.])), 1.0)],
  'last_act': [((tensor([0.]),
                 tensor([0., 0., 0.]),
                 tensor([1., 0.]),
                 tensor([1., 0.])),
                1.0),
               ((tensor([1.]),
             

Ideally, we would also like to be able to handle noise in the planning system. Below, we offset the endpoints throughout the adjacency table using a normal distribution with a standard deviation of 5% of the cell. We also offset our robot and cart starting locations by the same distribution, and planning still works. In the real world, this might correspond to a 10cm difference between our perception system's expectation of an object's location, and our navigation system's response to it.

In [14]:
def noise(length=2):
    return 0.05 * torch.randn(length)

def gen_fixed():
    '''Generate indices for all the fixed tables (just the adjacency table).'''
    adjacent = ([(noise() + a, noise() + f) for a, f in zip(aisles_t, front_t)] +
                [(noise() + a, noise() + b) for a, b in zip(aisles_t, back_t)] +
                [(noise() + front_t[i], noise() + front_t[i+1]) for i in range(len(front_t) - 1)] +
                [(noise() + back_t[i], noise() + back_t[i+1]) for i in range(len(back_t) - 1)])
    adjacent = adjacent + [(e, s) for s, e in adjacent]
    adj_table = [(row, 1.0) for row in adjacent]
    return {'adj': [gen_index(adj_table, i) for i in (0, 1)]}

robot_locations = [noise() + aisles_t[1], noise() + aisles_t[2]]
cart_locations = [noise() + aisles_t[3], noise() + front_t[3], noise() + back_t[3]]
fixed = gen_fixed()
start = gen_start()
pprint(sample(start, 7))

visited 643 states
({'cart_loc': [((tensor([0.]), tensor([3.0019, 0.0041])), 1.0),
               ((tensor([1.]), tensor([2.9095, 1.0383])), 1.0),
               ((tensor([2.]), tensor([ 2.9711, -1.0211])), 1.0)],
  'last_act': [((tensor([0.]),
                 tensor([0., 0., 0.]),
                 tensor([0.8795, 0.0142]),
                 tensor([0.8795, 0.0142])),
                1.0),
               ((tensor([1.]),
                 tensor([0., 0., 0.]),
                 tensor([1.9720, 0.0879]),
                 tensor([1.9720, 0.0879])),
                1.0)],
  'loc': [((tensor([0.]), tensor([0.8795, 0.0142])), 1.0),
          ((tensor([1.]), tensor([1.9720, 0.0879])), 1.0)]},
 {'cart_loc': [((tensor([0.]), tensor([3.0019, 0.0041])), 1.0),
               ((tensor([1.]), tensor([2.9095, 1.0383])), 1.0),
               ((tensor([2.]), tensor([ 2.9711, -1.0211])), 1.0)],
  'last_act': [((tensor([1.]),
                 tensor([0., 0., 0.]),
                 tensor([1.9720, 0.0879]),

Hopefully this notebook makes some of the details of my recent research ideas clear. Thanks for reading all the way to the end! If you have any questions, feel free to email me at kzentner@usc.edu. If you're at USC, I would also be happy to talk with potential collaborators on research topics in this direction. My desk is in the back of RTH 422.