<br><br><br><br><br>

# Type checking

<br><br><br><br><br>

<br>

Most compiled languages perform an additional tree-to-tree transformation: **type checking**.

Generally, an **untyped AST** (such as the ones we've been dealing with) gets replaced by a **typed AST**, in which each node is marked by a data type, such as `double` or `boolean`. (It's also possible to mark an AST in-place with type labels, but if so, be sure that node instances are unique!)

Type checking was traditionally motivated by the need to generate the right instructions in the output language (e.g. `__add_int32__` vs `__add_float32__` on unlabled 32-bit registers), but it can be much more general than that:

<center style="margin-top: 20px; margin-bottom: 20px"><b>type checking is a formal proof that the program satisfies certain properties.</b></center>

The properties to prove are encoded in the **type system**, which can be specialized to a domain like particle physics.

_What properties do we want particle physics analysis scripts to satisfy?_

<br>

<br>

**Some terminology:**

   * A **type** is a _set of possible values_ that a symbol or expression can have at runtime. Types may be
      * **abstract** if they're specified without reference to a bit-representation, like "all non-negative integers less than `2**32`"
      * **concrete** if a bit-representation is given, like "two's complement 32-bit integers without a sign bit."
   
   
   * A **strongly typed** language stops processing if it encounters values that do not match function argument types: it either stops the compilation or the runtime execution.
   
   * A **weakly typed** language either passes bits without checking them or converts values to fit expectations.
   
   * A **statically typed** language undergoes a type-checking pass before programs are run, usually as part of a compilation.
   
   * A **dynamically typed** language checks types at runtime. Types may be valid at one time and invalid at another.

<br>

<p style="margin-bottom: 0px"><b>Weakly typed (values are assumed to fit operations)</b></p>
<ul style="margin-top: 0px">
  <li>Most assembly languages treat all values as raw bits; programmer has to keep track of types and call the right instructions.
  <li>C is often used as a weakly typed language (e.g. passing everything as <tt>void*</tt>).
</ul>

<p style="margin-bottom: 0px"><b>Weakly typed (values are converted to fit operations)</b></p>
<ul style="margin-top: 0px">
  <li>Perl: <tt>"2" + 8 → "10"</tt> and unknown or unconvertable variables are presumed to be zero.
  <li>Javascript: <tt>"2" + 8 → "28"</tt>
  <li>MATLAB: <tt>"2" + 8 → 58</tt> (because the ASCII value of <tt>"2"</tt> is <tt>50</tt>...)
  <li>Python predicates: <tt>None</tt> or <tt>[]</tt> resolves to <tt>False</tt>, <tt>[0]</tt> resolves to <tt>True</tt> when used with <tt>if/and/or/not</tt>.
  <li>Python 2's handling of byte-strings vs unicode.
  <li>Most languages promote integers to floating-point values in mixed arithmetic.
</ul>

<p style="margin-bottom: 0px"><b>Strongly but dynamically typed</b></p>
<ul style="margin-top: 0px">
  <li>Everything else in Python (<tt>"2" + 8</tt> is a <tt>TypeError</tt>).
  <li>Lisp, Ruby, R, Erlang, Lua, Tcl, Smalltalk, PostScript...
</ul>

<p style="margin-bottom: 0px"><b>Strongly and statically typed</b></p>
<ul style="margin-top: 0px">
  <li>C++, Java, C#, Rust, Go, Swift, Fortran, Haskell, ML, Scala, Julia, mypy (Python linter), LLVM's assembly language...
</ul>

In [1]:
import lark
grammar = """
start: or
or:       and -> pass | and "or" and
and:      not -> pass | not "and" not
not:  compare -> pass | "not" not
compare: term -> pass | term "==" term -> eq | term "!=" term -> ne
                      | term  "<" term -> lt | term "<=" term -> le
                      | term  ">" term -> gt | term ">=" term -> ge
term:  factor -> pass | factor "+" term -> add | factor "-" term -> sub
factor:  atom -> pass | atom "*" factor -> mul | atom "/" factor -> truediv
atom:      "(" or ")" | CNAME -> symbol | INT -> int | FLOAT -> float

%import common.CNAME
%import common.INT
%import common.FLOAT
%import common.WS
%ignore WS
"""
parser = lark.Lark(grammar)

In [2]:
print(parser.parse("not x > 0.0 and 2 + 2").pretty())

