In [1]:
# Convert the pattern that appears in ADT methods:
#
# if self.is_state1:
#     ...
# elif self.is_state2:
#     ...
# else:
#     ...
#    
# Into n implementations and distributed among different sub-classes

In [2]:
import inspect
import ast
import types
from textwrap import dedent
from sidekick import Union, opt, record

def example(foo, x, *args, y=None, **kwargs):
    if foo.is_just:
        first_expr()
        second_expr()
    elif foo.is_nothing:
        some_expr()
    else:
        ...

In [3]:
class Maybe(Union):
    Just = opt(object)
    Nothing = opt()

In [4]:
def extract_case(test):
    """
    Extracts (self variable name, case name) from test AST.
    
    AST represents code of the form: ``this.is_case``
    """
    if not isinstance(test, ast.Attribute):
        raise ValueError('case is not a simple attribute access')

    # Extract attr
    attr = test.attr
    if not attr.startswith('is_'):
        raise ValueError(f'not a valid case attribute name: {attr}')
    case = attr[3:]

    # Extract self name
    self_expr = test.value
    if not isinstance(self_expr, ast.Name):
        name = self_expr.__class__.__name__
        raise ValueError(f'invalid self expression: {name}')
    self_name = self_expr.id
    return (self_name, case)

In [5]:
def extract_method(body, func, py_func):
    """
    Extracts a method from a block of instructions, the original function
    expression and the original python function implementation.
    """
    func_ast = ast.FunctionDef(
        decorator_list=[], 
        name=func.name, args=func.args, body=body, 
        lineno=func.lineno, col_offset=func.col_offset,
    )
    mod = ast.Module(body=[func_ast])
    skip = py_func.__code__.co_firstlineno - 1
    mod = ast.Module(body=[func_ast])
    ast.increment_lineno(mod, skip)
    code = compile(mod, py_func.__code__.co_filename, 'exec', optimize=1)
    ns = {}
    exec(code, py_func.__globals__, ns)
    return ns[func.name]

In [6]:
def methods_map(function):
    """
    Convert case function into a mapping from cases to implementations.
    """
    methods = {}
    self_vars = set()

    tree = ast.parse(dedent(inspect.getsource(function)))
    func = tree.body[0]
    body = func.body

    # Check if body consists of a single If statement
    if not (len(body) == 1 and isinstance(body[0], ast.If)):
        raise ValueError('function body is not a single if statement')

    if_block = body[0]

    while if_block:
        self, case = extract_case(if_block.test)
        self_vars.add(self)
        method = extract_method(if_block.body, func, function)
        method.__name__ = f'{method.__name__}[{case}]'
        method.__qualname__ = f'{case}.{method.__qualname__}'
        methods[case] = method

        orelse = if_block.orelse
        if len(orelse) == 1 and isinstance(orelse[0], ast.If):
            if_block = orelse[0]
        elif not orelse:
            break
        else:
            method = extract_method(orelse, func, function)
            method.__name__ = f'{method.__name__}[else]'
            method.__qualname__ = f'else.{method.__qualname__}'
            methods['else'] = method
            break

    if len(self_vars) != 1:
        raise ValueError(f'inconsistent self variables: {self_vars}')

    return methods

In [7]:
def update_union(adt, attr, force=False):
    func = getattr(adt, attr)
    methods = methods_map(func)
    generic = methods.pop('else', None)
    cases = dict(adt._meta.cases)
    case_names = {x.lower(): x for x in cases}

    method_map = {case_names[case]: f for case, f in methods.items()}
    case_classes = set(cases.values())

    for name, func in method_map.items():
        mcs = cases[name]
        case_classes.remove(mcs)
        if attr not in mcs.__dict__ or force:
            setattr(mcs, attr, method_map[name])

    if generic:
        for mcs in case_classes:
            if attr not in mcs.__dict__ or force:
                setattr(mcs, attr, generic)

