# DyND Callable MultiDispatch in Python

This is a notebook which goes with the [DyND Callable Multi-Dispatch Design](callable-multidispatch-design.md).
It's an implementation of the described signature subset construction, with the aim to illustrate the algorithm
clearly. We will build up an implementation in stages, starting with simple cases and gradually escalating the
complexity to account for more of the pattern matching capabilities that DyND types have.

In [29]:
import dynd
from dynd import nd, ndt
from pprint import pprint
print("libdynd", dynd.__libdynd_version__, ", dynd-python", dynd.__version__)

libdynd v0.7.2-975-ga53df83 , dynd-python v0.7.2-278-gbbf2ace


In [108]:
import functools
class reprwrapper(object):
    def __init__(self, reprstr, func):
        self._reprstr = reprstr
        self._func = func
        functools.update_wrapper(self, func)
    def __call__(self, *args, **kw):
        return self._func(*args, **kw)
    def __repr__(self):
        return self._reprstr

def withrepr(reprstr):
    def _wrap(func):
        return reprwrapper(reprstr, func)
    return _wrap

## Simple Scalar Signatures

Let's start with a few scalar signatures.

In [170]:
sigs = [ndt.type("(int8, int8) -> int8"),
        ndt.type("(int16, int16) -> int16"),
        ndt.type("(float32, float32) -> float32")]

In [171]:
sigs_working = [s.pos_types for s in sigs]; pprint(sigs_working)

[[ndt.type('int8'), ndt.type('int8')],
 [ndt.type('int16'), ndt.type('int16')],
 [ndt.type('float32'), ndt.type('float32')]]


For this initial case, we don't need to track additional working state, so that's it. We do need
a set of equations that need to be satisfied for a match. The only form of equation at this point is
`S.TInp[i].matches(TInp[i])`, so we'll just store the `i`.

In [172]:
equations = [0, 1]

Each node of the decision tree will process an equation, either completely or partially by decomposing
it into additional simpler equations. We're starting really simple, so that we can build up the complexity of
the final algorithm incrementally.

In [173]:
def child_edges(sigs_working, i):
    return {sig[i].id for sig in sigs_working if sig is not None}

def filter_sigs(sigs_working, i, tid):
    return [sig if sig and sig[i].id == tid else None for sig in sigs_working]

We're going to just process the first equation at each step. Eventually, we'll want to pick the
equation which maximizes how quickly we shrink the set of working signatures, but we'll not do that yet.

In [175]:
edges = {tid: filter_sigs(sigs_working, equations[0], tid)
         for tid in child_edges(sigs_working, equations[0])}
pprint(edges)

{6: [[ndt.type('int8'), ndt.type('int8')], None, None],
 7: [None, [ndt.type('int16'), ndt.type('int16')], None],
 19: [None, None, [ndt.type('float32'), ndt.type('float32')]]}


This gave us a decision tree node which looks at the type id of argument zero. We now need to recursively apply
what we just did to every outgoing edge of of the node to complete the tree.

In [176]:
def build_decision_tree_node(sigs_working, equations):
    if equations:
        # For now, always processing the first equation
        i, equations = equations[0], equations[1:]
        # Each equation generates a candidate decision tree node
        edges = {tid: filter_sigs(sigs_working, i, tid)
                 for tid in child_edges(sigs_working, i)}
        # Build the child nodes recursively
        return ('typeid', i, {tid: build_decision_tree_node(edges[tid], equations)
                              for tid in edges})
    else:
        # When there are no equations left, the arguments are fully matched
        results = [i for (i, sig) in enumerate(sigs_working) if sig is not None]
        if len(results) == 1:
            return ('match', results[0])
        elif(results):
            return ('ambiguous', results)
        else:
            return ('nomatch')
        

def build_decision_tree(sigs):
    sigs_working = [s.pos_types for s in sigs]
    equations = list(range(len(sigs_working[0])))
    return build_decision_tree_node(sigs_working, equations)

