In [1]:
import sys
sys.path.append('../')
from src.adt import ADT
from src.adt import memo as ADTmemo

In this file I'm going to sketch out a language for tensors.  To begin, let's create an initial IR to work with.  We will refine it as we go.

$$
\newcommand{\or}{\ |\ }
\begin{array}{rl}
  e &= x \or c \or e_0 + e_1 \or e_0 \cdot e_1 \\
  &\or \sum_i e \or \boxplus_i e \or e_i \or [i=j] \\
  &\or (e_0,e_1) \or \pi_0 e \or \pi_1 e \\
\end{array}
$$

In this basic expression grammar, we have three lines.  The first line loosely corresponds to polynomials: variables, constants, addition and multiplication.  Then the second line extends the basic computation language with tensors: summation over an index, construction of an array over an index ( $\boxplus_i e$ ), indexing a tensor ( $e_i$ ), and an Iverson bracket ( $[i=j]$ ) on indices, which I'll get to in a minute.  Finally, some basic surrogate data structuring is included in the form of pair groupings.  Named dictionaries and many other more ergonomic language features can be reduced to pairs, so they work as a good middle ground between on the one hand just assuming that there won't be problems, and on the other hand getting bogged down in needlessly complicated trivial details.

The Iverson bracket ( $[i=j]$ ) is a function which assumes the value $1$ if the predicate in the bracket ( $i=j$ here ) is true, and $0$ if it is false.  It forms a fundamental link between logic (Boolean Algebra) and numeric computation (Linear Algebra).  The fundamental equation relating these two worlds is the inclusion-exclusion principle.

$$ [A \vee B] + [A \wedge B] = [A] + [B] $$

If $A \wedge B$ is false, we say that $A$ and $B$ are _logically independent_.  Given a decomposition of a predicate $A$ into $A = \vee_i A_i$ logically independent atoms (i.e. that $i\neq j \implies \neg(A_i \wedge A_j)$), then the corresponding Iverson bracketed predicates form a linear-basis for the original term: $[A] = \sum_i [A_i]$.  This is only true contingent on logical independence.

### Matrix-Vector Multiplication Example

$$ y = \boxplus_i\  c_i + \sum_j A_{i,j} \cdot x_j$$

(note that I am assuming the $\boxplus_i$ encloses the remainder of the line)

### Matrix-Trace Example

$$ tr(A) = \sum_i A_{i,i} $$

### Diagonal Matrix From a Vector Example

$$ diag(x) = \boxplus_{i,j}\  [i=j] \cdot x_i $$

# Tensor Language v.0

In [2]:
T0 = ADT("""
module T0 {
    expr = Var      ( string name, type? typ )
         | Const    ( float  val  )
         | Add      ( expr lhs, expr rhs )
         | Mul      ( expr lhs, expr rhs )
         | Pair     ( expr lhs, expr rhs )
         | Proj     ( int01 idx, expr arg )
         | Gen      ( string idxname, int range, expr body )
         | Sum      ( string idxname, int range, expr body )
         | Access   ( expr  base, index idx )
         -- implied multiplication of the bracket with body
         | Indicate ( pred  arg, expr body )
         -- important to express sharing of computation results
         | Let      ( string name, expr rhs, expr body )
    
    -- indices are drawn from a range s.t.
    -- 0 <= i < range
    
    pred    = Eq( index lhs, index rhs )
    
    type    = TNum    ()
            | TError  ()
            | TPair   (  type lhs, type rhs )
            | TTensor ( int range, type typ )
}
""", {
    'int01': lambda x: x == 0 or x == 1,
    'index': lambda x: (type(x) is int) or (type(x) is str),
})
ADTmemo(T0,['TNum','TError','TPair','TTensor'])
T0.tnum = T0.TNum()
T0.terr = T0.TError()


Two issues confront us immediately.  We need to ensure that expressions are well-typed, and then secondly we'd like a more concise way to input/display expressions in this language.  For instance, consider the encoding of the earlier examples

In [3]:
n = 100
R = T0.tnum
x = T0.Var('x',T0.TTensor(n,R))
A = T0.Var('A',T0.TTensor(n, T0.TTensor(n,R) ))
c = T0.Var('c',T0.TTensor(n,R))
i = 'i'
j = 'j'