In [8]:
def case_method(mcs, name=None, force=False):
    # Decorating a method
    if isinstance(mcs, types.FunctionType):
        mcs.is_case_method = True
        return mcs
    
    def decorator(func):
        attr = name or func.__name__
        if hasattr(mcs, attr) and not force:
            raise TypeError(f'{mcs.__name__} already has method {attr}')
        setattr(mcs, attr, func)
        update_union(mcs, attr)
        return func
        
    return decorator


def optimize_case_methods(mcs):
    for attr in dir(mcs):
        value = getattr(mcs, attr, None)
        if isinstance(value, types.FunctionType) and getattr(value, 'is_case_method', False):
            update_union(mcs, attr)
    return mcs

In [9]:
@case_method(Maybe, force=True)
def foo(self):
    if self.is_just:
        return self.value + 2
    else:
        return 2
    
@case_method(Maybe, force=True)
def bar(self, x):
    if self.is_nothing:
        return x
    elif self.is_just:
        return 42
    

In [10]:
class Slist1(Union):
    Cons = opt(head=object, tail=object)
    Nil = opt()
    
    def size(self):
        if self.is_cons:
            return 1 + self.tail.size()
        else:
            return 0

In [11]:
@optimize_case_methods
class Slist2(Union):
    Cons = opt(head=object, tail=object)
    Nil = opt()
    
    @case_method
    def size(self):
        if self.is_cons:
            return 1 + self.tail.size()
        else:
            return 0

In [12]:
class Slist3(Union):
    class Cons(this):
        head: object
        tail: object 
            
        def size(self):
            return 1 + self.tail.size()
        
    class Nil(this):
        def size(self):
            return 0
    
    def generic(self):
        x = 0
        while self.is_cons:
            x += 1
            self = self.tail
        return x

In [13]:
@optimize_case_methods
class Slist4(Union):
    class Cons(this):
        head: object
        tail: object
    
    class Nil(this):
        pass
    
    @case_method
    def size(self):
        if self.is_cons:
            return 1 + self.tail.size()
        else:
            return 0

In [14]:
class Slist5:
    __slots__ = ()
    
class Cons(Slist5):
    __slots__ = ('head', 'tail')
    
    def __init__(self, head, tail):
        self.head = head
        self.tail = tail
    
    def size(self):
        return 1 + self.tail.size()

class Nil(Slist5):
    __slots__ = ()

    def size(self):
        return 0
    
Slist5.Cons = Cons
Slist5.Nil = Nil()

In [15]:
class Slist6:
    __slots__ = ()
    is_cons = False
    is_nil = False
    
    def size(self):
        if self.is_cons:
            return 1 + self.tail.size()
        else:
            return 0
    
class Cons(Slist6):
    __slots__ = ('head', 'tail')
    is_cons = True
    
    def __init__(self, head, tail):
        self.head = head
        self.tail = tail
    
class Nil(Slist6):
    __slots__ = ()
    is_nil = True

Slist6.Cons = Cons
Slist6.Nil = Nil()
Slist6._meta = record(cases={'Cons': Cons, 'Nil': Nil})

@case_method(Slist6)
def size2(self):
    if self.is_cons:
        return 1 + self.tail.size2()
    else:
        return 0

In [16]:
def to_list(mcs, data):
    return mcs.Cons(data[0], to_list(mcs, data[1:])) if data else mcs.Nil

In [17]:
n = 5
L1 = to_list(Slist1, list(range(n)))
L2 = to_list(Slist2, list(range(n)))
L3 = to_list(Slist3, list(range(n)))
L4 = to_list(Slist4, list(range(n)))
L5 = to_list(Slist5, list(range(n)))
L6 = to_list(Slist6, list(range(n)))

In [18]:
%timeit -n 100000 L1.size()
%timeit -n 100000 L2.size()
%timeit -n 100000 L3.size()
%timeit -n 100000 L4.size()
%timeit -n 100000 L5.size()
%timeit -n 100000 L6.size()
%timeit -n 100000 L6.size2()

2.12 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.87 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.02 µs ± 282 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.82 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
907 ns ± 57.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.12 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
879 ns ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [19]:
%timeit -n 1000000 L1.head
%timeit -n 1000000 L2.head
%timeit -n 1000000 L3.head
%timeit -n 1000000 L4.head
%timeit -n 1000000 L5.head
%timeit -n 1000000 L6.head

