# DyND Callable MultiDispatch in Python

This is a notebook which goes with the [DyND Callable Multi-Dispatch Design](callable-multidispatch-design.md).
It's an implementation of the described signature subset construction, with the aim to illustrate the algorithm
clearly. We will build up an implementation in stages, starting with simple cases and gradually escalating the
complexity to account for more of the pattern matching capabilities that DyND types have.

In [29]:
import dynd
from dynd import nd, ndt
from pprint import pprint
print("libdynd", dynd.__libdynd_version__, ", dynd-python", dynd.__version__)

libdynd v0.7.2-975-ga53df83 , dynd-python v0.7.2-278-gbbf2ace


In [108]:
import functools
class reprwrapper(object):
    def __init__(self, reprstr, func):
        self._reprstr = reprstr
        self._func = func
        functools.update_wrapper(self, func)
    def __call__(self, *args, **kw):
        return self._func(*args, **kw)
    def __repr__(self):
        return self._reprstr

def withrepr(reprstr):
    def _wrap(func):
        return reprwrapper(reprstr, func)
    return _wrap

Let's start with a few scalar signatures.

In [109]:
sigs = [ndt.type("(int8, int8) -> int8"),
        ndt.type("(int16, int16) -> int16"),
        ndt.type("(float32, float32) -> float32")]

In [110]:
sigs_working = [s.pos_types for s in sigs]; pprint(sigs_working)

[[ndt.type('int8'), ndt.type('int8')],
 [ndt.type('int16'), ndt.type('int16')],
 [ndt.type('float32'), ndt.type('float32')]]


For this initial case, we don't need to track additional working state, so that's done. We do need
a set of equations that need to be satisfied for a match. The only form of equation at this point is
`S.TInp[i].matches(TInp[i])`, so we'll just store the `i`.

In [111]:
equations = [0, 1]

For each equation, we need to generate a set of edges and a filtered version of `sigs_working` for each edge.
Let's write some functions for this. We're starting really simple, so that we can build up the complexity of
the final algorithm incrementally.

In [112]:
def child_edges(sigs_working, i):
    return {sig[i].id for sig in sigs_working if sig is not None}

def filter_sigs(sigs_working, i, tid):
    return [sig if sig and sig[i].id == tid else None for sig in sigs_working]

Now let's generate our two candidate decision tree nodes, one for each equation. In each case, we get a
dict which maps us from the type id for the argument number of the equation to a filtered list of
signatures. Because this case is so simple, looking at either argument fully distinguishes them.

In [113]:
candidates = {i: {tid: filter_sigs(sigs_working, i, tid)
                  for tid in child_edges(sigs_working, i)}
              for i in equations}
pprint(candidates)

{0: {6: [[ndt.type('int8'), ndt.type('int8')], None, None],
     7: [None, [ndt.type('int16'), ndt.type('int16')], None],
     19: [None, None, [ndt.type('float32'), ndt.type('float32')]]},
 1: {6: [[ndt.type('int8'), ndt.type('int8')], None, None],
     7: [None, [ndt.type('int16'), ndt.type('int16')], None],
     19: [None, None, [ndt.type('float32'), ndt.type('float32')]]}}


Since nothing distinguishes the candidates, we'll just arbitrarily take the first one. This gives us a decision
tree node which looks at the type id of argument zero. We now need to recursively apply what we just did to every
outgoing edge of of the node, and then our decision tree will be done.

In [114]:
def build_decision_tree_node(sigs_working, equations):
    if equations:
        # Each equation generates a candidate decision tree node
        candidates = {i: {tid: filter_sigs(sigs_working, i, tid)
                          for tid in child_edges(sigs_working, i)}
                      for i in equations}
        # For now, always taking first candidate
        i, equations = equations[0], equations[1:]
        edges = candidates[i]
        # Build the child nodes recursively
        return ('typeid', i, {tid: build_decision_tree_node(edges[tid], equations)
                              for tid in edges})
    else:
        # When there are no equations left, the arguments are fully matched
        results = [i for (i, sig) in enumerate(sigs_working) if sig is not None]
        if len(results) == 1:
            return ('match', results[0])
        elif(results):
            return ('ambiguous', results)
        else:
            return ('nomatch')
        

def build_decision_tree(sigs):
    sigs_working = [s.pos_types for s in sigs]
    equations = list(range(len(sigs_working[0])))
    return build_decision_tree_node(sigs_working, equations)

In [115]:
pprint(build_decision_tree(sigs))

('typeid',
 0,
 {6: ('typeid', 1, {6: ('match', 0)}),
  7: ('typeid', 1, {7: ('match', 1)}),
  19: ('typeid', 1, {19: ('match', 2)})})


The final thing we want to do is convert the decision tree into some bytecode that can be executed quickly.
We'll model the bytecode as a list of functions that modify a simple machine state.

In [118]:
class MachineResult(Exception):
    def __init__(self, result):
        self.result = result

class MachineState:
    def __init__(self, argtypes):
        self.argtypes = argtypes
    def run(self, instr):
        try:
            i = 0;
            while True:
                i = instr[i](self)
        except MachineResult as e:
            return e.result

def build_machine(decision_tree, instr):
    i = len(instr)
    if(decision_tree[0] == 'match'):
        @withrepr('<return match %d>' % decision_tree[1])
        def _match(state):
            raise MachineResult(decision_tree[1])
        instr.append(_match)
    elif(decision_tree[0] == 'ambiguous'):
        @withrepr('<raise ambiguous %r>' % decision_tree[1])
        def _ambiguous(state):
            raise Exception('Ambiguous matches %s' % decision_tree[1])
        instr.append(_ambiguous)
    elif(decision_tree[0] == 'nomatch'):
        @withrepr('<raise nomatch>')
        def _nomatch(state):
            raise Exception('No match found')
        instr.append(_nomatch)
    elif(decision_tree[0] == 'typeid'):
        instr.append(None)
        nextinstr = {tid: build_machine(decision_tree[2][tid], instr)
                     for tid in decision_tree[2]}
        @withrepr('<switch typeid[%d] %r>' % (decision_tree[1], nextinstr))
        def _typeid(state):
            tid = state.argtypes[decision_tree[1]].id
            if not tid in nextinstr:
                raise Exception('No match found')
            else:
                return nextinstr[tid]
        instr[i] = _typeid
    return i

In [119]:
instr = []
build_machine(build_decision_tree(sigs), instr); pprint(instr)

[<switch typeid[0] {19: 1, 6: 3, 7: 5}>,
 <switch typeid[1] {19: 2}>,
 <return match 2>,
 <switch typeid[1] {6: 4}>,
 <return match 0>,
 <switch typeid[1] {7: 6}>,
 <return match 1>]


Now we can run run these instructions on a few different matching and non-matching argument types.

In [120]:
MachineState([ndt.int8, ndt.int8]).run(instr)

0

In [121]:
MachineState([ndt.float32, ndt.float32]).run(instr)

2

In [122]:
MachineState([ndt.int32, ndt.int32]).run(instr)

Exception: No match found