# matrix-vector 
y = T0.Gen( i,n,T0.Add(
        T0.Access(c,i),
        T0.Sum( j,n,T0.Mul( 
            T0.Access(T0.Access(A,i),j),
            T0.Access(x,j)))))

# matrix-trace
tr = T0.Sum( i,n, T0.Access(T0.Access(A,i),i) )

# diagonal matrix
diag = T0.Gen( i,n,T0.Gen( j,n,
            T0.Indicate(T0.Eq(i,j),
                        T0.Access(x,i) )))

### Typechecking

Let's try to write a type-checking pass.

In [4]:
class _Context:
    def __init__(self):
        self._envs   = [{}]
    def push(self):
        self._envs.insert(0,{})
    def pop(self):
        self._envs.pop(0)
    def set(self, nm, val):
        self._envs[0][nm] = val
    def get(self, nm):
        for e in self._envs:
            if nm in e:
                return e[nm]
        return None

class TCError(Exception):
    def __init(self, errs):
        errmsg = ("errors during typechecking:\n" +
                  ('\n'.join(errs)))
        super(TCError, self).__init__(errmsg)
        
class TypeChecker:
    def __init__(self, expr):
        self._ctxt   = _Context()
        #for nm,typ in initenv.items():
        #    assert isinstance(typ, T0.type)
        #    self._ctxt.set(nm,typ)
        self._errors = []
        self._typ    = self.check(expr)
    
    def _err(self, node, msg):
        # might want to discern location
        # via `node` eventually
        self._errors.append(msg)
    
    def report_errors(self):
        if len(self._errors) > 0:
            raise TCError('Found errors during typechecking:\n  '+
                          '\n  '.join(self._errors))
    
    def _get_ivar(self, node, name):
        idxrange = self._ctxt.get(name)
        if idxrange == None:
            self._err(node, f"index variable '{name}' was undefined")
        elif type(idxrange) is not int:
            self._err(node, f"variable '{name}' was "
                            f"not bound as an index variable")
        else: return idxrange
        # fail fall-through
        return None
    
    def _get_var(self, node, name):
        typ = self._ctxt.get(name)
        if typ == None: return None
        elif not isinstance(typ, T0.type):
            self._err(node, f"variable '{name}' was "
                            f"not bound as a normal variable")
            return T.terr
        # on success
        return typ
    
    def check(self, node):
        nclass = type(node)
        if   nclass is T0.Var:
            lookup = self._get_var(node, node.name)
            if lookup == None and node.typ == None:
                self._err(node, f"expected free variable '{node.name}' "
                                f"to have a type annotated.")
                return T0.terr
            elif lookup == None:
                return node.typ
            elif node.type == None:
                return lookup
            elif lookup != node.typ:
                self._err(node, f"expected variable '{node.name}' "
                                f"type annotation {node.typ} to "
                                f"match bound variable type {lookup}")
                return T0.terr
            else: # two mechanisms agree
                return lookup

        elif nclass is T0.Const:
            return T0.tnum
        
        elif nclass is T0.Add or nclass is T0.Mul:
            ltyp = self.check(node.lhs)
            rtyp = self.check(node.rhs)
            typ  = T0.tnum if (ltyp == T0.tnum and
                               rtyp == T0.tnum) else T0.terr
            if ltyp != T0.tnum and ltyp != T0.terr:
                self._err(node,
                          f"expected number on left-hand-side "
                          f"of addition: {node}")
            if rtyp != T0.tnum and rtyp != T0.terr:
                self._err(node,
                          f"expected number on right-hand-side "
                          f"of addition: {node}")
            return typ
        
        elif nclass is T0.Pair:
            ltyp = self.check(node.lhs)
            rtyp = self.check(node.rhs)
            if ltyp == T0.terr or rtyp == T0.terr:
                return T0.terr
            else:
                return T0.TPair(ltyp,rtyp)
        
        elif nclass is T0.Proj:
            subtyp = self.check(node.arg)
            if subtyp == T0.terr: return T0.terr
            elif type(subtyp) is not T.TPair:
                self._err(node, f"Was expecting a pair as argument: {node}")
                return T0.terr
            elif node.idx == 0:   return subtyp.lhs
            else:                 return subtyp.rhs
        
        elif nclass is T0.Gen or nclass is T0.Sum:
            self._ctxt.push()
            self._ctxt.set(node.idxname, node.range)
            bodytyp = self.check(node.body)
            self._ctxt.pop()
            if   bodytyp == T0.terr: return T0.terr
            elif nclass is T0.Sum:
                if bodytyp == T0.tnum: return T0.tnum
                # otherwise
                self._err(node, f"Was expecting a number as argument: {node}")
                return T0.terr
            else: # nclass is T0.Gen
                return T0.TTensor(node.range, bodytyp)
        
        elif nclass is T0.Access:
            basetyp  = self.check(node.base)
            if basetyp == T0.terr: return T0.terr
            if not isinstance(basetyp,T0.TTensor):
                self._err(node, f"Was expecting a tensor to index: {node}")
                return T0.terr
            idxrange = self._get_ivar(node, node.idx)
            if idxrange == None: return T0.terr
            if idxrange != basetyp.range:
                self._err(node, f"index variable '{node.idx}' was bound "
                                f"to the range {idxrange}, but this tensor "
                                f"expects an index of range {basetyp.range}")
                return T0.terr
            # if reaching here, all checks passed
            # we can return the (accessed) de-tensor-ed type
            return basetyp.typ
        
        elif nclass is T0.Indicate:
            # need to check the predicate
            eqnode  = node.arg
            lrange  = self._get_ivar(node, eqnode.lhs)
            rrange  = self._get_ivar(node, eqnode.rhs)
            bodytyp = self.check(node.body)
            
            if lrange == None or rrange == None: pass
            elif lrange != rrange:
                self._err(node, f"index variables "
                                f"'{eqnode.lhs}' and '{eqnode.rhs}' "
                                f"in equality are drawn from different"
                                f"ranges: {lrange} and {rrange}")
            # can proceed to type-check regardless of errors
            # if we at least have the type of the body
            return bodytyp
        
        elif nclass is T0.Let:
            rtyp    = self.check(node.rhs)
            self._ctxt.push()
            self._ctxt.set(node.name, rtyp)
            bodytyp = self.check(node.body)
            self._ctxt.pop()
            return bodytyp
        
        else:
            assert false, "Unexpected expression class for {node}"
            

