In [1]:
import asdl
import sys
sys.path.append('../')
from src.adt import ADT

Frequently when working with intermediate representations in a compiler, we want to ensure a unique identity of nodes.  We can do this very neatly using memoization.  Therefore, it's worth the trouble to see if we can incorporate memoization directly into the constructors from the ASDL module.

One complication we will run into is that a memoization `dict` will by default ensure that every constructed object cannot be garbage collected.  This seems, and is essentially bad behavior.  It can be avoided by treating all references from the memoization `dict` as weak, so that they are not—in and of themselves—sufficient to keep the object alive.

In [2]:
from weakref import WeakValueDictionary

Let us start with an example IR

In [3]:
P = ADT("""
module P
{
  expr = Var(string name)
       | Const(float val)
       | Add(expr lhs, expr rhs)
       | Mul(expr lhs, expr rhs)
}
""")

Now consider the expression `(x*x)*(x*x)`.  If properly memoized, there should be exactly 3 nodes: `x`, `x*x` and the final product.  First, let's consider memoizing `Var`.  We will try a wrapper function.

In [4]:
var_v0_cache = WeakValueDictionary({})
def memo_var_v0(name):
    v = var_v0_cache.get(name)
    if v == None:
        v = P.Var(name)
        var_v0_cache[name] = v
    return v

In [5]:
x  = memo_var_v0('x')
x0 = memo_var_v0('x')
x == x0

True

We could even hide the fact that we've done this manipulation.

In [6]:
Var = memo_var_v0
type(x) is Var

False

Hmmm, except that causes another problem.  Does this mean that we will have to use special constructor functions whenever we want to memoize IR objects?  Possibly not.  To avoid that, we need a way to hijack the object creation process and (when a cache lookup succeeds) return a different object than the one that would normally be created.

Python gives us a way to do this via the `__new__` function on classes.  We could directly muck with `Var.__new__`!

In [7]:
P.Var._memo_cache = WeakValueDictionary({})
def _Var_new_memo(cls, name):
    v = P.Var._memo_cache.get(name)
    if v == None:
        v = super(P.Var,cls).__new__(cls)
        P.Var._memo_cache[name] = v
    return v
P.Var.__new__ = _Var_new_memo

In [8]:
x  = P.Var('x')
x0 = P.Var('x')
print(type(x) is P.Var, type(x0) is P.Var)
x == x0

True True


True

Now, one question is whether we waste the time re-initializing the object and what that does.  Let's create our own little test object to see.

In [9]:
class MyVar:
    _memo_cache = WeakValueDictionary({})
    def __new__(cls,name):
        v = MyVar._memo_cache.get(name)
        if v == None:
            v = super(MyVar,cls).__new__(cls)
            MyVar._memo_cache[name] = v
        return v
    
    def __init__(self,name):
        print('init with name: '+name)
        self.name = name

x = MyVar('x')
x = MyVar('x')
x.foo = 3
print(x.__dict__)
y = MyVar('x')
print(y.__dict__)

init with name: x
init with name: x
{'name': 'x', 'foo': 3}
init with name: x
{'name': 'x', 'foo': 3}


So, we are definitely "wasting" time doing re-initialization.  However, we're also not disrupting anything else about the object.  So maybe that's ok for a first pass at this.

We now have a strategy for how to sneak the memoization into the constructor classes pretty transparently.  However, we need a way to code-generate the appropriate memoization for the potentially multi-argument constructors.  This also means we need a multi-key (and different kinds of value-supporting) approach to the memoization dictionary.  Let's investigate how that might work out.

In [10]:
dd = {}
print(id( (2,5) ))
print(id( (2,5) ))
dd[(2,5)] = 24
print(dd.get( (2,5) ))

4438257096
4438253896
24


tuples don't have stable `id`s, but they do behave well as keys.  Objects do not work as keys and need their `id` taken.  It's a bit confusing.  But some searching around reveals that the following types should work fine as dictionary keys.

In [11]:
( type(3) is int,
  type(5.4) is float,
 type("ibasdf") is str,
 type(True) is bool,
 type((1,2)) is tuple )

(True, True, True, True, True)

Therefore, let's build a `_to_key` helper function.

In [12]:
def _to_key(v):
    tv = type(v)
    if ( tv is int or tv is float or
         tv is str or tv is bool or
         tv is tuple ):
        return v
    else:
        return id(v)

We need to deal with optional values as well.  What happens if we try to memoize on `None`?