start
  pass
    and
      not
        pass
          gt
            pass
              pass
                symbol	x
            pass
              pass
                float	0.0
      pass
        pass
          add
            pass
              int	2
            pass
              pass
                int	2



In [4]:
# Define AST nodes, as before. This is the untyped AST.

class AST:
    _fields = ()
    def __init__(self, *args):
        for n, x in zip(self._fields, args):
            setattr(self, n, x)

class Literal(AST):                                 # a literal always knows its type,
    _fields = ("value", "type")                     # even in the UntypedAST
    def __str__(self): return "{0}({1})".format(self.type.__name__, str(self.value))

class Symbol(AST):
    _fields = ("symbol",)
    def __str__(self): return self.symbol

class Call(AST):
    _fields = ("function", "arguments")
    def __str__(self):
        return "{0}({1})".format(str(self.function), ", ".join(str(x) for x in self.arguments))

In [5]:
# Simplify the Parsing Tree (PT) into an Abstract Syntax Tree (AST), as before.

def toast(ptnode):
    if ptnode.data == "start" or ptnode.data == "pass" or ptnode.data == "atom":
        return toast(ptnode.children[0])
    elif ptnode.data == "int":
        return Literal(int(ptnode.children[0]), int)
    elif ptnode.data == "float":
        return Literal(float(ptnode.children[0]), float)
    elif ptnode.data == "symbol":
        return Symbol(str(ptnode.children[0]))
    else:
        return Call(str(ptnode.data), [toast(x) for x in ptnode.children])

print(toast(parser.parse("not x > 0.0 and 2 + 2")))

and(not(gt(x, float(0.0))), add(int(2), int(2)))


In [6]:
# The typed AST is just like the untyped AST except that each node is labeled with a type.

class Typed:
    def __init__(self, thetype, *args):
        self.type = thetype
        super(Typed, self).__init__(*args)
    def __str__(self):
        return "{0} as {1}".format(super(Typed, self).__str__(), self.type.__name__)

class TypedLiteral(Typed, Literal): pass

class TypedSymbol(Typed, Symbol): pass

class TypedCall(Typed, Call): pass

In [7]:
def totyped(ast, symbols):
    if isinstance(ast, Literal):
        return TypedLiteral(ast.type, ast.value)
    elif isinstance(ast, Symbol):
        return TypedSymbol(symbols[ast.symbol], ast.symbol)
    else:
        arguments = [totyped(x, symbols) for x in ast.arguments]
        if ast.function in ("add", "sub", "mul", "truediv"):           # number · number → number
            if any(x.type != int and x.type != float for x in arguments):
                raise TypeError("{0} requires numerical arguments".format(repr(ast.function)))
            return TypedCall(float, ast.function, arguments)
        elif ast.function in ("eq", "ne", "lt", "le", "gt", "ge"):     # number · number → boolean
            if any(x.type != int and x.type != float for x in arguments):
                raise TypeError("{0} requires numerical arguments".format(repr(ast.function)))
            return TypedCall(bool, ast.function, arguments)
        elif ast.function in ("and", "or", "not"):                     # boolean · boolean → boolean
            if any(x.type != bool for x in arguments):
                raise TypeError("{0} requires boolean arguments".format(repr(ast.function)))
            return TypedCall(bool, ast.function, arguments)

In [10]:
# Our syntactically correct example has a type error.

code = "not x > 0.0 and 2 + 2"
print(toast(parser.parse(code)))
print(totyped(toast(parser.parse(code)), symbols={"x": float}))

and(not(gt(x, float(0.0))), add(int(2), int(2)))


TypeError: 'and' requires boolean arguments

In [11]:
# The totyped handling can be simplified by looking up a signature from a list.

def totyped(ast, signatures, symbols):
    if isinstance(ast, Literal):
        return TypedLiteral(ast.type, ast.value)

    elif isinstance(ast, Symbol):
        return TypedSymbol(symbols[ast.symbol], ast.symbol)

    else:
        arguments = [totyped(x, signatures, symbols) for x in ast.arguments]
        types = [x.type for x in arguments]

        # search for a (name, args) match; apply the corresponding ret
        for name, args, ret in signatures:
            if name == ast.function and args == types:
                return TypedCall(ret, ast.function, arguments)

        raise TypeError("illegal arguments: {0}({1})".format(
            ast.function, ", ".join(x.__name__ for x in types)))

In [14]:
# Short exercise: add and test truediv. How does its signature differ from add's?