In [5]:
TypeChecker(y).report_errors()

### Input Syntax Sugar

Typechecking succeeded.  As a next step, let's look at building in some additional help for constructing the IR objects.  For instance, we ought to be able to take two nodes `a` and `b` and write `a + b` to get `T0.Add(a,b)`.  Likewise for multiplication.

In fact we can do that by overloading the `__add__` and `__mul__` methods of the class.  We can actually do this for all expressions at once, making the matter a bit simpler.

We would also like to promote literals up to constants in the IR when appropriate.  With a very little bit of extra work, we can do this too.

In [6]:
def _T0_expr_add_(lhs,rhs):
    if type(rhs) == int or type(rhs) == float:
        return T0.Add(lhs, T0.Const(float(rhs)))
    elif isinstance(rhs,T0.expr):
        return T0.Add(lhs,rhs)
    else:
        return NotImplemented
T0.expr.__add__ = _T0_expr_add_

def _T0_expr_radd_(rhs,lhs):
    if type(lhs) == int or type(lhs) == float:
        return T0.Add(T0.Const(float(lhs)), rhs)
    else:
        return NotImplemented
T0.expr.__radd__ = _T0_expr_radd_
    

def _T0_expr_mul_(lhs,rhs):
    if type(rhs) == int or type(rhs) == float:
        return T0.Mul(lhs, T0.Const(float(rhs)))
    elif isinstance(rhs,T0.expr):
        return T0.Mul(lhs,rhs)
    else:
        return NotImplemented
T0.expr.__mul__ = _T0_expr_mul_

def _T0_expr_rmul_(rhs,lhs):
    if type(lhs) == int or type(lhs) == float:
        return T0.Mul(T0.Const(float(lhs)), rhs)
    else:
        return NotImplemented
T0.expr.__rmul__ = _T0_expr_rmul_

In [7]:
print(x * x)
print(x + x)
print(x + 4)
print(4 + x)
print(x * 4)
print(4 * x)

