In [1]:
import sys
sys.path.append('../')
sys.path.append('../src/')
from adt import ADT
from adt import memo as ADTmemo
from atlv0 import IR as IRv0
from atlv0 import Func
import numpy as np
import time
from halide import *

Recall our Tensor Language:

$$
\newcommand{\or}{\ |\ }
\begin{array}{rl}
  e &= x \or c \or e_0 + e_1 \or e_0 \cdot e_1 \\
  &\or \sum_i e \or \boxplus_i e \or e_i \or [i=j] \\
  &\or (e_0,e_1) \or \pi_0 e \or \pi_1 e \\
\end{array}
$$

When we last left off, we had designed an IR, with:
* some operator overloading to make input slightly more convenient
* string and LaTeX display routines
* typechecking
* a slow Python intepreter

We took a detour into wrapping up the Halide language in a Python wrapper.  Now that we've done that, we can import Halide as a high-quality, JiT-ready code generator.  The main goal of this notebook is to work out how to compile the tensor language to that target in the simplest, most direct way possible.

### Note...
In transfering the tensor language work to a `src/` directory Python script, `T0` was renamed to `IR`.  Some other changes were also made to clean up the results and function interface a bit into a more canonical usage pattern.  I will not review all of those changes here.

-----

# NumPy Arrays

In Python, large amounts of numeric data are most commonly stored, managed and accessed via `NumPy`.  For this reason, we'll try to bind in these arrays as the representation of our tensor data.  Doing so will give us a high degree of interoperability with the Python ecosystem.

To begin, let's just restate some of the constants we were working with before.

In [2]:
xv = np.array([4.,7.,1.], order='F')
Av = np.array([[5.,2.,0.],[2.2,0.,4.5],[0.,6.1,3.3]], order='F')
cv = np.array([0.,0.,1.], order='F')

xv, Av, cv

(array([4., 7., 1.]), array([[5. , 2. , 0. ],
        [2.2, 0. , 4.5],
        [0. , 6.1, 3.3]]), array([0., 0., 1.]))

Now, observe the following metadata about the nd-array that we can retrieve via introspection.

In [3]:
def nparr_data(a):
    print(f'flags:    {a.flags}')
    print(f'ndim:     {a.ndim}')
    print(f'shape:    {a.shape}')
    print(f'strides:  {a.strides}')
    print(f'data:     {a.data}')
    print(f'size:     {a.size}')
    print(f'itemsize: {a.itemsize}')
    print(f'nbytes:   {a.nbytes}')
    
#nparr_data(xv)
nparr_data(Av)
#nparr_data(cv)

flags:      C_CONTIGUOUS : False
  F_CONTIGUOUS : True
  OWNDATA : True
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False
ndim:     2
shape:    (3, 3)
strides:  (8, 24)
data:     <memory at 0x10e527990>
size:     9
itemsize: 8
nbytes:   72


`ndim` gives us the order of the tensor, referred to as _the number of dimensions_.  `shape` then gives us those _dimensions_ themselves.  So this says we have a $3\times 3$ matrix.  The `strides` are very important not for our conceptual picture, but for translating between the mathematical indexing and the actual address of data.  These correspond to the same strides used in Halide's buffers.  Great!  This should clearly make things easy.

`size` should just be the product of the dimensions.  And `itemsize` tells us how many bytes each entry takes.  *note* importantly that the strides above are all given in bytes, whereas Halide's buffers expect strides divided by the byte-size.  That is, we would want to describe the strides here to Halide as `(24/8,8/8)`, i.e. `(3,1)`.  `nbytes` might be useful if we start doing manual memory management (e.g. allocating data on the GPU).

Lastly, consider the flags here.  These are _not_ the same as `halide_buffer_t.flags`, which appear to be used just for keeping track of whether data is dirty.  Here, the `C_CONTIGUOUS` and `F_CONTIGUOUS` flags can be subordinated to a more thorough analysis of the strided description of layout.