In [177]:
pprint(build_decision_tree(sigs))

('typeid',
 0,
 {6: ('typeid', 1, {6: ('match', 0)}),
  7: ('typeid', 1, {7: ('match', 1)}),
  19: ('typeid', 1, {19: ('match', 2)})})


The final thing we want to do is convert the decision tree into some bytecode that can be executed quickly.
We'll model the bytecode as a list of functions that modify a simple machine state.

In [178]:
class MachineResult(Exception):
    def __init__(self, result):
        self.result = result

class MachineState:
    def __init__(self, argtypes):
        self.argtypes = argtypes
    def run(self, instr):
        try:
            i = 0;
            while True:
                i = instr[i](self)
        except MachineResult as e:
            return e.result

def build_machine(decision_tree, instr=None):
    if instr is None:
        instr = []
        build_machine(decision_tree, instr)
        return instr
    i = len(instr)
    if(decision_tree[0] == 'match'):
        @withrepr('%d <return match %d>' % (i, decision_tree[1]))
        def _match(state):
            raise MachineResult(decision_tree[1])
        instr.append(_match)
    elif(decision_tree[0] == 'ambiguous'):
        @withrepr('%d <raise ambiguous %r>' % (i, decision_tree[1]))
        def _ambiguous(state):
            raise Exception('Ambiguous matches %s' % decision_tree[1])
        instr.append(_ambiguous)
    elif(decision_tree[0] == 'nomatch'):
        @withrepr('%d <raise nomatch>' % (i,))
        def _nomatch(state):
            raise Exception('No match found')
        instr.append(_nomatch)
    elif(decision_tree[0] == 'typeid'):
        instr.append(None)
        nextinstr = {tid: build_machine(decision_tree[2][tid], instr)
                     for tid in sorted(decision_tree[2])}
        @withrepr('%d <switch typeid[%d] {%s}>' %
                  (i, decision_tree[1], ', '.join('%d: %d' % (k, nextinstr[k]) for k in sorted(nextinstr))))
        def _typeid(state):
            tid = state.argtypes[decision_tree[1]].id
            if not tid in nextinstr:
                raise Exception('No match found')
            else:
                return nextinstr[tid]
        instr[i] = _typeid
    return i

In [179]:
instr = build_machine(build_decision_tree(sigs)); pprint(instr)

[0 <switch typeid[0] {6: 1, 7: 3, 19: 5}>,
 1 <switch typeid[1] {6: 2}>,
 2 <return match 0>,
 3 <switch typeid[1] {7: 4}>,
 4 <return match 1>,
 5 <switch typeid[1] {19: 6}>,
 6 <return match 2>]


Now we can run run these instructions on a few different matching and non-matching argument types.

In [180]:
MachineState([ndt.int8, ndt.int8]).run(instr)

0

In [181]:
MachineState([ndt.float32, ndt.float32]).run(instr)

2

In [182]:
MachineState([ndt.int32, ndt.int32]).run(instr)

Exception: No match found

What we've created already works pretty well for simple scalar signatures, let's give the one we used a slight twist.

In [183]:
sigs2 = [ndt.type("(int8, int8) -> int8"),
         ndt.type("(int16, int16) -> int16"),
         ndt.type("(float32, float32) -> float32"),
         ndt.type("(int16, float32) -> float32")]
instr = build_machine(build_decision_tree(sigs2)); pprint(instr)

[0 <switch typeid[0] {6: 1, 7: 3, 19: 6}>,
 1 <switch typeid[1] {6: 2}>,
 2 <return match 0>,
 3 <switch typeid[1] {7: 4, 19: 5}>,
 4 <return match 1>,
 5 <return match 3>,
 6 <switch typeid[1] {19: 7}>,
 7 <return match 2>]


This results in a slightly modified decision tree and machine, that has a branch at the second level if the
first argument was `int16`.

## Signatures With Type Variables