Mul(lhs=Var(name=x,typ=TTensor(range=100,typ=TNum())),rhs=Var(name=x,typ=TTensor(range=100,typ=TNum())))
Add(lhs=Var(name=x,typ=TTensor(range=100,typ=TNum())),rhs=Var(name=x,typ=TTensor(range=100,typ=TNum())))
Add(lhs=Var(name=x,typ=TTensor(range=100,typ=TNum())),rhs=Const(val=4.0))
Add(lhs=Const(val=4.0),rhs=Var(name=x,typ=TTensor(range=100,typ=TNum())))
Mul(lhs=Var(name=x,typ=TTensor(range=100,typ=TNum())),rhs=Const(val=4.0))
Mul(lhs=Const(val=4.0),rhs=Var(name=x,typ=TTensor(range=100,typ=TNum())))


A bit more cleverness will let us avoid any odd syntax for the Iverson-bracket/indicator-function.  Simply, if we define multiplication for the type _pred_, then we can multiply those terms with expressions to implicitly form indicators.

In [8]:
def _T0_pred_mul_(lhs,rhs):
    if type(rhs) == int or type(rhs) == float:
        return T0.Indicate(lhs, T0.Const(float(rhs)))
    elif isinstance(rhs,T0.expr):
        return T0.Indicate(lhs,rhs)
    else:
        return NotImplemented
T0.pred.__mul__ = _T0_pred_mul_

def _T0_pred_rmul_(rhs,lhs):
    if type(lhs) == int or type(lhs) == float:
        return T0.Indicate(rhs, T0.Const(float(lhs)))
    elif isinstance(lhs,T0.expr):
        return T0.Indicate(rhs,lhs)
    else:
        return NotImplemented
T0.pred.__rmul__ = _T0_pred_rmul_

In [9]:
i1 = T0.Eq('i',1)
print(i1 * x)
print(x * i1)
print(i1 * 2)
print(2 * i1)

Indicate(arg=Eq(lhs=i,rhs=1),body=Var(name=x,typ=TTensor(range=100,typ=TNum())))
Indicate(arg=Eq(lhs=i,rhs=1),body=Var(name=x,typ=TTensor(range=100,typ=TNum())))
Indicate(arg=Eq(lhs=i,rhs=1),body=Const(val=2.0))
Indicate(arg=Eq(lhs=i,rhs=1),body=Const(val=2.0))


Finally, we can exploit array indexing syntax to make it more concise to index tensors.

In [10]:
def _T0_expr_getitem_(expr,key):
    if not type(key) is tuple: key = (key,)
    # iterate over the key-tuple and glom on accesses
    for i in key:
        if type(i) == int or type(i) == str:
            expr = T0.Access(expr,i)
        else:
            raise TypeError(f"expected int or string in tensor index: {i}")
    return expr
T0.expr.__getitem__ = _T0_expr_getitem_

In [11]:
print(x[1])
print(x['i'])

Access(base=Var(name=x,typ=TTensor(range=100,typ=TNum())),idx=1)
Access(base=Var(name=x,typ=TTensor(range=100,typ=TNum())),idx=i)


With this bag of tricks, let's rewrite the examples from before that we were playing with.

Notice that not only are the expressions more concise, we've also made the construction code significantly easier to read thanks to normal arithmetic and indexing expressions.

We can certainly get more sophisticated by introducing overloading tricks for the types.  However, doing so can be tricky because which benefits we see will change.  Often doing this will make the code more concise.  However, there is little reason to believe that programmers will see the resulting code as "normal."  Most importantly, since this is just a sketch, a minimal amount of niceness to keep ourselves sane will suffice.

In [12]:
n = 100
R = T0.tnum
x = T0.Var('x',T0.TTensor(n,R))
A = T0.Var('A',T0.TTensor(n, T0.TTensor(n,R) ))
c = T0.Var('c',T0.TTensor(n,R))
i = 'i'
j = 'j'

# matrix-vector 
y = T0.Gen(i,n, c[i] + T0.Sum(j,n, A[i,j] * x[j] ))
# old version for comparison
# y = T0.Gen( i,n,T0.Add(
#     T0.Access(c,i),
#       T0.Sum( j,n,T0.Mul( 
#           T0.Access(T0.Access(A,i),j),
#           T0.Access(x,j)))))

