In [184]:
import numpy as np
from typing import NamedTuple, Optional, Any, Type
from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
import inspect
from functools import partial, reduce
from itertools import starmap, accumulate
import types
import builtins as bis
import operator as op
from copy import deepcopy as dcp

In [202]:
smap = starmap

def acc(fn, *xs, init=0):
  return accumulate(zip(*xs), lambda acc, xs: fn(acc, *xs), initial=init)

def foldl(fn, *xs, init=0):
  return reduce(lambda acc, xs: fn(acc, *xs), zip(*xs), init)

def emap(f, *xs):
  return tuple(map(f, *xs))

def eacc(fn, *xs, init=0):
  return tuple(acc(fn, *xs, init))

def esmap(f, *xs):
  return tuple(starmap(f, *xs))

def ezip(*xs):
  return tuple(zip(*xs))

def unzip2(assocs):
  return ezip(*assocs)

def swap(f):
  return lambda a, b: f(b, a)

def identity(x):
  return x

def some(xs):
  return next(iter(xs))

def all_same_type(xs):
  return len(xs) > 0 and all(isinstance(x, type(some(xs))) for x in xs)

def ext(cls, name, fn):
  cp = dcp(cls)
  setattr(cp, name, fn)
  return cp

def exts(cls, names, fns):
  return foldl(ext, names, fns, init=cls)

def clext(cls1, cls2, names):
  fns = [getattr(cls2, name) for name in names]
  assert all(callable(fn) for fn in fns)
  return exts(dcp(cls1), names, fns)

def cons(x, xs):
  return (x,) + xs

def snoc(xs, x):
  return xs + (x,)

def zeros_like(x):
  av = aval(x)
  return np.zeros(shape(av), dtype(av))

In [186]:
# globals

JAX_TYPES = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}

In [187]:
# data structures

# Primitive = (name)

Primitive = namedtuple('Primitive', ['name'])

def name(p):
  match p:
    case Primitive(name):
      return name
    case _: raise ValueError(f"name: {p}")

In [188]:
# MainTrace = (lvl, tc_t, gd)

MainTrace = namedtuple('MainTrace', ['lvl', 'tc_t', 'gd'])

def lvl(mt):
  match mt:
    case MainTrace(lvl, _, _):
      return lvl
    case Trace(main):
      return lvl(main)
    case EvalTrace(main):
      return lvl(main)
    case JVPTrace(main):
      return lvl(main)
    case Tracer(tc, _):
      return lvl(tc)
    case JVPTracer(tc, _, _, _):
      return lvl(tc)
    case _: raise ValueError(f"lvl: {mt}")

def tc_t(mt):
  match mt:
    case MainTrace(_, tc_t, _):
      return tc_t
    case _: raise ValueError(f"tc_t: {mt}")

def gd(mt):
  match mt:
    case MainTrace(_, _, gd):
      return gd
    case _: raise ValueError(f"gd: {mt}")


lvl(MainTrace(0, Trace, None))

0

In [189]:
# Trace = (main) | EvalTrace | JVPTrace
# EvalTrace = (main)
# JVPTrace = (main)

Trace = namedtuple('Trace', ['main'])
EvalTrace = namedtuple('EvalTrace', ['main'])
JVPTrace = namedtuple('JVPTrace', ['main'])

def main(tc): 
  match tc:
    case Trace(main):
      return main
    case EvalTrace(main):
      return main
    case JVPTrace(main):
      return main
    case Tracer(tc, _):
      return main(tc)
    case JVPTracer(tc, _, _, _):
      return main(tc)
    case _: raise ValueError(f"main: {tc}")

def pure(tc, x): 
  match tc:
    case Trace(_):
      raise ValueError(f"pure: {tc}")
    case EvalTrace(_):
      return x
    case JVPTrace(main):
      return JVPTracer(tc, x, zeros_like(x))
    case _: 
      raise ValueError(f"pure: {tc}")