In [13]:
dd = {}
dd[None] = 3
print(dd[None])
dd[(None,None)] = 4
print(dd[(None,None)])

3
4


And what about sequences (i.e. lists)?  Those do not work as keys.  Therefore, we need a special function to help memoize lists.  It turns out that we can just convert them to tuples (woah!)

In [14]:
xs = [1,2,3]
ys = [1,2,3]
dd = {}
dd[tuple(xs)] = 5
print(dd.get(tuple(xs)))
print(dd.get(tuple(ys)))

5
5


Now we're set to create a proper, robust IR memoization support.

-----

In [21]:
_builtin_keymap = {
    'string'  : lambda x: x,
    'int'     : lambda x: x,
    'object'  : id,
    'float'   : lambda x: x,
    'bool'    : lambda x: x,
}

def _add_memoization(mod,whitelist,ext_key):
    asdl_mod = mod._ast
    
    keymap = _builtin_keymap.copy()
    for nm,fn in ext_key.items():
        keymap[nm] = fn
    for nm in asdl_mod.types:
        keymap[nm] = id
    
    def create_listkey(f):
        i = 'i' if f.name != i else 'ii'
        return (f"tuple([ K['{f.type}']({i}) "
                f"for {i} in {f.name} ]),")
    def create_optkey(f):
        return (f"None if {f.name} == None else "
                f"K['{f.type}']({f.name}),")
    
    def create_newfn(name, fields):
        if not name in whitelist: return
        T       = getattr(mod,name)
        
        argstr  = ', '.join([ f.name for f in fields ])
        keystr  = '('+(''.join([
            create_listkey(f) if f.seq else
            create_optkey(f)  if f.opt else
            f"K['{f.type}']({f.name}),"
            for f in fields
        ]))+')'
        
        exec_out = { 'T': T, 'K': keymap }
        exec_str = (f"def {name}_new(cls,{argstr}):\n"
                    f"    key = {keystr}\n"
                    f"    val = T._memo_cache.get(key)\n"
                    f"    if val == None:\n"
                    f"        val = super(T,cls).__new__(cls)\n"
                    f"        T._memo_cache[key] = val\n"
                    f"    return val")
        # un-comment this line to see what's
        # really going on
        print(exec_str)
        exec(exec_str, exec_out)
        
        T._memo_cache = WeakValueDictionary({})
        T.__new__     = exec_out[name + '_new']
        
    def expand_sum(typ_name,t):
        T          = getattr(mod,typ_name)
        afields    = t.attributes
        for c in t.types:
            create_newfn(c.name, c.fields + afields)
    
    for nm,t in asdl_mod.types.items():
        if isinstance(t,asdl.Product):
            create_newfn(nm,t.fields)
        elif isinstance(t,asdl.Sum):
            expand_sum(nm,t)
        else: assert false, "unexpected kind of asdl type"

def memo(mod, whitelist, ext_key={}):
    _add_memoization(mod,whitelist,ext_key)

In [23]:
memo(P,['Var','Const','Add','Mul'])

x = P.Var('x')
xx = P.Mul(x,x)
xxxx = P.Mul(xx,xx)
print(id(xxxx))
print(id(xxxx.lhs),id(xxxx.rhs))
print(id(xxxx.lhs.lhs),id(xxxx.lhs.rhs),id(xxxx.rhs.lhs),id(xxxx.rhs.rhs))

def Var_new(cls,name):
    key = (K['string'](name),)
    val = T._memo_cache.get(key)
    if val == None:
        val = super(T,cls).__new__(cls)
        T._memo_cache[key] = val
    return val
def Const_new(cls,val):
    key = (K['float'](val),)
    val = T._memo_cache.get(key)
    if val == None:
        val = super(T,cls).__new__(cls)
        T._memo_cache[key] = val
    return val
def Add_new(cls,lhs, rhs):
    key = (K['expr'](lhs),K['expr'](rhs),)
    val = T._memo_cache.get(key)
    if val == None:
        val = super(T,cls).__new__(cls)
        T._memo_cache[key] = val
    return val
def Mul_new(cls,lhs, rhs):
    key = (K['expr'](lhs),K['expr'](rhs),)
    val = T._memo_cache.get(key)
    if val == None:
        val = super(T,cls).__new__(cls)
        T._memo_cache[key] = val
    return val
4444128872
4443989720 4443989720
4444131168 4444131168 4444131168 4444131168