# matrix-trace
tr = T0.Sum(i,n, A[i,i])
# old version for comparison
# tr = T0.Sum( i,n, T0.Access(T0.Access(A,i),i) )

# diagonal matrix
diag = T0.Gen(i,n,T0.Gen(j,n, T0.Eq(i,j) * x[i] ))
# old version for comparison
# diag = T0.Gen( i,n,T0.Gen( j,n,
#           T0.Indicate(T0.Eq(i,j),
#                       T0.Access(x,i) )))
print(y)
print(tr)
print(diag)

Gen(idxname=i,range=100,body=Add(lhs=Access(base=Var(name=c,typ=TTensor(range=100,typ=TNum())),idx=i),rhs=Sum(idxname=j,range=100,body=Mul(lhs=Access(base=Access(base=Var(name=A,typ=TTensor(range=100,typ=TTensor(range=100,typ=TNum()))),idx=i),idx=j),rhs=Access(base=Var(name=x,typ=TTensor(range=100,typ=TNum())),idx=j)))))
Sum(idxname=i,range=100,body=Access(base=Access(base=Var(name=A,typ=TTensor(range=100,typ=TTensor(range=100,typ=TNum()))),idx=i),idx=i))
Gen(idxname=i,range=100,body=Gen(idxname=j,range=100,body=Indicate(arg=Eq(lhs=i,rhs=j),body=Access(base=Var(name=x,typ=TTensor(range=100,typ=TNum())),idx=i))))


### Output / Display Formats

Consider our original grammar sketch.  Wouldn't it be nice if we could display tensor expressions in a manner closer to this?

$$
\newcommand{\or}{\ |\ }
\begin{array}{rl}
  e &= x \or c \or e_0 + e_1 \or e_0 \cdot e_1 \\
  &\or \sum_i e \or \boxplus_i e \or e_i \or [i=j] \\
  &\or (e_0,e_1) \or \pi_0 e \or \pi_1 e \\
\end{array}
$$

Additionally, we'd like a simple plain-text approximation of the above.  Along those lines, consider the following translation:

$$
\begin{array}{r|r|l}
6 & x & \texttt{x}\\
  & c & \texttt{c} \\
  \hline
5 & e[i] & \texttt{e[i]} \\
  \hline
4 & \pi_0 e & \texttt{e.0} \\
  & \pi_1 e & \texttt{e.1} \\
  \hline
3 & a \cdot b & \texttt{a * b} \\
  & [i=j]\cdot e & \texttt{[i=j]*e} \\
  \hline
2 & a + b & \texttt{a + b} \\
  \hline
1 & \sum_{i:n}\ e & \texttt{+(i:n) e} \\
  & \boxplus_{i:n}\ e & \texttt{Gen(i:n) e} \\
  \hline
0 & (a,b) & \texttt{(a,b)} \\
\end{array}
$$

The various forms are arranged in a precedence order.  The exception is the pair construction, which behaves as if it was a maximum precedence operator when acting as a sub-expression, but also behaves as a minimum precedence operator when enclosing sub-expressions itself.

We will suppress type-annotations but preserve range annotations on $\sum$ and $\boxplus$.

In [13]:
def T0_str_rep(e,prec=0):
    eclass = type(e)
    s      = "ERROR"
    if   eclass is T0.Var:
        s = e.name
    elif eclass is T0.Const:
        s = str(e.val)
    elif eclass is T0.Add:
        s = f"{T0_str_rep(e.lhs,2)} + {T0_str_rep(e.rhs,2)}"
        if prec > 2: s = f"({s})"
    elif eclass is T0.Mul:
        s = f"{T0_str_rep(e.lhs,3)} * {T0_str_rep(e.rhs,3)}"
        if prec > 3: s = f"({s})"
    elif eclass is T0.Pair:
        s = f"({T0_str_rep(e.lhs,0)},{T0_str_rep(e.rhs,0)})"
    elif eclass is T0.Proj:
        s = f"{T0_str_rep(e.arg,4)}.{e.idx}"
        if prec > 4: s = f"({s})"
    elif eclass is T0.Gen or eclass is T0.Sum:
        op = "+" if eclass is T0.Sum else "Gen"
        s = f"{op}({e.idxname}:{e.range}) {T0_str_rep(e.body,1)}"
        if prec > 1: s = f"({s})"
    elif eclass is T0.Access:
        s = f"{T0_str_rep(e.base,5)}[{e.idx}]"
        if prec > 5: s = f"({s})"
    elif eclass is T0.Indicate:
        assert isinstance(e.arg, T0.Eq), 'sanity: pred is Eq'
        s = f"[{e.arg.lhs}={e.arg.rhs}]*{T0_str_rep(e.body,3)}"
        if prec > 3: s = f"({s})"
    elif eclass is T0.Let:
        # note that this is ill-behaved formatting
        # for lets nested inside of expressions
        s = f"let {e.name} = {T0_str_rep(e.rhs,0)} in\n{T0_str_rep(e.body,0)}"
        if prec > 0: s = f"({s})"
    return s