def lift(tc, x):
  match tc:
    case Trace(_):
      raise ValueError(f"lift: {tc}")
    case EvalTrace(_):
      return x
    case JVPTrace(main):
      return JVPTracer(tc, x, zeros_like(x))
    case _: 
      raise ValueError(f"lift: {tc}")

def proc_prim(tc, prim, *xs, **kvs): 
  match tc:
    case Trace(_):
      raise ValueError(f"proc_prim: {tc}")
    case EvalTrace(_):
      return impl_rules(prim, *xs, **kvs)
    case JVPTrace(main): 
      ps_in, ts_in = unzip2((primal(x), tangent(x)) for x in xs)
      jvp_rule = jvp_rules(prim)
      ps_out, ts_out = jvp_rule(ps_in, ts_in, **kvs)
      return [JVPTracer(tc, p, t) for p, t in zip(ps_out, ts_out)]
    case _: 
      raise ValueError(f"proc_prim: {tc}")

In [200]:
# Tracer = (tc, arr_prio) | JVPTracer
# JVPTracer = (tc, primal, tangent, arr_prio)

Tracer = namedtuple('Tracer', ['tc', 'arr_prio'], defaults=[1000])
JVPTracer = namedtuple('JVPTracer', ['tc', 'primal', 'tangent', 'arr_prio'], defaults=[1000])

_names = [
  '__neg__',
  '__add__',
  '__radd__',
  '__mul__',
  '__rmul__',
  '__gt__',
  '__lt__',
  '__bool__',
  '__nonzero__'
]

_fns = [
  lambda self: _neg(aval(self), self),
  lambda self, o: _add(aval(self), self, o),
  lambda self, o: _radd(aval(self), self, o),
  lambda self, o: _mul(aval(self), self, o),
  lambda self, o: _rmul(aval(self), self, o),
  lambda self, o: _gt(aval(self), self, o),
  lambda self, o: _lt(aval(self), self, o),
  lambda self: _bool(aval(self), self),
  lambda self: _nonzero(aval(self), self)
]

Tracer = foldl(ext, _names, _fns, init=Tracer)
JVPTracer = clext(JVPTracer, Tracer, _names)

def _tc(tcr):
  match tcr:
    case Tracer(tc, _):
      return tc
    case JVPTracer(tc, _, _, _):
      return tc
    case _: raise ValueError(f"tc: {tcr}")

def __arr_prio__(tcr):
  match tcr:
    case Tracer(_, arr_prio):
      return arr_prio
    case JVPTracer(_, _, _, arr_prio):
      return arr_prio
    case _: raise ValueError(f"arr_prio: {tcr}")

def aval(v, ts):
  match v:
    case Tracer(_, _):
      raise ValueError(f"aval: {v}")
    case JVPTracer(tc, primal, tangent, _):
      ...
    case x if type(x) in ts:
      return ConcreteArray(np.asarray(x))
    case _:
      raise TypeError(f"aval: {v}")

aval = partial(aval, ts=JAX_TYPES)

def full_lower(v):
  match v:
    case Tracer(_, _):
      return v
    case _:
      return v

def full_raise(tc, val, ts): 
  match val:
    case Tracer(_, _):
      lvl_tc = lvl(tc)
      if main(val) is main(tc):
        return val
      elif lvl(val) < lvl_tc:
        return lift(tc, val)
      elif lvl(val) > lvl_tc:
        raise Exception(f"can't lift level {lvl(val)} to {lvl_tc}")
      else:
        raise Exception(f"different traces at same level: {_tc(val)}, {tc}")
    case JVPTracer(_, _, _, _):
      ...
    case x if type(x) in ts:
      return pure(tc, x)
    case _:
      raise ValueError(f"full_raise: {val}")

full_raise = partial(full_raise, ts=JAX_TYPES)

def primal(x):
  match x:
    case JVPTracer(_, primal, _, _):
      return primal
    case _:
      raise ValueError(f"primal: {x}")

def tangent(x):
  match x:
    case JVPTracer(_, _, tangent, _):
      return tangent
    case _:
      raise ValueError(f"tangent: {x}")