`OWNDATA` allows NumPy to wrap data arrays that it is borrowing from other sources.  For instance, we could choose to manage memory ourselves but expose NumPy handles onto our internally managed memory.  (There are some good reasons to do this, but it's more responsibility than we want right now.)

`WRITABLE` and `WRITEBACKIFCOPY` control writing policies. (`UPDATEIFCOPY` is deprecated.)  The write-back establishes behavior to a write-back cache in a memory hierarchy.  Write-back is supposed to be triggered when this array object is destroyed, with the implication that this array does _not_ own its own data.  Writable is... well it's writable.  Importantly the flag allows for wrapping data as read only.

Lastly `ALIGNED` tells us whether the data is allocated in a way that will play nicely with the memory/cache system.  (Does Halide ensure anything about this?)

What's up with `data`?

In [4]:
type(Av.data)

memoryview

That's not a pointer...  Well, we can get a `uint8_t *` pointer pretty simply.

In [5]:
Av.ctypes.data_as(ctypes.POINTER(ctypes.c_ubyte))

<halide.LP_c_ubyte at 0x10e531ea0>

And the type of the array?

In [6]:
Av.dtype, type(Av.dtype)

(dtype('float64'), numpy.dtype)

In [7]:
Av.dtype.type, type(Av.dtype.type)

(numpy.float64, type)

Poking around reveals that the following are good Numpy types that should correspond to all the basic types we need.

In [8]:
(np.int8, np.int16, np.int32, np.int64,
 np.uint8, np.uint16, np.uint32, np.uint64,
 np.float32, np.float64)

(numpy.int8,
 numpy.int16,
 numpy.int32,
 numpy.int64,
 numpy.uint8,
 numpy.uint16,
 numpy.uint32,
 numpy.uint64,
 numpy.float32,
 numpy.float64)

In [9]:
np.float is float, np.float32, np.double, np.byte, np.int is int, np.uint

(True, numpy.float32, numpy.float64, numpy.int8, True, numpy.uint64)

## Wrap a NumPy Array for Halide

Recall the `halide_buffer_t` structure, and some defaults we worked out for some of it.

```
struct halide_buffer_t {
    uint64_t        device = 0,
    void *          device_interface = null,
    uint8_t *       host,
    uint64_t        flags = 0,
    halide_type_t   type,
    uint32_t        dimensions,
    halide_dimension_t * dim, // array of...
    void *          padding = null,
};
```

`host` will receive the pointer we just sketched out, `type` will need to be introspected on, and `dim` ought to be deducible from the strides.

In [10]:
def ndarray_to_halide_buf(a):
    def typ_convert(dt):
        t = dt.type
        # remapping to prevent some pointless errors
        if t is float:
            t = np.float64 if (sys.float_info.max_exp == 1024) else np.float32
        if t is int:   t = np.int32
        # main case switch
        if   t is np.int8:    return halide_type_t(C.type_int,8,1)
        elif t is np.int16:   return halide_type_t(C.type_int,16,1)
        elif t is np.int32:   return halide_type_t(C.type_int,32,1)
        elif t is np.int64:   return halide_type_t(C.type_int,64,1)
        elif t is np.uint8:   return halide_type_t(C.type_uint,8,1)
        elif t is np.uint16:  return halide_type_t(C.type_uint,16,1)
        elif t is np.uint32:  return halide_type_t(C.type_uint,32,1)
        elif t is np.uint64:  return halide_type_t(C.type_uint,64,1)
        elif t is np.float32: return halide_type_t(C.type_float,32,1)
        elif t is np.float64: return halide_type_t(C.type_float,64,1)
        else:
            raise TypeError(f"unexpected type {t}")
    
    buf  = halide_buffer_t()
    buf.device              = 0
    buf.device_interface    = None
    buf.host                = a.ctypes.data_as(ctypes.POINTER(ctypes.c_ubyte))
    buf.flags               = 0
    buf.type                = typ_convert(a.dtype)
    buf.dimensions          = a.ndim
    buf.dim                 = (halide_dimension_t * a.ndim)()
    # now loop through and sort out each dimension
    for k in range(0,a.ndim):
        s = int(a.strides[k] / a.itemsize)
        assert a.strides[k] % a.itemsize == 0
        buf.dim[k] = halide_dimension_t(0,a.shape[k],s,0)
    buf.padding             = None
    
    return buf


In [11]:
Ahv = ndarray_to_halide_buf(Av)
Ahv

<halide.halide_buffer_t at 0x10bac4d90>

----

# Revisiting the IR: 

Recall the definition of our IR.

In [12]:
print(IRv0._defstr)


module IR_v0 {
    expr = Var      ( string name )
         | Const    ( float  val  )
         | Add      ( expr lhs, expr rhs )
         | Mul      ( expr lhs, expr rhs )
         | Pair     ( expr lhs, expr rhs )
         | Proj     ( int01 idx, expr arg )
         | Gen      ( string idxname, int range, expr body )
         | Sum      ( string idxname, int range, expr body )
         | Access   ( expr  base, index idx )
         -- implied multiplication of the bracket with body
         | Indicate ( pred  arg, expr body )
         -- important to express sharing of computation results
         | Let      ( string name, expr rhs, expr body )
    
    -- indices are drawn from a range s.t.
    -- 0 <= i < range
    
    pred    = Eq( index lhs, index rhs )
    
    type    = TNum    ()
            | TError  ()
            | TPair   (  type lhs, type rhs )
            | TTensor ( int range, type typ )
}



## Eliminating Pairs... Detour into Typechecking

Halide does not have a first-class construct for pair-values.  Therefore, we need to eliminate our dependence on pairs (as much as possible).  Doing this will require pushing $\pi_k$ down, and pulling $(\ ,\ )$ up.  How does this work?  We must figure out how to exchange ("commute") these two operations with every other operation.

The most distinctive rule is the one where the two forms cancel against each other

$$ \pi_k ( e_0, e_1 ) \to e_k $$

Beside that we have a set of useful rules.

$$
\begin{array}{rcl}
\pi_k [i=j] e &\to& [i=j] \pi_k e \\
\pi_k \boxplus_i e &\to& \boxplus_i \pi_k e \\
\pi_k e[i] &\to& (\pi_k e)[i] \\
\end{array}
$$

Typing tells us we don't have to worry about projecting constants, additions, multiplications, or big-sums.  But by that reasoning, we oughtn't be able to exchange projection with indexing or array generation.  Those re-writes change the type!  Likewise, we don't know what to do about `Let` or `Var`.  These examples reveal a deficiency of the prior IR.  It lacks typing information.  What we need to do is develop a new _typed_ IR, and revisit type-checking to convert one IR to the other.


------
When we change the IR below, we'll take the opportunity to replace name `string`s with `Sym`bol objects.  Doing this will allow us to create copies of names with the same original "string name" but distinct identities.  Generally, it's easier to get program re-writes correct once you adopt identifiers-as-symbols instead of identifiers-as-strings.

In [13]:
class Sym:
    unq_count   = 1
    
    def __init__(self,nm):
        self._nm    = nm
        self._id    = Sym.unq_count
        Sym.unq_count += 1
    
    def __str__(self):
        return self._nm
    
    def __repr__(self):
        return f"{self._nm}${self._id}"
    
    def name(self):
        return self._nm
    
    def copy(self):
        return Sym(self._nm)
    

IR = ADT("""
module IR {
    expr = Var      ( symbol name )
         | Const    ( float  val  )
         | Add      ( expr lhs, expr rhs )
         | Mul      ( expr lhs, expr rhs )
         | Pair     ( expr lhs, expr rhs )
         | Proj     ( int01 idx, expr arg )
         | Gen      ( symbol idxname, int range, expr body )
         | Sum      ( symbol idxname, int range, expr body )
         | Access   ( expr  base, index idx )
         -- implied multiplication of the bracket with body
         | Indicate ( pred  arg, expr body )
         -- important to express sharing of computation results
         | Let      ( symbol name, expr rhs, expr body )
         attributes( type typ )
    
    -- indices are drawn from a range s.t.
    -- 0 <= i < range
    
    pred    = Eq( index lhs, index rhs )
}
""", {
    'int01':  lambda x:  x == 0 or x == 1,
    'index':  lambda x:  (type(x) is int) or (type(x) is Sym),
    'symbol': lambda x: type(x) is Sym,
    'type':   lambda x:  isinstance(x, IRv0.type)
})
# copy over types from the old module...
for nm in ['tnum','terr','type','TNum','TError','TPair','TTensor']:
    setattr(IR,nm,getattr(IRv0,nm))
del nm


In [14]:
from atlv0 import TCError
from atlv0 import _Context

class TypeChecker:
    def __init__(self, expr, initsymtyps):
        self._ctxt   = _Context()
        for nm,symtyp in initsymtyps.items():
            assert type(symtyp[0]) is Sym
            assert isinstance(symtyp[1], IR.type)
            self._ctxt.set(nm,symtyp)
        self._errors = []
        self._out_ir = self.check(expr)
        self.report_errors()

    def get_out_ir(self):
        return self._out_ir
    
    def _err(self, node, msg):
        # might want to discern location
        # via `node` eventually
        self._errors.append(msg)
    
    def report_errors(self):
        if len(self._errors) > 0:
            raise TCError('Found errors during typechecking:\n  '+
                          '\n  '.join(self._errors))
    
    def _get_ivar(self, node, name):
        symtyp  = self._ctxt.get(name)
        if symtyp == None:
            self._err(node, f"index variable '{name}' was undefined")
        elif type(symtyp[1]) is not int:
            self._err(node, f"variable '{name}' was "
                            f"not bound as an index variable")
        else: return symtyp
        # on failure fallthrough
        return (Sym(name),None)
    
    def _get_var(self, node, name):
        symtyp  = self._ctxt.get(name)
        if symtyp == None: 
            self._err(node, f"variable '{name}' was undefined")
        elif not isinstance(symtyp[1], IR.type):
            self._err(node, f"variable '{name}' was "
                            f"not bound as a variable")
        else: return symtyp
        # on failure fallthrough
        return (Sym(name),IR.terr)
    
    def check(self, node):
        nclass = type(node)
        if   nclass is IRv0.Var:
            sym,typ = self._get_var(node, node.name)
            if typ == None: typ = IR.terr
            return IR.Var(sym, typ)

        elif nclass is IRv0.Const:
            return IR.Const(node.val,IR.tnum)
        
        elif nclass is IRv0.Add or nclass is IRv0.Mul:
            lhs = self.check(node.lhs)
            rhs = self.check(node.rhs)
            typ = IR.tnum if (lhs.typ == IR.tnum and
                              rhs.typ == IR.tnum) else IR.terr
            if lhs.typ != IR.tnum and lhs.typ != IR.terr:
                self._err(node,
                          f"expected number on left-hand-side "
                          f"of addition: {node}")
            if rhs.typ != IR.tnum and rhs.typ != IR.terr:
                self._err(node,
                          f"expected number on right-hand-side "
                          f"of addition: {node}")
            if nclass is IRv0.Add:
                return IR.Add(lhs,rhs,typ)
            else:
                return IR.Mul(lhs,rhs,typ)
        
        elif nclass is IRv0.Pair:
            lhs = self.check(node.lhs)
            rhs = self.check(node.rhs)
            typ = IR.terr
            if lhs.typ != IR.terr and rhs.typ != IR.terr:
                typ = IR.TPair(lhs.typ,rhs.typ)
            return IR.Pair(lhs,rhs,typ)
        
        elif nclass is IRv0.Proj:
            arg = self.check(node.arg)
            typ = IR.terr
            if   arg.typ == IR.terr: pass
            elif type(arg.typ) is not IR.TPair:
                self._err(node, f"Was expecting a pair as argument: {node}")
            elif node.idx == 0: typ = arg.typ.lhs
            else:               typ = arg.typ.rhs
            return IR.Proj(node.idx,arg,typ)
        
        elif nclass is IRv0.Gen or nclass is IRv0.Sum:
            self._ctxt.push()
            newsym  = Sym(node.idxname)
            self._ctxt.set(node.idxname, (newsym,node.range) )
            body    = self.check(node.body)
            self._ctxt.pop()
            if   nclass is IRv0.Sum:
                typ = body.typ
                if typ != IR.tnum and typ != IR.terr:
                    self._err(node, f"Was expecting a number as body: {node}")
                    typ = IR.terr
                return IR.Sum(newsym,node.range,body,typ)
            else: # nclass is IRv0.Gen
                typ = IR.terr
                if body.typ != IR.terr:
                    typ = IR.TTensor(node.range, body.typ)
                return IR.Gen(newsym,node.range,body,typ)
        
        elif nclass is IRv0.Access:
            base    = self.check(node.base)
            sym,rng = self._get_ivar(node, node.idx)
            typ     = IR.terr
            if base.typ == IR.terr: pass
            elif not isinstance(base.typ,IR.TTensor):
                self._err(node, f"Was expecting a tensor to index: {node}")
            elif rng == None:  pass
            elif rng != base.typ.range:
                self._err(node, f"index variable '{node.idx}' was bound "
                                f"to the range {rng}, but this tensor "
                                f"expects an index of range {base.typ.range}")
            else:
                typ = base.typ.typ
            return IR.Access(base,sym,typ)
        
        elif nclass is IRv0.Indicate:
            # need to check the predicate
            eqnode      = node.arg
            lsym, lrng  = self._get_ivar(node, eqnode.lhs)
            rsym, rrng  = self._get_ivar(node, eqnode.rhs)
            body        = self.check(node.body)
            
            if   lrng == None or rrng == None: pass
            elif lrng != rrng:
                self._err(node, f"index variables "
                                f"'{eqnode.lhs}' and '{eqnode.rhs}' "
                                f"in equality are drawn from different"
                                f"ranges: {lrng} and {rrng}")
            # can proceed with the body type regardless of errors
            return IR.Indicate( IR.Eq(lsym,rsym), body, body.typ)
        
        elif nclass is IRv0.Let:
            rhs     = self.check(node.rhs)
            self._ctxt.push()
            newsym  = Sym(node.name)
            self._ctxt.set(node.name, (newsym,rhs.typ) )
            body    = self.check(node.body)
            self._ctxt.pop()
            return IR.Let(newsym,rhs,body,body.typ)
        
        else:
            assert false, "Unexpected expression class for {node}"

# wrap type-checking in an object that we can use to keep track of
# the arguments etc. with
class TypedFunc:
    def __init__(self,argtyps,expr):
        self._orig_expr = expr
        
        # re-pack the args a few ways
        self._arglist = []
        self._argsyms = {}
        self._symtyps = {}
        assert type(argtyps) is list
        for nm,typ in argtyps:
            self._arglist.append( (nm,typ) )
            s = Sym(nm)
            self._argsyms[nm] = s
            self._symtyps[nm] = (s,typ)
        
        # run type-checking to get the modified IR
        self._expr    = TypeChecker(expr,self._symtyps).get_out_ir()

    def __str__(self):
        args = ", ".join([ f"{st[0]}:{st[1]}"
                           for n,st in self._symtyps.items() ])
        return (f"Function({args}) : {self._expr.typ}\n"
                f"    {self._expr.__str__(0,'    ')}")

Let's define a quick string representation just so that we're not flying totally blind on what the results of the typechecking were.

In [15]:
def IR_typ_str_rep(t):
    tclass = type(t)
    s      = "ERROR"
    if tclass is IR.TNum:
        s  = "R"
    elif tclass is IR.TPair:
        s  = f"({IR_typ_str_rep(t.lhs)},{IR_typ_str_rep(t.rhs)})"
    elif tclass is IR.TTensor:
        s  = f"[{t.range}]{IR_typ_str_rep(t.typ)}"
    return s
IR.type.__str__ = IR_typ_str_rep
    
def IR_str_rep(e,prec=0,indent=""):
    def sub(e,p):
        return IR_str_rep(e,p,indent)
    eclass = type(e)
    s      = "ERROR"
    if   eclass is IR.Var:
        s = e.name
    elif eclass is IR.Const:
        s = str(e.val)
    elif eclass is IR.Add:
        s = f"{sub(e.lhs,2)} + {sub(e.rhs,2)}"
        if prec > 2: s = f"({s})"
    elif eclass is IR.Mul:
        s = f"{sub(e.lhs,3)} * {sub(e.rhs,3)}"
        if prec > 3: s = f"({s})"
    elif eclass is IR.Pair:
        s = f"({sub(e.lhs,0)},{sub(e.rhs,0)})"
    elif eclass is IR.Proj:
        s = f"{sub(e.arg,4)}.{e.idx}"
        if prec > 4: s = f"({s})"
    elif eclass is IR.Gen or eclass is IR.Sum:
        op = "+" if eclass is IR.Sum else "Gen"
        s = f"{op}({e.idxname}:{e.range}) {sub(e.body,1)}"
        if prec > 1: s = f"({s})"
    elif eclass is IR.Access:
        s = f"{sub(e.base,5)}[{e.idx}]"
        if prec > 5: s = f"({s})"
    elif eclass is IR.Indicate:
        assert isinstance(e.arg, IR.Eq), 'sanity: pred is Eq'
        s = f"[{e.arg.lhs}={e.arg.rhs}]*{sub(e.body,3)}"
        if prec > 3: s = f"({s})"
    elif eclass is IR.Let:
        # note that this is ill-behaved formatting
        # for lets nested inside of expressions
        rhs  = sub(e.rhs,0)
        body = sub(e.body,0)
        s = f"let {e.name} : {e.rhs.typ} = {rhs} in\n{indent}{body}"
        if prec > 0: s = f"({s})"
    return s

IR.expr.__str__ = IR_str_rep

In [16]:
n   = 3
R   = IR.tnum
Rn  = IR.TTensor(n,R)
Rnn = IR.TTensor(n,Rn)
x   = IRv0.Var('x')
A   = IRv0.Var('A')
c   = IRv0.Var('c')
i   = 'i'
j   = 'j'

y       = IRv0.Gen(i,n, c[i] + IRv0.Sum(j,n, A[i,j] * x[j] ))
tr      = IRv0.Sum(i,n, A[i,i])
diag    = IRv0.Gen(i,n,IRv0.Gen(j,n, IRv0.Eq(i,j) * x[i] ))

Ty      = TypedFunc([('A',Rnn),('x',Rn),('c',Rn)], y)
Ttr     = TypedFunc([('A',Rnn)], tr)
Tdiag   = TypedFunc([('x',Rn)], diag)

print(Ty)
print(Ttr)
print(Tdiag)

Function(A:[3][3]R, x:[3]R, c:[3]R) : [3]R
    Gen(i:3) c[i] + (+(j:3) A[i][j] * x[j])
Function(A:[3][3]R) : R
    +(i:3) A[i][i]
Function(x:[3]R) : [3][3]R
    Gen(i:3) Gen(j:3) [i=j]*x[i]


## Eliminating Pairs (now with Typing!)

We had three big problems when we last left off.

1. If we try to eliminate pairs from an expression whose input or output types involve pairs, we may change the type of the expression.
2. We are going to change the type of the expression internally.  Is this safe?
3. We need to work through `Let` bindings; can types help with that?

Key to all of these problems is the ability to analyze a type that mixes tensor-ing and pair-ing.  Let $A \times B$ be a pair type and $[n]A$ be a tensor type.  Then, in some sense $[n](A \times B) \cong [n]A \times [n]B$; that is, they define the same data.  In fact, this change-in-type is what we would call the AoS-to-SoA (array-of-structs to struct-of-arrays) transform.  Pushing pair constructs outwards will tend to induce this transformation to our code.

On the type level, we can ask two simple questions.  First, does this type contain pairs somewhere?  Second, if we were to fully perform this transform, what would the resulting type be?  Both are useful functions to have around


In [17]:
def typ_has_pairs(t):
    tclass = type(t)
    if tclass is IR.TPair: return True
    elif tclass is IR.TTensor:
        return typ_has_pairs(t.typ)
    else: # tnum or terr
        return False

def typ_SoA_transform(t,rngs=[]):
    tclass = type(t)
    if t is IR.tnum or t is IR.terr:
        # possibly unroll the ranges stack here
        if len(rngs) > 0:
            for r in reversed(rngs):
                t = IR.TTensor(r,t)
        return t
    elif tclass is IR.TTensor:
        rngs = rngs.copy()
        rngs.append(t.range)
        return typ_SoA_transform(t.typ,rngs)
    elif tclass is IR.TPair:
        return IR.TPair( typ_SoA_transform(t.lhs,rngs),
                         typ_SoA_transform(t.rhs,rngs) )
    else: assert false, "impossible case"


In [18]:
print(typ_has_pairs(Rn))
print(typ_SoA_transform( Rn ))
print(typ_has_pairs(Rnn))
print(typ_SoA_transform( Rnn ))
Complex = IR.TPair(R,R)
print(typ_has_pairs(Complex))
Cn      = IR.TTensor(n,Complex)
print(Cn)
print(typ_has_pairs(Cn))
print(typ_SoA_transform(Cn))

False
[3]R
False
[3][3]R
True
[3](R,R)
True
([3]R,[3]R)


We can use this type transformation to guide our destructuring of `Let`-bound variables.  Suppose we have a binding $\textrm{let } x = e_1 \textrm{ in } e_2$ where the type of $x$ contains pairs.  Then, we can destructure it in the following way.  ($\{ var \mapsto expr \}$ denotes a substitution in the following term)

$$\begin{array}{l}
\textrm{let } x_0 = \pi_0 e_1 \textrm{ in } \\
\textrm{let } x_1 = \pi_1 e_1 \textrm{ in } \\
(\{x \mapsto (x_0,x_1)\}e_2)
\end{array}$$

Using the SoA transform as a guide, we can fully expand the let-binding into pair-free bindings.  Of course, these bindings will then have to be processed by pushing down the new $\pi$ projections themselves.  However, doing this will suffice to ensure that all variable names become free of pairs before processing the definition and body separately.

Note that this is only possible because we can now query our IR for the type in the middle, at the let-bindings.

Recalling the previous rules (below), this gives us a complete rewrite system.  So long as none of the inputs or output types of an expression have pairs in them, this system should totally eliminate pairs from the code without otherwise significantly damaging the code structure.

$$
\begin{array}{rcl}
\pi_k ( e_0, e_1 ) &\to& e_k \\
\pi_k [i=j] e &\to& [i=j] \pi_k e \\
\pi_k \boxplus_i e &\to& \boxplus_i \pi_k e \\
\pi_k e[i] &\to& (\pi_k e)[i] \\
\end{array}
$$

We can adopt a push-down focused strategy in this case, where all of the projections are pushed down until they hit and annihilate the pair constructors.  As such, the algorithm proceeds by keeping track of a stack of projections currently being pushed down.  It must also keep around a variable substitution environment to handle the let-binding rule.


In [19]:
class PairEliminator:
    def __init__(self,expr,arg_symtyps,rettyp):
        self._expr      = expr
        self._ctxt      = _Context()
        
        if typ_has_pairs(rettyp):
            raise TypeError(f"Cannot eliminate pairs from a function "
                            f"whose return type is a pair")
        
        for sym,typ in arg_symtyps.values():
            if typ_has_pairs(typ):
                raise TypeError(f"Cannot eliminate pairs from a function "
                                f"with pair-typed arguments")
            self._ctxt.set(sym,IR.Var(sym,typ))
        
        # do elimination
        self._out_expr  = self.pushdown(expr)
    
    def get_result(self):
        return self._out_expr
    
    def _get_ivar_sub(self,nm):
        i = self._ctxt.get(nm)
        assert not i is None, f'should be caught earlier: {nm}'
        return i
    
    def _get_var_sub(self,nm):
        e = self._ctxt.get(nm)
        assert not e is None, f'should be caught earlier: {nm}'
        return e
    
    def pushdown(self, e, projstk = []):
        eclass = type(e)
        if   eclass is IR.Var:
            subst_e = self._get_var_sub(e.name)
            # prevent infinite recursion and needless duplications
            if (type(subst_e) is IR.Var and
                subst_e.name == e.name and
                subst_e.typ  == e.typ):
                    assert(len(projstk) == 0)
                    return e
            else: # need to continue pushdown
                return self.pushdown(subst_e, projstk)
        
        elif eclass is IR.Const:
            assert len(projstk) == 0
            return e
        
        elif eclass is IR.Add or eclass is IR.Mul:
            assert len(projstk) == 0
            lhs = self.pushdown(e.lhs, projstk)
            rhs = self.pushdown(e.rhs, projstk)
            return eclass(lhs, rhs, e.typ)
        
        # deconstruct it!
        elif eclass is IR.Pair:
            assert len(projstk) > 0
            proj_i = projstk.pop()
            if proj_i == 0:
                return self.pushdown(e.lhs, projstk)
            else:
                return self.pushdown(e.rhs, projstk)
        
        elif eclass is IR.Proj:
            projstk.append(e.idx)
            return self.pushdown(e.arg, projstk)
        
        elif eclass is IR.Gen or eclass is IR.Sum:
            self._ctxt.push()
            idxname = e.idxname.copy()
            self._ctxt.set(e.idxname,idxname)
            body    = self.pushdown(e.body, projstk)
            if eclass is IR.Gen:
                TensorType = IR.TTensor(e.range, body.typ)
                return IR.Gen(idxname, e.range, body, TensorType)
            else:
                assert body.typ == IR.tnum
                return IR.Sum(idxname, e.range, body, body.typ)
            
        elif eclass is IR.Access:
            base    = self.pushdown(e.base, projstk)
            idx     = self._get_ivar_sub(e.idx)
            assert type(base.typ) is IR.TTensor
            return IR.Access(base, idx, base.typ.typ)
            
        elif eclass is IR.Indicate:
            lidx    = self._get_ivar_sub(e.arg.lhs)
            ridx    = self._get_ivar_sub(e.arg.rhs)
            body    = self.pushdown(e.body, projstk)
            return IR.Indicate(IR.Eq(lidx,ridx), body, body.typ)
            
        elif eclass is IR.Let:
            soa_typ = typ_SoA_transform(e.rhs.typ)

            # unpack the soa_typ into projections
            binds   = []
            def soa_unpack(nm,T,projstk=[]):
                if type(T) is IR.TPair:
                    projstk.insert(0,0)
                    lhs     = soa_unpack(nm+'0',T.lhs,projstk)
                    projstk.pop()
                    projstk.insert(0,1)
                    rhs     = soa_unpack(nm+'1',T.rhs,projstk)
                    projstk.pop()
                    assert lhs.typ == T.lhs
                    assert rhs.typ == T.rhs
                    return IR.Pair(lhs,rhs,T)
                else:
                    rval    = self.pushdown(e.rhs,projstk.copy())
                    sym     = Sym(nm)
                    assert rval.typ == T
                    binds.append((sym,rval))
                    return IR.Var(sym,T)
            nm = e.name.name() + ('_' if (type(soa_typ) is IR.TPair) else '')
            subst = soa_unpack(nm, soa_typ)
            
            # bind the soa-transformed variables, and rewrite body
            self._ctxt.push()
            self._ctxt.set(e.name,subst)
            # ensure termination at the new let-bound variables
            for sym,rhs in binds:
                self._ctxt.set(sym,IR.Var(sym,rhs.typ))
            body    = self.pushdown(e.body,projstk)
            self._ctxt.pop()
            
            # construct the resulting let-binding chain
            for sym,rhs in reversed(binds):
                body = IR.Let(sym,rhs,body,body.typ)
            return body

def _TypedFunc_eliminate_pairs(self):
    e       = self._expr
    atyps   = self._symtyps
    self._expr = PairEliminator(e,atyps,e.typ).get_result()
TypedFunc.eliminate_pairs = _TypedFunc_eliminate_pairs


We need some example to test.  How about this one?

In [20]:
a   = IRv0.Var('a')
b   = IRv0.Var('b')
i,j = 'i','j'
ab3 = IRv0.Let('c3',IRv0.Gen(j,n, IRv0.Pair( IRv0.Pair(a[j], b[j]),
                                             a[j]*b[j] )),
               IRv0.Sum(j,n, IRv0.Proj(0, IRv0.Proj(0, IRv0.Var('c3')[j])) +
                             IRv0.Proj(1, IRv0.Proj(0, IRv0.Var('c3')[j])) +
                             IRv0.Proj(1, IRv0.Var('c3')[j]) ))

Fab3 = TypedFunc( [('a',Rn),('b',Rn)], ab3 )
print(Fab3)
Fab3.eliminate_pairs()
print(Fab3)

Function(a:[3]R, b:[3]R) : R
    let c3 : [3]((R,R),R) = Gen(j:3) ((a[j],b[j]),a[j] * b[j]) in
    +(j:3) c3[j].0.0 + c3[j].0.1 + c3[j].1
Function(a:[3]R, b:[3]R) : R
    let c3_00 : [3]R = Gen(j:3) a[j] in
    let c3_01 : [3]R = Gen(j:3) b[j] in
    let c3_1 : [3]R = Gen(j:3) a[j] * b[j] in
    +(j:3) c3_00[j] + c3_01[j] + c3_1[j]


## Let-Flattening

We have now eliminated two forms from the IR—two that we weren't sure how to compile into Halide.  However, doing this hardly takes us down to the point where we can simply emit Halide code.  In Halide, all statements occur at the top-level (i.e. not within other statements).  However, our language allows for nested `Let` bindings.  For instance, consider the following program.

```
Gen(i:n) Sum(j:n) let m = Sum(k:n) D[i,j,k]*a[k] in  m * b[j]
```

How can we move the `let m = ...` definition to the outermost-level?  If we did that naively, we would violate the scoping of the `i` and `j` variables.  Our only option is to close the right-hand-side of the binding in those variables.  That is,

```
let m = Gen(i:n) Gen(j:n) Sum(k:n) D[i,j,k]*a[k] in
Gen(i:n) Sum(j:n) m[i,j] * b[j]
```

In [21]:
k = 'k'
D = IRv0.Var('D')
Rnnn = IR.TTensor(n,Rnn)

contract    = IRv0.Gen(i,n, IRv0.Sum(j,n,
                IRv0.Let('m', IRv0.Sum(k,n, D[i,j,k] * a[k]),
                         IRv0.Var('m') * b[j] )))

Fcontract   = TypedFunc([('D',Rnnn),('a',Rn),('b',Rn)],contract)
print(Fcontract)

Function(D:[3][3][3]R, a:[3]R, b:[3]R) : [3]R
    Gen(i:3) +(j:3) (let m : R = +(k:3) D[i][j][k] * a[k] in
    m * b[j])


We can think of `Let`-flattening as a process that "pulls up" the `Let` bindings past all the other constructs in the language.  On a coding level, our goal will be to transform the program into a "block" consisting of a sequence of name/expression pairs, (i.e. statements) terminated by a "return" expression.  Sketching this rule means looking at how `Let` commutes with every other expression when "pulled up"...

$$\begin{array}{rcl}
e_0 + (\textrm{let } x = e_1 \textrm{ in } e_2)
&\to& \textrm{let } x = e_1 \textrm{ in } (e_0 + e_2) \\
e_0 \cdot (\textrm{let } x = e_1 \textrm{ in } e_2)
&\to& \textrm{let } x = e_1 \textrm{ in } (e_0 \cdot e_2) \\
[i=j] \cdot (\textrm{let } x = e_1 \textrm{ in } e_2)
&\to& \textrm{let } x = e_1 \textrm{ in } ([i=j] \cdot e_2) \\
\boxplus_i\ (\textrm{let } x = e_1 \textrm{ in } e_2)
&\to& \textrm{let } x = (\boxplus_i\ e_1) \textrm{ in }
      (\boxplus_i\ \{x \mapsto x[i] \}e_2) \\
\sum_i\ (\textrm{let } x = e_1 \textrm{ in } e_2)
&\to& \textrm{let } x = (\boxplus_i\ e_1) \textrm{ in }
      (\sum_i\ \{x \mapsto x[i] \}e_2) \\
(\textrm{let } x = e_1 \textrm{ in } e_2)[i]
&\to& \textrm{let } x = e_1 \textrm{ in } (e_2[i]) \\
\textrm{let } x = (\textrm{ let } y = e_0 \textrm{ in } e_1) \textrm{ in } e_2
&\to& \textrm{let } y = e_0 \textrm{ in } (\textrm{ let } x = e_1 \textrm{ in } e_2) \\
\end{array}$$

(_note of course that the + and * rules apply symmetrically with let on the left-hand-side of the operator_)

If `Let` occurs in the body of another `Let`, that's fine.  That's exactly what we're aiming for when we talk about a "block".  In fact, these rules can all be generalized to "blocks" by understanding that to be the form of the `Let`-expressions being pulled up.

In [22]:
class LetFlatten:
    def __init__(self,expr,arg_symtyps,rettyp):
        self._expr      = expr
        #self._ctxt      = _Context()
        
        # lift
        binds, ret_e    = self.letlift(expr)
        
        # construct the final let-chained expression
        e   = ret_e
        for nm,rhs in reversed(binds):
            e = IR.Let(nm,rhs,e,e.typ)
        self._out_expr  = e
    
    def get_result(self):
        return self._out_expr
    
    def letlift(self, e):
        eclass = type(e)
        assert not eclass is IR.Pair, "pairs should be eliminated"
        assert not eclass is IR.Proj, "pairs should be eliminated"
        if   eclass is IR.Var:
            return [],e
        elif eclass is IR.Const:
            return [],e
        
        elif eclass is IR.Add or eclass is IR.Mul:
            lbind, lhs  = self.letlift(e.lhs)
            rbind, rhs  = self.letlift(e.rhs)
            
            return lbind + rbind, eclass(lhs, rhs, e.typ)
        
        elif eclass is IR.Gen or eclass is IR.Sum:
            binds, body = self.letlift(e.body)
            ctxt        = _Context()
            i           = e.idxname
            rng         = e.range
            
            new_binds   = []
            for nm,rhs in binds:
                rhs     = self.subst(ctxt,rhs)
                T       = rhs.typ
                TensorT = IR.TTensor(rng,T)
                new_rhs = IR.Gen(i,rng,rhs,TensorT)
                ctxt.set( nm, IR.Access(IR.Var(nm,TensorT), i, T) )
                new_binds.append( (nm, new_rhs) )
            new_body    = eclass(i,rng, self.subst(ctxt,body), e.typ)
            
            return new_binds, new_body
            
        elif eclass is IR.Access:
            binds, base = self.letlift(e.base)
            return binds, IR.Access(base, e.idx, e.typ)
            
        elif eclass is IR.Indicate:
            binds, body = self.letlift(e.body)
            return binds, IR.Indicate(e.arg, body, e.typ)
            
        elif eclass is IR.Let:
            binds0, rhs     = self.letlift(e.rhs)
            binds1, body    = self.letlift(e.body)
            binds           = binds0 + [(e.name,rhs)] + binds1
            return binds, body
    
    def subst(self, env, e):
        eclass = type(e)
        assert not eclass is IR.Pair, "pairs should be eliminated"
        assert not eclass is IR.Proj, "pairs should be eliminated"
        assert not eclass is IR.Let
        if   eclass is IR.Var:
            sub = env.get(e.name)
            return e if sub is None else sub
        
        elif eclass is IR.Const:
            return e
        
        elif eclass is IR.Add or eclass is IR.Mul:
            lhs = self.subst(env, e.lhs)
            rhs = self.subst(env, e.rhs)
            return eclass(lhs,rhs,e.typ)
        
        elif eclass is IR.Gen or eclass is IR.Sum:
            body    = self.subst(env, e.body)
            return eclass(e.idxname, e.range, body, e.typ)
            
        elif eclass is IR.Access:
            base    = self.subst(env, e.base)
            return IR.Access(base, e.idx, e.typ)
            
        elif eclass is IR.Indicate:
            body    = self.subst(env, e.body)
            return IR.Indicate(e.arg, body, e.typ)

def _TypedFunc_lift_lets(self):
    e       = self._expr
    atyps   = self._symtyps
    self._expr = LetFlatten(e,atyps,e.typ).get_result()
TypedFunc.lift_lets = _TypedFunc_lift_lets


In [23]:
Fcontract   = TypedFunc([('D',Rnnn),('a',Rn),('b',Rn)],contract)
print(Fcontract)
Fcontract.lift_lets()
print(Fcontract)
# check for fixed-point behaviors
Fcontract.eliminate_pairs()
print(Fcontract)
Fcontract.lift_lets()
print(Fcontract)

Function(D:[3][3][3]R, a:[3]R, b:[3]R) : [3]R
    Gen(i:3) +(j:3) (let m : R = +(k:3) D[i][j][k] * a[k] in
    m * b[j])
Function(D:[3][3][3]R, a:[3]R, b:[3]R) : [3]R
    let m : [3][3]R = Gen(i:3) Gen(j:3) +(k:3) D[i][j][k] * a[k] in
    Gen(i:3) +(j:3) m[i][j] * b[j]
Function(D:[3][3][3]R, a:[3]R, b:[3]R) : [3]R
    let m : [3][3]R = Gen(i:3) Gen(j:3) +(k:3) D[i][j][k] * a[k] in
    Gen(i:3) +(j:3) m[i][j] * b[j]
Function(D:[3][3][3]R, a:[3]R, b:[3]R) : [3]R
    let m : [3][3]R = Gen(i:3) Gen(j:3) +(k:3) D[i][j][k] * a[k] in
    Gen(i:3) +(j:3) m[i][j] * b[j]


## Gen-Normalization

Flattening out all of the `Let` bindings into a statement block is a significant step towards compilation.  Is it the case that we can now translate all of these statements into Halide statements and be done with everything?  No.

We would like to compile tensor bindings of the form
```
let f = Gen(i,j,...) e in
...
```
where `e` is an expression absent any `Gen`s.

To accomplish this, we must push down any `Access` occurences to eliminate spurious `Gen`s in the middle of the expressions.  Then, we must also push down any `Indicate` wrappers so that they only occur inside of the generators.

Here are two examples to play with.  Note how the latter example includes an indicator function that might allow further simplifications but need not be used in that way.

In [24]:
prod    = IRv0.Gen(i,n, IRv0.Sum(j,n, A[i,j] * x[j] ))
qprod   = IRv0.Sum(k,n, prod[k]*x[k] )

Fqprod  = TypedFunc([('A',Rnn),('x',Rn)],qprod)
print(Fqprod)

diag    = IRv0.Gen(i,n, IRv0.Gen(j,n, IRv0.Eq(i,j) * x[i] ))
dprod   = IRv0.Sum(j,n, IRv0.Sum(i,n, a[j] * diag[j,i] * a[i] ))

Fdprod  = TypedFunc([('x',Rn),('a',Rn)],dprod)
print(Fdprod)

Function(A:[3][3]R, x:[3]R) : R
    +(k:3) (Gen(i:3) +(j:3) A[i][j] * x[j])[k] * x[k]
Function(x:[3]R, a:[3]R) : R
    +(j:3) +(i:3) a[j] * (Gen(i:3) Gen(j:3) [i=j]*x[i])[j][i] * a[i]


Doing the minimum necessary simplifications will require invoking at least the following rewrite rules to lift all generators outwards.

$$\begin{array}{rcl}
(\boxplus_i\ e)[j]
  &\to& \{i \mapsto j\} e \\
[j=k] \cdot (\boxplus_i\ e)
  &\to& \boxplus_i\ [j=k] \cdot e \\
\end{array}$$

Another set of rewrites that can be handled concurrently tries to lift indicators out as far as possible without escaping `Let` bindings.  This can drastically reduce computational overhead in some cases.
  
$$\begin{array}{rcl}
([i=j]\cdot e)[k] &\to& [i=j]\cdot(e[k]) \\
([i=j]\cdot e_1) \cdot e_2 &\to& [i=j]\cdot(e_1 \cdot e_2) \\
\sum_i ([j=k]\cdot e) &\to& [j=k]\cdot(\sum_i e) \\
\sum_i ([i=j]\cdot e) &\to& \{i \mapsto j\}e \\
\end{array}$$

Both of these sets of rewrites work by lifting out the focused form (`Gen` or `Indicate`).  However, the rule of the first set ensures that all `Gen` will end up outside all `Indicate` forms.  Therefore, a combination of these two strategies would seek to percolate outwards these forms together.  Some indicators may get trapped under simple sums; others will interact with big sums.

In [25]:
class GenNormalize:
    def __init__(self,expr,arg_symtyps,rettyp):
        self._expr = expr
        
        # break out into a list of bindings and return expr
        binds, ret = self.to_block(expr)
        
        # normalize each expr
        rebinds    = []
        for sym,rhs in binds:
            r      = self.final_lift_gen(rhs)
            rebinds.append((sym,r))
        ret        = self.final_lift_gen(ret)
        
        self._out_expr  = self.from_block(rebinds,ret)
    
    def get_result(self):
        return self._out_expr
    
    def to_block(self, e):
        binds = []
        while type(e) is IR.Let:
            binds.append((e.name,e.rhs))
            e = e.body
        return binds, e
    
    def from_block(self, binds, ret):
        e = ret
        for sym,rhs in reversed(binds):
            e = IR.Let(sym,rhs,e,e.typ)
        return e
    
    def wrap_ind(self, inds, e):
        for p in inds:
            e = IR.Indicate(p, e, e.typ)
        return e
    
    def final_lift_gen(self, e):
        gen, ind, e = self.lift_gen(e)
        e           = self.wrap_ind(ind,e)
        for idx,rng in gen:
            T       = IR.TTensor(rng,e.typ)
            e       = IR.Gen(idx,rng,e,T)
        return e
    
    def lift_gen(self, e):
        eclass = type(e)
        assert not eclass is IR.Pair, "pairs should be eliminated"
        assert not eclass is IR.Proj, "pairs should be eliminated"
        assert not eclass is IR.Let
        if   eclass is IR.Var or eclass is IR.Const:
            return [], [], e
        
        elif eclass is IR.Add or eclass is IR.Mul:
            lgen, lind, lhs = self.lift_gen(e.lhs)
            rgen, rind, rhs = self.lift_gen(e.rhs)
            assert len(lgen) == 0
            assert len(rgen) == 0
            if eclass is IR.Add:
                lhs = self.wrap_ind(lind,lhs)
                rhs = self.wrap_ind(rind,rhs)
                return [], [], IR.Add(lhs,rhs,e.typ)
            else:
                return [], lind + rind, IR.Mul(lhs,rhs,e.typ)
        
        elif eclass is IR.Gen:
            gen, ind, body  = self.lift_gen(e.body)
            gen.append((e.idxname, e.range))
            return gen, ind, body
        
        elif eclass is IR.Sum:
            gen, ind, body  = self.lift_gen(e.body)
            assert len(gen) == 0
            # go through indicators and maybe find sum collapse
            i = e.idxname
            j = None
            for p in ind:
                if   p.lhs == i: j = p.rhs; break
                elif p.rhs == i: j = p.lhs; break
            if j is None:
                body = IR.Sum(i,e.range,body,body.typ)
            else:
                ind, body   = self.subst(i,j,ind,body)
            return [], ind, body
            
        elif eclass is IR.Access:
            gen, ind, base  = self.lift_gen(e.base)
            if len(gen) > 0:
                i,r         = gen.pop()
                ind, base   = self.subst(i,e.idx,ind,base)
                return gen, ind, base
            else:
                return gen, ind, IR.Access(base, e.idx, e.typ)
            
        elif eclass is IR.Indicate:
            gen, ind, body  = self.lift_gen(e.body)
            # sanity!
            for i,r in gen:
                assert i != e.arg.lhs and i != e.arg.rhs
            ind.append(e.arg)
            return gen, ind, body
            
    def sub_ind(self, old, new, p):
        assert type(p) is IR.Eq
        lhs = new if p.lhs == old else p.lhs
        rhs = new if p.rhs == old else p.rhs
        return None if lhs == rhs else IR.Eq(lhs,rhs)
    
    def subst(self, old, new, ind, e):
        # if there are indicators, process separately
        if len(ind) > 0:
            new_ind = []
            for p in ind:
                p = self.sub_ind(old,new,p)
                if not p is None: new_ind.append(p)
            # substitute the expression without indicators
            _, e = self.subst(old,new,[],e)
            return new_ind, e
        
        # the usual case for things other than indicator lists
        eclass = type(e)
        assert not eclass is IR.Pair, "pairs should be eliminated"
        assert not eclass is IR.Proj, "pairs should be eliminated"
        assert not eclass is IR.Let
        assert not eclass is IR.Gen
        if   eclass is IR.Var or eclass is IR.Const:
            return [], e
        
        elif eclass is IR.Add or eclass is IR.Mul:
            _, lhs  = self.subst(old,new,ind,e.lhs)
            _, rhs  = self.subst(old,new,ind,e.rhs)
            return [], eclass(lhs,rhs,e.typ)
        
        elif eclass is IR.Sum:
            # probably shouldn't happen?
            if e.idxname == old:
                return e
            _, body = self.subst(old,new,[],e.body)
            return [], eclass(e.idxname, e.range, body, e.typ)
            
        elif eclass is IR.Access:
            idx     = new if e.idx == old else e.idx
            _, base = self.subst(old,new,[],e.base)
            return [], IR.Access(base, idx, e.typ)
            
        elif eclass is IR.Indicate:
            arg     = self.sub_ind(old,new,e.arg)
            _, body = self.subst(old,new,[],e.body)
            if arg is None: return [], body
            else:           return [], IR.Indicate(arg,body,body.typ)

def _TypedFunc_gen_normalize(self):
    e       = self._expr
    atyps   = self._symtyps
    self._expr = GenNormalize(e,atyps,e.typ).get_result()
TypedFunc.gen_normalize = _TypedFunc_gen_normalize


In [26]:

Fqprod  = TypedFunc([('A',Rnn),('x',Rn)],qprod)
print(Fqprod)
Fqprod.gen_normalize()
print(Fqprod)
print()

Fdprod  = TypedFunc([('x',Rn),('a',Rn)],dprod)
print(Fdprod)
Fdprod.gen_normalize()
print(Fdprod)

Function(A:[3][3]R, x:[3]R) : R
    +(k:3) (Gen(i:3) +(j:3) A[i][j] * x[j])[k] * x[k]
Function(A:[3][3]R, x:[3]R) : R
    +(k:3) (+(j:3) A[k][j] * x[j]) * x[k]

Function(x:[3]R, a:[3]R) : R
    +(j:3) +(i:3) a[j] * (Gen(i:3) Gen(j:3) [i=j]*x[i])[j][i] * a[i]
Function(x:[3]R, a:[3]R) : R
    +(j:3) a[j] * x[j] * a[j]


# An Implied Normal Form

At this point, all of our valid programs (excepting those with pair-typed inputs and outputs) have been aggressively normalized.  However, it may seem hard to keep track of exactly what that normal form is.  Let's try to provide a loose BNF we can refer to here.

$$\begin{array}{rcl}
  a &=&    a[i]
    \ |\   x \\
  e &=&    a
    \ |\   c
    \ |\   e_0 + e_1
    \ |\   e_0 \cdot e_1
    \ |\   \sum_i e
    \ |\   [i=j] \cdot e \\
  g &=&    \boxplus_i g
  \ \ |\ \ e \\
  s &=&    \textrm{let } x = g \textrm{ in } s
  \ \ |\ \ g \\
\end{array}$$

Observe how successive layers of forms have been peeled back into a standard nesting.  We have a sequence of statements, whose expressions begin with all the generators, followed by an expression totally free of `Let` and `Gen`.  This can now be translated to Halide, because each such line is a pure-definition of a Func. with the final expression being the output func.

This normal form can be extracted after all the preceding transformations by using the following function.

In [27]:
def _get_block_norm(e):
    stmts   = []
    expr    = e
    while type(expr) is IR.Let:
        nm      = expr.name
        e       = expr.rhs
        gens    = []
        while type(e) is IR.Gen:
            i   = e.idxname
            r   = e.range
            gens.append((i,r))
            e   = e.body
        stmts.append((nm,gens,e))
        expr    = expr.body
    body_gens   = []
    while type(expr) is IR.Gen:
        i       = expr.idxname
        r       = expr.range
        body_gens.append((i,r))
        expr    = expr.body
        
        
    # stmts has form [( var_name, [( i_name, range )], (body_gens, body_expr) )]
    return stmts, (body_gens,expr)


----
----

# Code Generation

In order to generate code, we need to manage a context of name bindings, as well as inputs and outputs finally.

In [28]:
class Halide_CodeGen:
    def __init__(self,expr,arg_symtyps,rettyp):        
        # break out into a list of bindings and return expr
        stmts, ret  = _get_block_norm(expr)
        self._orig_stmts = stmts
        self._orig_ret  = ret
        
        self._ctxt      = _Context()
        self._first_run = True
        
        # helper function to process tensor types
        def shape_dim(T):
            if not type(T) is IR.TTensor: return []
            else: return [T.range] + shape_dim(T.typ)
        
        # create Halide inputs for each argument
        # inputs may be 'param's (scalars) or
        #               'img_param's (tensors)
        Hf64        = halide_type_t(C.type_float,64,1)
        arg_params  = {}
        for nm,typ in arg_symtyps.values():
            nm_bytes    = repr(nm).encode('utf-8')
            shape       = shape_dim(typ)
            if len(shape) == 0: # scalar case
                P = C.hwrap_new_param(nm_bytes,Hf64)
                arg_params[nm.name()] = P
                # also bind the arg symbol to an Expr
                E = C.hwrap_param_to_expr(P)
                self._ctxt.set( nm, E )
            else: # tensor case
                Img = C.hwrap_new_img(nm_bytes,len(shape),Hf64)
                arg_params[nm.name()] = Img
                # also bind the arg symbol to a Func
                F   = C.hwrap_img_to_func(Img)
                self._ctxt.set( nm, F )
                # Halide auto-scheduling requires an estimate
                # of tensor size, which we extract from the type
                for i,r in enumerate(shape):
                    C.hwrap_set_img_bound_estimate(Img,i,
                        C.hwrap_i32_to_expr(0),C.hwrap_i32_to_expr(r))
        self._arg_params    = arg_params
        self._arg_typs      = { nm : symtyp[1]
                                for nm,symtyp in arg_symtyps.items() }
        
        # compile each statement
        for name,gens,body in stmts:
            self._compile_stmt(name,gens,body)
        
        # compile the return expression as a statement
        self._ret_sym       = Sym('return')
        self._ret_Func      = self._compile_stmt(self._ret_sym,
                                                 ret[0], ret[1])
        self._ret_typ       = rettyp
        # also provide output estimates for auto-scheduling
        if type(rettyp) is IR.TTensor:
            for i,r in enumerate(shape_dim(rettyp)):
                C.hwrap_set_func_bound_estimate(self._ret_Func,i,
                    C.hwrap_i32_to_expr(0),C.hwrap_i32_to_expr(r))
        else: # scalar temporaries encoded as length 1 arrays
            C.hwrap_set_func_bound_estimate(self._ret_Func,0,
                C.hwrap_i32_to_expr(0),C.hwrap_i32_to_expr(1))
        
    
    def _compile_stmt(self, name, gens, expr):
        name_bytes  = repr(name).encode('utf-8')
        F           = C.hwrap_new_func(name_bytes)
        
        self._ctxt.push()
        
        # create index variables
        n_dims      = len(gens)
        i_var_arr   = None
        handles     = [] # to prevent garbage collection
        if n_dims == 0:
            i       = Sym(name.name()+"_0idx")
            i_bytes = repr(i).encode('utf-8')
            V       = C.hwrap_new_var(i_bytes)
            handles.append(V)
            i_var_arr   = (hw_var_t * 1)(V)
            n_dims      = 1
        else:
            i_var_arr   = (hw_var_t * n_dims)()
            for k,ir in enumerate(gens):
                i, r    = ir
                i_bytes = repr(i).encode('utf-8')
                V       = C.hwrap_new_var(i_bytes)
                E       = C.hwrap_var_to_expr(V)
                handles.append(V)
                handles.append(E)
                # pack var into lhs array
                i_var_arr[k] = V
                # store expr in context
                self._ctxt.set(i,E)
        
        # compile rhs expr
        rhs = self._compile_expr(expr)
        self._ctxt.pop()
        
        # add the statement to the program
        C.hwrap_pure_def(F,n_dims,i_var_arr,rhs)
        
        # add the function to the context
        if len(gens) == 0:
            zero    = C.hwrap_i32_to_expr(0)
            z_arr   = (hw_expr_t * 1)(zero)
            E       = C.hwrap_access_func(F,1,z_arr)
            self._ctxt.set(name,E)
        else:
            self._ctxt.set(name,F)
        return F
    
    def _compile_expr(self,e):
        eclass = type(e)
        assert not eclass is IR.Pair, "pairs should be eliminated"
        assert not eclass is IR.Proj, "pairs should be eliminated"
        assert not eclass is IR.Let
        assert not eclass is IR.Gen
        if   eclass is IR.Var:
            expr = self._ctxt.get(e.name)
            assert expr != None
            return expr
        
        elif eclass is IR.Const:
            expr = C.hwrap_f64_to_expr(e.val)
            return expr
        
        elif eclass is IR.Add or eclass is IR.Mul:
            lhs  = self._compile_expr(e.lhs)
            rhs  = self._compile_expr(e.rhs)
            C_op = C.hwrap_add if eclass is IR.Add else C.hwrap_mul
            res  = C_op(lhs,rhs)
            return res
        
        elif eclass is IR.Sum:
            self._ctxt.push()
            # create rdom
            rb   = repr(e.idxname).encode('utf-8')
            lo   = C.hwrap_i32_to_expr(0)
            hi   = C.hwrap_i32_to_expr(e.range)
            rng  = (hw_expr_t * 2)(lo,hi)
            rdom = C.hwrap_new_rdom(rb,1,rng)
            # bind rdom in context as expr
            rexp = C.hwrap_rdom_to_expr(rdom)
            self._ctxt.set(e.idxname,rexp)
            # finally compile body...
            body = self._compile_expr(e.body)
            self._ctxt.pop()
            
            res  = C.hwrap_big_sum(rdom,body)
            return res
        
        elif eclass is IR.Access:
            return self._compile_access(e)
        
        elif eclass is IR.Indicate:
            li      = self._ctxt.get(e.arg.lhs)
            ri      = self._ctxt.get(e.arg.rhs)
            assert type(li) == hw_expr_t
            assert type(ri) == hw_expr_t
            
            eq      = C.hwrap_eq(li,ri)
            body    = self._compile_expr(e.body)
            zero    = C.hwrap_f64_to_expr(0.0)
            res     = C.hwrap_select(eq, body, zero)
            return res
            
        else: assert False, "unrecognized IR case"
    
    def _compile_access(self,e):
        idxs    = []
        while type(e) is IR.Access:
            # note that we pull off accesses right-to-left
            idxs.insert(0,e.idx)
            e = e.base
        assert type(e) is IR.Var
        F       = self._compile_expr(e)
        assert type(F) is hw_func_t
        
        # lookup index expressions, and create access
        n_idx   = len(idxs)
        exprs   = [ self._ctxt.get(i)
                    for i in idxs ]
        idx_arr = (hw_expr_t * n_idx)(*exprs)
        a   = C.hwrap_access_func(F,n_idx,idx_arr)
        return a
        
    def run(self,inputs,outputs):
        pass
    

Ok, so this defines a Halide pipeline.  We still need a way to run it.

## Handling Parameters

Recall that we defined `ndarray_to_halide_buf(ndarray)` as a function to help convert from NumPy to Halide data descriptors.  Now, we need a few more details to help out.

* We need a way to handle scalar inputs and outputs.
* We need a way to type-check inputs and outputs


In [29]:
def _type_value_match(typ,val):
    if type(typ) is IR.Pair:
        raise TypeError("Pairs unsupported")
    elif type(typ) is IR.TTensor:
        # unroll tensor
        shape = []
        while type(typ) is IR.TTensor:
            shape.append(typ.range)
            typ = typ.typ
        if type(typ) is IR.Pair:
            raise TypeError("Pairs unsupported")
        assert typ is IR.tnum
        
        if type(val) != np.ndarray:
            raise TypeError("Expected 'numpy.ndarray' type value")
        # check shape
        if len(shape) != len(val.shape):
            raise TypeError(f"Expected {len(shape)} dims, but got "
                            f"an ndarray with {len(val.shape)}")
        for i,d in enumerate(shape):
            if d != val.shape[i]:
                raise TypeError(f"expected dimension {i} of tensor shape "
                                f"to be {d}, but it was {val.shape[i]}")
        
    elif typ is IR.tnum:
        if type(val) != float and type(val) != int:
            raise TypeError("Expected 'float' or 'int' type value")

In [30]:
print(repr(Av))
_type_value_match(Rnn,Av)
_type_value_match(Rn,cv)
_type_value_match(R,3)
try: _type_value_match(Rn,Av)
except TypeError as e: print(e)
try: _type_value_match(IR.TTensor(4,Rn),Av)
except TypeError as e: print(e)
try: _type_value_match(Rn,3.4)
except TypeError as e: print(e)


array([[5. , 2. , 0. ],
       [2.2, 0. , 4.5],
       [0. , 6.1, 3.3]])
Expected 1 dims, but got an ndarray with 2
expected dimension 0 of tensor shape to be 4, but it was 3
Expected 'numpy.ndarray' type value


The preceding gives us a simple way to type-check.  In order to bind scalar inputs, we simply need to decide which case we're in (scalar or tensor) and bind to a param or image-param appropriately.  The output may be a bit trickier because we need to bind an otherwise spurious 1-entry tensor to handle scalar output.  This also raises the question of how output allocations should be handled.  It seems a given that the input should be allocated by the caller.  However, the callee often allocates the output.

In order to prevent possibly unwanted allocations, we'll just accept a "pre-allocated" output parameter.

In [31]:

def _Halide_CodeGen_run_pipeline(self,inputs,output=None):
    # check and bind input arguments
    hbufs = []
    for nm in self._arg_params: # that all inputs are defined
        if not nm in inputs:
            raise TypeError(f"expected input argument '{nm}'")
    for nm,val in inputs.items():
        if not nm in self._arg_typs:
            raise TypeError(f"unexpected input, named '{nm}'")
        T = self._arg_typs[nm]
        _type_value_match(T,val)
        P = self._arg_params[nm]
        if T is IR.tnum: # scalar
            assert type(P) is hw_param_t
            C.hwrap_set_param(P, ctypes.byref(ctypes.c_double(val)))
        else:
            assert type(P) is hw_img_t
            hbuf = ndarray_to_halide_buf(val)
            hbufs.append(hbuf) # prevent early de-allocation
            C.hwrap_set_img(P, ctypes.byref(hbuf))
    
    # handle unsupplied output...
    if output is None:
        typ = self._ret_typ
        if typ is IR.tnum: # scalar case
            output = np.array([0.0])
        else: # tensor case
            shape = []
            while type(typ) is IR.TTensor:
                shape.append(typ.range)
                typ = typ.typ
            assert typ is IR.tnum
            
            output = np.ndarray(dtype='double', shape=shape, order='F')
    # check output
    if self._ret_typ is IR.tnum:
        if (type(output) != np.ndarray or
            len(output.shape) != 1 or
            output.shape[0] != 1):
                raise TypeError("Expected numpy.ndarray of shape [1]")
    else:
        _type_value_match(self._ret_typ, output)
    # bind output
    outbuf  = ndarray_to_halide_buf(output)
    hbufs.append(outbuf)
    
    # make sure the pipeline is auto-scheduled
    if self._first_run:
        self._first_run = False
        #C.hwrap_autoschedule_func( self._ret_Func )
        
    # run the pipeline
    C.hwrap_realize_func( self._ret_Func, outbuf )
    
    # potentially extract the output
    return_val = output
    if self._ret_typ is IR.tnum:
        return_val = output[0]
    return return_val

Halide_CodeGen.run = _Halide_CodeGen_run_pipeline

## Putting it Together

We've defined a compilation strategy that looks like the following for some Tensor-language expression `e` written in `IRv0` and signature `argtyps` with form `{ name : IRv0_type }`:

```
TF = TypedFunc(argtyps,e)
# normalization passes
TF.eliminate_pairs()
TF.lift_lets()
TF.gen_normalize()
# compilation pass
TF.compile()

# only this need be invoked on each further execution
TF.run(argvals)
```

What we need now is to create `compile()` and `run()` wrappers on the typed-func class.  We'll have compilation take care of all the normalization while we're at it.


In [32]:

def _TypedFunc_compile(self):
    # protect against re-compilation
    try:
        getattr(self,'_compiled_obj')
        return
    except: pass
    # run normalization
    self.eliminate_pairs()
    self.lift_lets()
    self.gen_normalize()
    # do the compilation via the CodeGen object
    e       = self._expr
    atyps   = self._symtyps
    CG      = Halide_CodeGen(e,atyps,e.typ)
    self._compiled_obj = CG

def _TypedFunc_run(self,inputs,output=None):
    self.compile()
    # actual execution
    return self._compiled_obj.run(inputs,output)

TypedFunc.compile   = _TypedFunc_compile
TypedFunc.run       = _TypedFunc_run

We can pack this into the original Func, neatly hiding everything behind that interface.

In [33]:
def _Func_jit_compile(self):
    try: getattr(self,'_typed_func'); return
    except: pass
    
    self._typed_func = TypedFunc(self._arglist,self._expr)
    self._typed_func.compile()

def Func_jit_exec(self, *args, **kwargs):
    self._jit_compile()
    call_args   = {}

    # check that the right number of arguments were used
    n_call      = len(args) + len(kwargs)
    n_args      = len(self._arglist)
    if n_call != n_args:
        raise TypeError(f"expected {n_args} arguments, "
                        f"but was called with {n_call}")

    # fill out call_args with supplied named arguments
    for nm in kwargs:
        typ     = self._argdict.get(nm)
        if typ is None:
            raise TypeError(f"argument '{nm}' is not an argument of "
                            f"this tensor function")
        else:
            call_args[nm] = kwargs[nm]

    # then fill in the remainder with the unnamed arguments
    arg_i   = 0
    for nm,typ in self._arglist:
        if not nm in call_args:
            assert(arg_i < len(args))
            val     = args[arg_i]
            arg_i   = arg_i + 1
            call_args[nm] = val

    # finally, execute
    return self._typed_func.run(call_args)

Func._jit_compile = _Func_jit_compile
Func.jit_exec = Func_jit_exec

Let's try to test what we've got here.

In [34]:
n       = 3
R       = IR.tnum
Rn      = IR.TTensor(n,R)
Rnn     = IR.TTensor(n,Rn)
Rnnn    = IR.TTensor(n,Rnn)
x       = IRv0.Var('x')
A, D    = IRv0.Var('A'), IRv0.Var('D')
a, b, c = IRv0.Var('a'), IRv0.Var('b'), IRv0.Var('c')
i, j, k = 'i', 'j', 'k'

store_order = 'F'
xv = np.array([4.,7.,1.], order=store_order)
Av = np.array([[5.,2.,0.],[2.2,0.,4.5],[0.,6.1,3.3]], order=store_order)
av = np.array([9.2,5.4,7.1], order=store_order)
bv = np.array([0.3,2.1,1.6], order=store_order)
cv = np.array([0.,0.,1.], order=store_order)
Dv = np.array([[[ 1., 2., 3.],[ 4., 5., 6.],[ 7., 8., 9.]],
               [[10.,11.,12.],[13.,14.,15.],[16.,17.,18.]],
               [[19.,20.,21.],[22.,23.,24.],[25.,26.,27.]]], order=store_order)

pp2     = IRv0.Pair
p0      = lambda x: IRv0.Proj(0, x)
p1      = lambda x: IRv0.Proj(1, x)
let     = lambda x,r,b: IRv0.Let(x,r,b)
Gen     = lambda i,r,e: IRv0.Gen(i,r,e)
Sum     = lambda i,r,e: IRv0.Sum(i,r,e)


Axc     = Gen(i,n, c[i] + Sum(j,n, A[i,j] * x[j] ))
tr      = Sum(i,n, A[i,i])
diag    = Gen(i,n, Gen(j,n, IRv0.Eq(i,j) * x[i] ))
ab3     = let('x', Gen(j,n, pp2( pp2(a[j], b[j]), a[j]*b[j] )),
              Sum(j,n, p0(p0(x[j])) + p1(p0(x[j])) + p1(x[j]) ))
ctrct   = Gen(i,n, Sum(j,n, let('x', Sum(k,n, D[i,j,k] * a[k]),
                                 x * b[j] )))
prod    = Gen(i,n, Sum(j,n, A[i,j] * x[j] ))
qprod   = Sum(k,n, prod[k] * x[k])
dprod   = Sum(j,n, Sum(i,n, a[j] * diag[j,i] * a[i] ))

FAxc    = Func([('A',Rnn),('x',Rn),('c',Rn)], Axc)
Ftr     = Func([('A',Rnn)], tr)
Fdiag   = Func([('x',Rn)], diag)
Fab3    = Func([('a',Rn),('b',Rn)], ab3)
Fctrct  = Func([('D',Rnnn),('a',Rn),('b',Rn)], ctrct)
Fqprod  = Func([('A',Rnn),('x',Rn)], qprod)
Fdprod  = Func([('x',Rn),('a',Rn)] ,dprod)

print(FAxc)
print(Ftr)
print(Fdiag)
print(Fab3)
print(Fctrct)
print(Fqprod)
print(Fdprod)

<atlv0.Func object at 0x10e5e4208>
<atlv0.Func object at 0x10e5e43c8>
<atlv0.Func object at 0x10e5e4438>
<atlv0.Func object at 0x10e5e44a8>
<atlv0.Func object at 0x10e5193c8>
<atlv0.Func object at 0x10e519400>
<atlv0.Func object at 0x10e5a1198>


In [35]:
FAxc._jit_compile()
Ftr._jit_compile()
Fdiag._jit_compile()
Fab3._jit_compile()
Fctrct._jit_compile()
Fqprod._jit_compile()
Fdprod._jit_compile()

In [36]:
def test(name,F,*args):
    print('Run Test: ',name)
    print(*args)
    l_args = [ a.tolist() for a in args ]
    print('compiled:    ', F.jit_exec(*args))
    print('interpreted: ', F.interpret(*l_args))

test("Axc",FAxc,Av,xv,cv)
test("tr",Ftr,Av)
test("diag",Fdiag,xv)
test("ab3",Fab3,av,bv)
test("ctrct",Fctrct,Dv,av,bv)
test("qprod",Fqprod,Av,xv)
test("dprod",Fdprod,xv,av)


Run Test:  Axc
[[5.  2.  0. ]
 [2.2 0.  4.5]
 [0.  6.1 3.3]] [4. 7. 1.] [0. 0. 1.]
compiled:     [34.  13.3 47. ]
interpreted:  [34.0, 13.3, 46.99999999999999]
Run Test:  tr
[[5.  2.  0. ]
 [2.2 0.  4.5]
 [0.  6.1 3.3]]
compiled:     8.3
interpreted:  8.3
Run Test:  diag
[4. 7. 1.]
compiled:     [[4. 0. 0.]
 [0. 7. 0.]
 [0. 0. 1.]]
interpreted:  [[4.0, 0.0, 0.0], [0.0, 7.0, 0.0], [0.0, 0.0, 1.0]]
Run Test:  ab3
[9.2 5.4 7.1] [0.3 2.1 1.6]
compiled:     51.16
interpreted:  51.16
Run Test:  ctrct
[[[ 1.  2.  3.]
  [ 4.  5.  6.]
  [ 7.  8.  9.]]

 [[10. 11. 12.]
  [13. 14. 15.]
  [16. 17. 18.]]

 [[19. 20. 21.]
  [22. 23. 24.]
  [25. 26. 27.]]] [9.2 5.4 7.1] [0.3 2.1 1.6]
compiled:     [ 510.23 1291.43 2072.63]
interpreted:  [510.23, 1291.43, 2072.6299999999997]
Run Test:  qprod
[[5.  2.  0. ]
 [2.2 0.  4.5]
 [0.  6.1 3.3]] [4. 7. 1.]
compiled:     275.1
interpreted:  275.1
Run Test:  dprod
[4. 7. 1.] [9.2 5.4 7.1]
compiled:     593.0899999999999
interpreted:  593.0899999999999


# Comments

This was a long slog, but we got all the way to a compiled version of the tensor language.  We don't necessarily have a very good sense of performance, but that doesn't have to be an overriding concern right now.

Rather, the big limitation of this experiment was the complexity overload of all the unanticipated hoop-jumping.  To get this to be cleaner, we need to start thinking about how to break apart our previously compact compiler into multiple, distinct IRs—each of which has a particular role to play.

What might this look like?

1. We ought to clean-up the front-end IR a bit more.  We can now expect that type-checking will mark a boundary from this front-end to the back-end.  We ought to be able to reverse this mapping for at least the benefit of serializing functions.
2. We ought to treat the internal IR in as normalized a form as possible.  Is there some way to avoid so many different passes to transform it down?
3. We might want to create an explicit Halide IR to simplify the code-generation.

However, the _type system_ is shared in common amongst different IRs.  This suggests we need to factor out the type-system.

In addition to separating out these IRs, we also ought to have more robust testing.  Certainly we forgot something in our haste (e.g. literal-int indices in accesses).  Testing infrastructure might be able to help us out tremendously and efficiently.

Lastly, we ought to clean up the way compiler passes get written.  We seemed to settle on a "pass-as-class" style.  However, certain objects such as `_Context` tended to be very important.  Is there any way we can standardize our pass-writing further?


## Extensions

In addition to re-factoring, we need to fundamentally extend the basic IR.  Recall its definition.

In [40]:
print(IRv0._defstr)


module IR_v0 {
    expr = Var      ( string name )
         | Const    ( float  val  )
         | Add      ( expr lhs, expr rhs )
         | Mul      ( expr lhs, expr rhs )
         | Pair     ( expr lhs, expr rhs )
         | Proj     ( int01 idx, expr arg )
         | Gen      ( string idxname, int range, expr body )
         | Sum      ( string idxname, int range, expr body )
         | Access   ( expr  base, index idx )
         -- implied multiplication of the bracket with body
         | Indicate ( pred  arg, expr body )
         -- important to express sharing of computation results
         | Let      ( string name, expr rhs, expr body )
    
    -- indices are drawn from a range s.t.
    -- 0 <= i < range
    
    pred    = Eq( index lhs, index rhs )
    
    type    = TNum    ()
            | TError  ()
            | TPair   (  type lhs, type rhs )
            | TTensor ( int range, type typ )
}



We will need to expand the `pred` class with other concepts in order to capture tensor sparsity patterns.

```
    pred    = Peq     ( iexp lhs, iexp rhs )
            | Pneq    ( iexp lhs, iexp rhs )
            | Plt     ( iexp lhs, iexp rhs )
            | Pleq    ( iexp lhs, iexp rhs )
            | Prel    ( relation r, iexp* args )
            | Pand    ( pred lhs, pred rhs )
            | Por     ( pred lhs, pred rhs )
            | Pnot    ( pred arg )
            
    iexp    = Ivar    ( symbol name )
            | Imul    ( int    c, iexp arg )
            | Iadd    ( iexp lhs, iexp rhs )
            | Iconst  ( int val )
```

This set of concepts contains affine-indexing expressions via `Imul` and `Iadd`, plus the comparison predicates.  Additionally, there are Boolean connectives, and `Prel`, which allows for data-defined predicates—which are relations in the database sense.

Consider as a suggestive example, matrix-matrix multiplication with sparse matrices whose sparsity patterns are governed by relations.

$$\begin{array}{rcl}
S &:& [n,m]\mathbb{2} \\
T &:& [m,p]\mathbb{2} \\
A &:& [n,m]\mathbb{R} <: S \\
B &:& [m,p]\mathbb{R} <: T \\
M &=& \boxplus_{i,k} \sum_{j} A[i,j] \cdot B[j,k] \\
\end{array}$$

We can propagate the sparsity constraints using the Iverson-bracket indicator function:

$$\begin{array}{rcl}
M &=& \boxplus_{i,k} \sum_{j} A[i,j] \cdot B[j,k] \\
  &=& \boxplus_{i,k} \sum_{j} [S(i,j)] \cdot A[i,j] \cdot [T(j,k)] \cdot B[j,k] \\
  &=& \boxplus_{i,k} \sum_{j} [S(i,j)] \cdot A[i,j] \cdot [T(j,k)] \cdot B[j,k] \\
  &=& \boxplus_{i,k} [\exists_j S(i,j)\wedge T(j,k)]\sum_{j} \cdot A[i,j] \cdot B[j,k] \\
\end{array}$$

This final structure predicate $\exists_j S(i,j)\wedge T(j,k)$, is a join on $j : m$ between $S$ and $T$ followed by a projection onto $i,k$, dropping $j$.  This new structure relation can be pre-computed separately from $M$ or computed at the same time as $M$.  However, we mostly lack the requisite subtlety in our language IR to capture these distinctions.

For instance, suppose we simply want to restrict the looping using the sparsity structure.  We need a way to de-couple the loop index ordering from the summation-reduction behavior, and also figure in the join-behavior of the relations.