T0.expr.__str__ = T0_str_rep

In [14]:
print(y)
print(tr)
print(diag)

Gen(i:100) c[i] + (+(j:100) A[i][j] * x[j])
+(i:100) A[i][i]
Gen(i:100) Gen(j:100) [i=j]*x[i]


Likewise, with latex for the notebook...

In [15]:
def T0_latex_str(e,prec=0):
    eclass = type(e)
    s      = "ERROR"
    if   eclass is T0.Var:
        s = e.name
    elif eclass is T0.Const:
        s = str(e.val)
    elif eclass is T0.Add:
        s = f"{T0_latex_str(e.lhs,2)} + {T0_latex_str(e.rhs,2)}"
        if prec > 2: s = f"\\left({s}\\right)"
    elif eclass is T0.Mul:
        s = f"{T0_latex_str(e.lhs,3)} \\cdot {T0_latex_str(e.rhs,3)}"
        if prec > 3: s = f"\\left({s}\\right)"
    elif eclass is T0.Pair:
        s = f"\\left({T0_latex_str(e.lhs,0)},{T0_latex_str(e.rhs,0)}\\right)"
    elif eclass is T0.Proj:
        s = f"\\pi_{{{e.idx}}} {T0_latex_str(e.arg,4)}"
        if prec > 4: s = f"\\left({s}\\right)"
    elif eclass is T0.Gen or eclass is T0.Sum:
        op = "\\sum" if eclass is T0.Sum else "\\boxplus"
        s = f"{op}_{{{e.idxname}:{e.range}}}\\ {T0_latex_str(e.body,1)}"
        if prec > 1: s = f"\\left({s}\\right)"
    elif eclass is T0.Access:
        s = f"{T0_latex_str(e.base,5)}[{e.idx}]"
        if prec > 5: s = f"\\left({s}\\right)"
    elif eclass is T0.Indicate:
        assert isinstance(e.arg, T0.Eq), 'sanity: pred is Eq'
        s = f"[{e.arg.lhs}={e.arg.rhs}]\\cdot {T0_latex_str(e.body,3)}"
        if prec > 3: s = f"\\left({s}\\right)"
    elif eclass is T0.Let:
        # note that this is ill-behaved formatting
        # for lets nested inside of expressions
        s = (f"\\begin{{array}}{{l}}"
             f" \\textrm{{let }} {e.name} = {T0_latex_str(e.rhs,0)}\\textrm{{ in}}\\\\"
             f" {T0_latex_str(e.body,0)}"
             f"\\end{{array}}")
        if prec > 0: s = f"\\left({s}\\right)"
    return s

def T0_latex_repr(e):
    return f"${T0_latex_str(e)}$"

T0.expr._repr_latex_ = T0_latex_repr

In [16]:
y

Gen(idxname=i,range=100,body=c[i] + (+(j:100) A[i][j] * x[j]))

In [17]:
tr

Sum(idxname=i,range=100,body=A[i][i])

In [18]:
diag

Gen(idxname=i,range=100,body=Gen(j:100) [i=j]*x[i])

## Execution Semantics (via interpretation)

First, let's develop a way to check whether a Python object/value inhabits a given type in our language.  We'll need this check to make sure that inputs to execution are not going to cause problems.