In [205]:
# Array = ShapedArray | ConcreteArray
# ShapedArray = (shape, dtype, ab_lvl)
# ConcreteArray = (val, ab_lvl)

ShapedArray = namedtuple('ShapedArray', ['shape', 'dtype', 'ab_lvl'], defaults=[1])
ConcreteArray = namedtuple('ConcreteArray', ['val', 'ab_lvl'], defaults=[2])

def shape(sa):
  match sa:
    case ShapedArray(shape, _, _):
      return shape
    case ConcreteArray(val, _):
      return val.shape
    case x if type(x) in JAX_TYPES:
      return np.shape(x)
    case _: 
      raise ValueError(f"shape: {sa}")

def dtype(sa):
  match sa:
    case ShapedArray(_, dtype, _):
      return dtype
    case ConcreteArray(val, _):
      return val.dtype
    case x if type(x) in {bool, int, float}:
      return type(x)
    case x if type(x) in {np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}:
      return np.dtype(x)
    case _: raise ValueError(f"dtype: {sa}")

def ab_lvl(a): 
  match a:
    case ShapedArray(_, _, ab_lvl):
      return ab_lvl
    case ConcreteArray(_, ab_lvl):
      return ab_lvl
    case _: raise ValueError(f"ab_lvl: {sa}")

def ndim(a): 
  match a:
    case ShapedArray(shape, _, _):
      return len(shape)
    case _: 
      raise ValueError(f"ndim: {a}")
      
def val(a):
  match a:
    case ConcreteArray(val, _):
      return val
    case _:
      raise ValueError(f"val: {a}")

def _neg(av, x):
  match av:
    case ShapedArray(_, _, _): 
      return neg(x)
    case ConcreteArray(_):
      return neg(x)
    case _:
      raise TypeError(f"_neg: {av}")