227 ns ± 8.83 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
257 ns ± 17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
255 ns ± 5.75 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
260 ns ± 22.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
56.1 ns ± 2.23 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
63.4 ns ± 3.23 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [20]:
%timeit -n 1000000 L1.tail
%timeit -n 1000000 L2.tail
%timeit -n 1000000 L3.tail
%timeit -n 1000000 L4.tail
%timeit -n 1000000 L5.tail
%timeit -n 1000000 L6.tail

236 ns ± 25.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
240 ns ± 13.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
261 ns ± 11.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
248 ns ± 14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
56.5 ns ± 5.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
57.9 ns ± 5.77 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [21]:
%timeit -n 1000000 L1.size
%timeit -n 1000000 L2.size
%timeit -n 1000000 L3.size
%timeit -n 1000000 L4.size
%timeit -n 1000000 L5.size
%timeit -n 1000000 L6.size

73.1 ns ± 12.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
75.1 ns ± 7.48 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
72.4 ns ± 3.61 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
73.7 ns ± 8.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
68.1 ns ± 7.82 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
71.8 ns ± 8.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [22]:
import dis
dis.dis(Slist5.Cons.size)

print()
dis.dis(Slist6.Cons.size2)

 12           0 LOAD_CONST               1 (1)
              2 LOAD_FAST                0 (self)
              4 LOAD_ATTR                0 (tail)
              6 LOAD_METHOD              1 (size)
              8 CALL_METHOD              0
             10 BINARY_ADD
             12 RETURN_VALUE

 31           0 LOAD_CONST               1 (1)
              2 LOAD_FAST                0 (self)
              4 LOAD_ATTR                0 (tail)
              6 LOAD_METHOD              1 (size2)
              8 CALL_METHOD              0
             10 BINARY_ADD
             12 RETURN_VALUE


In [23]:
f2 = Slist2.Nil.size
f3 = Slist3.Nil.size
c2 = f2.__code__
c3 = f3.__code__

In [24]:
d2 = {attr: getattr(c2, attr) for attr in dir(c2) if attr.startswith('co_')}
d3 = {attr: getattr(c3, attr) for attr in dir(c3) if attr.startswith('co_')}

In [25]:
{k1: (v1, v2) for (k1, v1), (k2, v2) in zip(d2.items(), d3.items()) if v2 != v1}

{'co_filename': ('<ipython-input-11-b95d254fa8b0>',
  '<ipython-input-12-027eeb8dd036>'),
 'co_firstlineno': (6, 10),
 'co_lnotab': (b'\x00\x05', b'\x00\x01')}

In [26]:
f = lambda self: 1 + self.tail.size()
Slist3.Cons.size = f 
dis.dis(f)

  1           0 LOAD_CONST               1 (1)
              2 LOAD_FAST                0 (self)
              4 LOAD_ATTR                0 (tail)
              6 LOAD_METHOD              1 (size)
              8 CALL_METHOD              0
             10 BINARY_ADD
             12 RETURN_VALUE


In [27]:
ns2 = Slist2.Cons.__dict__
ns3 = Slist3.Cons.__dict__
ns4 = Slist4.Cons.__dict__

In [28]:
key = 'size'
%timeit -n 1000000 ns2[key]
%timeit -n 1000000 ns3[key]
%timeit -n 1000000 ns4[key]

print()
key = 'head'
%timeit -n 1000000 ns2[key]
%timeit -n 1000000 ns3[key]
%timeit -n 1000000 ns4[key]

73.2 ns ± 6.69 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
74.8 ns ± 3.08 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
105 ns ± 21 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

75.1 ns ± 6.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
71.6 ns ± 2.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
80 ns ± 8.92 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [29]:
n2.keys(), ns3.keys()

NameError: name 'n2' is not defined

In [None]:
exec?

In [None]:
f = lambda x: x
x = 42.0
%timeit f(...)
%timeit ... if x.real else ...
%timeit ...