In [20]:
def T0_check_value(typ, val):
    assert typ != T0.terr, "No values of type TErr"
    if typ == T0.tnum:
        if type(val) is not float:
            raise TypeError("Expected floating point value")
    elif isinstance(typ, T0.TPair):
        if type(val) is not tuple or len(val) != 2:
            raise TypeError("Expected pair value")
        L,R = val
        T0_check_value(typ.lhs, L)
        T0_check_value(typ.rhs, R)
    elif isinstance(typ, T0.TTensor):
        if type(val) is not list:
            raise TypeError("Expected list value")
        elif len(val) != typ.range:
            raise TypeError(f"Expected list of {typ.range} "
                            f"entries, but got {len(val)} entries")
        for i in range(typ.range):
            T0_check_value(typ.typ, val[i])
    else:
        assert False, f"{typ} should be a tensor-language type"


In [21]:
xv = [4.,7.]
Av = [[5.,2.],[2.2,4.5],[6.1,3.3]]
vec2f = T0.TTensor(2,T0.tnum)
mat32f = T0.TTensor(3,T0.TTensor(2,T0.tnum))
R = T0.tnum
T0_check_value(R,3.2)
T0_check_value(vec2f,xv)
T0_check_value(mat32f,Av)

TNum()
TTensor(range=2,typ=TNum())
TTensor(range=3,typ=TTensor(range=2,typ=TNum()))


Given an environment binding names to typed values, we want to have a way to evaluate a tensor expression in that environment.  That is what we'll call the interpreter.

In [22]:
class Interpreter():
    def __init__(self,env):
        self._init_vals = {}
        for name in env:
            val = env[name]
            self._init_vals[name] = val
    
    def run(self,e):
        self._vals = _Context()
        for nm,v in self._init_vals.items():
            self._vals.set(nm,v)
        return self._exec(e)
    
    def _get_val(self,nm):
        v = self._vals.get(nm)
        if v is None: raise KeyError(f"Did not find variable '{nm}'")
        return v
    
    def _get_ival(self,idx):
        if type(idx) is int: return idx
        else: return self._get_val(idx)
    
    def _exec(self,e):
        eclass = type(e)
        if   eclass is T0.Var:
            return self._get_val(e.name)
        
        elif eclass is T0.Const:
            return e.val
        
        elif eclass is T0.Add:
            return self._exec(e.lhs) + self._exec(e.rhs)
        
        elif eclass is T0.Mul:
            return self._exec(e.lhs) * self._exec(e.rhs)
        
        elif eclass is T0.Pair:
            return ( self._exec(e.lhs), self._exec(e.rhs) )
        
        elif eclass is T0.Proj:
            return self._exec(e.arg)[e.idx]
        
        elif eclass is T0.Gen:
            xs = []
            self._vals.push()
            for i in range(e.range):
                self._vals.set(e.idxname,i)
                xs.append( self._exec(e.body) )
            self._vals.pop()
            return xs
        
        elif eclass is T0.Sum:
            acc = 0.0
            self._vals.push()
            for i in range(e.range):
                self._vals.set(e.idxname,i)
                acc += self._exec(e.body)
            self._vals.pop()
            return acc
        
        elif eclass is T0.Access:
            base = self._exec(e.base)
            i = self._get_ival(e.idx)
            return base[i]
        
        elif eclass is T0.Indicate:
            assert type(e.arg) is T0.Eq, 'expect only Eq pred'
            i = self._get_ival(e.arg.lhs)
            j = self._get_ival(e.arg.rhs)
            if i == j:
                return self._exec(e.body)
            else:
                return 0.0
        
        elif eclass is T0.Let:
            rval = self._exec(e.rhs)
            self._vals.push()
            self._vals.set(e.name,rval)
            body = self._exec(e.body)
            self._vals.pop()
            return body
        
        else:
            assert False, "Unexpected Exec Case"


In [23]:
xv = [4.,7.,1.]
Av = [[5.,2.,0.],[2.2,0.,4.5],[0.,6.1,3.3]]
cv = [0.,0.,1.]
n = 3
R = T0.tnum
x = T0.Var('x',T0.TTensor(n,R))
A = T0.Var('A',T0.TTensor(n, T0.TTensor(n,R) ))
c = T0.Var('c',T0.TTensor(n,R))
i = 'i'
j = 'j'

y = T0.Gen(i,n, c[i] + T0.Sum(j,n, A[i,j] * x[j] ))
tr = T0.Sum(i,n, A[i,i])
diag = T0.Gen(i,n,T0.Gen(j,n, T0.Eq(i,j) * x[i] ))