signatures = [("add", [int, int], int),
              ("add", [int, float], float),
              ("add", [float, int], float),
              ("add", [float, float], float),
              ("gt",  [int, int], bool),
              ("gt",  [int, float], bool),
              ("gt",  [float, int], bool),
              ("gt",  [float, float], bool),
              ("not", [bool], bool),
              ("and", [bool, bool], bool),
              ("or",  [bool, bool], bool)]

code = "not x > 0.0 and 3 / 2"
print(toast(parser.parse(code)))
print(totyped(toast(parser.parse(code)), signatures, symbols={"x": float}))

and(not(gt(x, float(0.0))), truediv(int(3), int(2)))


TypeError: illegal arguments: truediv(int, int)

## Parameterized types

To support more data structures, we can consider "functions of types." Like functions in a programming language, they allow us to build what we need from simpler primitives.

Examples include:

   * C++ templates: think of the `<` `>` brackets as `(` `)` around a function's arguments.
   * Arrays, structs, and unions in C, which don't have a function-like syntax.
   * `tuple<T1, T2, T3>` values are points in `T1` and `T2` and `T3` (**product type**).
   * `variant<T1, T2, T3>` values are each in `T1` or `T2` or `T3` (**sum type**).
   * `None` or `null` is a member of a single-element set.
   * exceptions or functions that never return (infinite loop) are represented by the empty set.

Type-checking is implemented in some programming language; **type functions** are functions in the type-checking language.

The type-checking language is usually dynamically typed (ironically enough) and the type of a type is called a **kind**.

In [15]:
import lark

# matrix multiplication language, like bra-ket without syntactic constraints
grammar = """
start: term
term:  factor -> pass | term "+" factor -> add
factor:  atom -> pass | atom factor -> mul
atom:    "(" term ")" | CNAME -> symbol

%import common.CNAME
%import common.WS
%ignore WS
"""

parser = lark.Lark(grammar)

print(toast(parser.parse("x + A y")))

add(x, mul(A, y))


In [16]:
class Matrix:
    def __init__(self, rows, cols):
        self.rows, self.cols = rows, cols
        self.__name__ = str(self)
    def __str__(self):
        return "({0}×{1})".format(self.rows, self.cols)

def totyped(ast, symbols):
    if isinstance(ast, Symbol):
        return TypedSymbol(symbols[ast.symbol], ast.symbol)
    else:
        arguments = [totyped(x, symbols) for x in ast.arguments]
        left, right = [x.type for x in arguments]
        if ast.function == "add":
            if left.rows != right.rows or left.cols != right.cols:
                raise TypeError("cannot add {0} to {1}".format(left, right))
            return TypedCall(left, ast.function, arguments)
        elif ast.function == "mul":
            if left.cols != right.rows:
                raise TypeError("cannot mul {0} to {1}".format(left, right))
            return TypedCall(Matrix(left.rows, right.cols), ast.function, arguments)

In [19]:
code = "x + A y"
print(toast(parser.parse(code)))
print(totyped(toast(parser.parse(code)),
              symbols={"x": Matrix(1, 5), "A": Matrix(5, 4), "y": Matrix(4, 1)}))

add(x, mul(A, y))


TypeError: cannot add (1×5) to (5×1)

## Dependent types

Until recently, most of the focus of type-checking was to ensure that machine instructions were valid—for instance, a floating-point instruction is never called on integer data. Now there's increased emphasis on "type safety," ensuring that the programmer doesn't make certain kinds of mistakes.

   * To ensure machine code validity, you only need **concrete types**: what do these bytes respresent?
   * For type-safety, you need **abstract types**: what mathematical set is this value guaranteed to be in?

**Dependent types** are types that depend on values; they say something about the range of allowed values. For instance, a type system might distinguish empty lists from non-empty lists—a request for the first item is valid only for non-empty lists.

The Rust programming language includes memory management in its type system, so that memory leaks and double-deallocation are compile-time errors.

We'll implement numbers with interval arithmetic to forbid division by zero.

In [20]:
import lark

grammar = """
start: term
term:  factor -> pass | term "+" factor -> add | term "-" factor -> sub
factor:  atom -> pass | atom "*" factor -> mul | atom "/" factor -> truediv
atom:    "(" term ")" | CNAME -> symbol

%import common.CNAME
%import common.WS
%ignore WS
"""

parser = lark.Lark(grammar)

print(toast(parser.parse("(x + y) / z")))

truediv(add(x, y), z)