def _add(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return add(x, y)
    case ConcreteArray(_):
      return add(x, y)
    case _:
      raise TypeError(f"_add: {av}")

def _radd(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return swap(add)(x, y)
    case ConcreteArray(_):
      return swap(add)(x, y)
    case _:
      raise TypeError(f"_radd: {av}")

def _mul(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return mul(x, y)
    case ConcreteArray(_):
      return mul(x, y)
    case _:
      raise TypeError(f"_mul: {av}")

def _rmul(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return swap(mul)(x, y)
    case ConcreteArray(_):
      return swap(mul)(x, y)
    case _:
      raise TypeError(f"_rmul: {av}")

def _gt(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return gt(x, y)
    case ConcreteArray(_):
      return gt(x, y)
    case _:
      raise TypeError(f"_gt: {av}")

def _lt(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return lt(x, y)
    case ConcreteArray(_):
      return lt(x, y)
    case _:
      raise TypeError(f"_lt: {av}")

def _bool(av, x):
  match av:
    case ShapedArray(_, _, _):
      raise Exception("ShapedArray can't be unambiguously converted to bool")
    case ConcreteArray(_):
      return bool(val(aval(x)))
    case _:
      raise TypeError(f"_bool: {av}")

def _nonzero(av, x):
  match av:
    case ShapedArray(_, _, _):
      raise Exception("ShapedArray can't be unambiguously converted to bool")
    case ConcreteArray(_):
      return bool(val(aval(x)))
    case _:
      raise TypeError(f"_nonzero: {av}")

In [192]:
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive('neg')
sin_p = Primitive('sin')
cos_p = Primitive('cos')
reduce_sum_p = Primitive('reduce_sum')
gt_p = Primitive('gt')
lt_p = Primitive('lt')
transpose_p = Primitive('transpose')
broadcast_p = Primitive('broadcast')

def add(x, y): return bind1(add_p, x, y) 
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def reduce_sum(x): return bind1(reduce_sum_p, x)
def gt(x, y): return bind1(greater_p, x, y)
def lt(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  match axis: 
    case None: 
      return reduce_sum(x, axis=tuple(range(np.ndim(x))))
    case int(v):
      return reduce_sum(x, axis=(v,))
    case tuple(_):
      return bind1(reduce_sum_p, x, axis=axis)
    case _: raise ValueError(f"reduce_sum: {axis}")

def _bind1(tcs, dyn_tc):
  def bind1(prim, *xs, **kvs):
    match bind(tcs, dyn_tc, prim, *xs, **kvs):
      case (ret,):
        return ret
      case _: 
        raise ValueError(f"bind1: {ret}")
  return bind1

tcs = ()
tcs = cons(MainTrace(0, EvalTrace, None), tcs)
dyn_tc = None

bind1 = _bind1(tcs, dyn_tc)

In [193]:
def bind(tcs, dyn_tc, prim, *xs, **kvs):
  top = top_tc(tcs, dyn_tc, xs)
  xs = [full_raise(top, x) for x in xs]
  os = proc_prim(top, prim, *xs, **kvs)
  return [full_lower(o) for o in os]

In [194]:
def top_tc(tcs, dyn_tc, xs): 
  top_main = max((main(x) for x in xs if isinstance(x, Tracer)), default=tcs[0], key=lvl)
  top_main = max(top_main, dyn_tc, key=lvl) if dyn_tc else top_main
  return tc_t(top_main)(top_main)

In [201]:
def impl_rules(prim, *xs, **kvs): 
  match prim: 
    case Primitive('add'):
      return [np.add(*xs)]
    case Primitive('mul'):
      return [np.multiply(*xs)]
    case Primitive('neg'):
      return [np.negative(*xs)]
    case Primitive('sin'):
      return [np.sin(*xs)]
    case Primitive('cos'):
      return [np.cos(*xs)]
    case Primitive('reduce_sum'):
      return [np.sum(*xs, **kvs)]
    case Primitive('gt'):
      return [np.greater(*xs)]
    case Primitive('lt'):
      return [np.less(*xs)]
    case Primitive('transpose'):
      return [np.transpose(*xs, **kvs)]
    case Primitive('broadcast'):
      x = foldl(np.expand_dims, kvs['axes'], init=xs[0])
      return [np.broadcast_to(x, kvs['shape'])]

def _jvp_rules():
  def add_jvp(ps, ts): 
    (x, y), (x_dot, y_dot) = ps, ts
    return [x + y], [x_dot + y_dot]

  def mul_jvp(ps, ts): 
    (x, y), (x_dot, y_dot) = ps, ts
    return [x * y], [x_dot * y + x * y_dot]

  def sin_jvp(ps, ts): 
    (x,), (x_dot,) = ps, ts
    return [sin(x)], [cos(x) * x_dot]
  
  def cos_jvp(ps, ts):
    (x,), (x_dot,) = ps, ts
    return [cos(x)], [-sin(x) * x_dot]

  def neg_jvp(ps, ts):
    (x,), (x_dot,) = ps, ts
    return [neg(x)], [neg(x_dot)]

  def reduce_sum_jvp(ps, ts, *, axis):
    (x,), (x_dot,) = ps, ts
    return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]

  def gt_jvp(ps, ts):
    (x, y), _ = ps, ts
    out_p = gt(x, y) 
    return [out_p], [zeros_like(out_p)]

  def lt_jvp(ps, ts): 
    (x, y), _ = ps, ts
    out_p = lt(x, y)
    return [out_p], [zeros_like(out_p)]

  def jvp_rules(prim):
    match prim:
      case Primitive('add'):
        return add_jvp
      case Primitive('mul'):
        return mul_jvp
      case Primitive('sin'):
        return sin_jvp
      case Primitive('cos'):
        return cos_jvp
      case Primitive('neg'):
        return neg_jvp
      case Primitive('reduce_sum'):
        return reduce_sum_jvp
      case Primitive('gt'):
        return gt_jvp
      case Primitive('lt'):
        return lt_jvp
      case _:
        raise ValueError(f"jvp_rules: {prim}")
  return jvp_rules

jvp_rules = _jvp_rules()

In [196]:
def f(x):
  y = sin(x) * 2.
  z = -y + x
  return z

f(3.0)

2.7177599838802657