TypeChecker(y).report_errors()
TypeChecker(tr).report_errors()
TypeChecker(diag).report_errors()

y_out = Interpreter({'A':Av,'x':xv,'c':cv}).run(y)
tr_out = Interpreter({'A':Av,'x':xv,'c':cv}).run(tr)
diag_out = Interpreter({'A':Av,'x':xv,'c':cv}).run(diag)

print(y_out)
print(tr_out)
print(diag_out)

[34.0, 13.3, 46.99999999999999]
8.3
[[4.0, 0.0, 0.0], [0.0, 7.0, 0.0], [0.0, 0.0, 1.0]]


## Putting it Together (Functions)

At this point, we have all of the components we need to package up an expression in the tensor language into a function.  We can have this packaging perform type-checking in a reliable way.  However, because of a strange choice about typing free-variables using explicit annotations, we need to additionally have a way to extract free variables (with types) from an expression.  Given that we should be good to proceed

In [29]:
class FreeVarAnalysis():
    def __init__(self,e):
        self._env  = _Context()
        self._free = {}
        self._do(e)
    
    def vars(self):
        return self._free.copy()
    
    def _do(self,e):
        eclass = type(e)
        if   eclass is T0.Var:
            # if this variable is free...
            if self._env.get(e.name) is None:
                assert e.typ, "free variable must have type"
                lookupT = self._free.get(e.name)
                if lookupT is None:
                    self._free[e.name] = e.typ
                elif lookupT != e.typ:
                    raise TypeError(f"free variable '{e.name}' "
                                    f"was inconsistently typed")
        
        elif eclass is T0.Const: pass
        
        elif eclass is T0.Add or eclass is T0.Mul or eclass is T0.Pair:
            self._do(e.lhs)
            self._do(e.rhs)
        
        elif eclass is T0.Proj:
            self._do(e.arg)
        
        elif eclass is T0.Gen or eclass is T0.Sum:
            self._env.push()
            self._env.set(e.idxname,True)
            self._do(e.body)
            self._env.pop()
        
        elif eclass is T0.Access:
            self._do(e.base)
        
        elif eclass is T0.Indicate:
            self._do(e.body)
        
        elif eclass is T0.Let:
            self._do(e.rhs)
            self._env.push()
            self._env.set(e.name,True)
            self._do(e.body)
            self._env.pop()
        
        else:
            assert False, "Unexpected expr Case"


In [31]:
print(FreeVarAnalysis(y).vars())
print(FreeVarAnalysis(tr).vars())
print(FreeVarAnalysis(diag).vars())

{'c': TTensor(range=3,typ=TNum()), 'A': TTensor(range=3,typ=TTensor(range=3,typ=TNum())), 'x': TTensor(range=3,typ=TNum())}
{'A': TTensor(range=3,typ=TTensor(range=3,typ=TNum()))}
{'x': TTensor(range=3,typ=TNum())}


Now to get back to packaging this up.

In [33]:
class FuncT0():
    def __init__(self,expr):
        self._expr = expr
        # run basic type-checking
        TypeChecker(expr).report_errors()
        # and get the free variables
        self._free = FreeVarAnalysis(expr).vars()
        
    def __call__(self, *args, **kwargs):
        if len(args) > 0:
            raise TypeError("Tensor functions must be called with named arguments")
        # check that every kwarg is present in free
        for nm in kwargs:
            if not self._free.get(nm):
                raise TypeError(f"argument '{nm}' is not a free variable "
                                f"in this tensor function")
        # then check that every free variable is bound with a value of
        # the correct type
        for nm,typ in self._free.items():
            val = kwargs.get(nm)
            if val is None:
                raise TypeError(f"expected argument '{nm}'")
            T0_check_value(typ,val)
        # finally, execute
        return Interpreter(kwargs).run(self._expr)


In [38]:
Fy = FuncT0(y)
Ftr = FuncT0(tr)
Fdiag = FuncT0(diag)
print( Fy(A = Av, x = xv, c = cv) )
print( Ftr(A = Av) )
print( Fdiag(x = xv) )

[34.0, 13.3, 46.99999999999999]
8.3
[[4.0, 0.0, 0.0], [0.0, 7.0, 0.0], [0.0, 0.0, 1.0]]
