In [1]:
import numpy as np
from typing import NamedTuple, Optional, Any, Type
from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
import inspect
from functools import partial, reduce
from itertools import starmap, accumulate, chain
import types
import builtins as bis
import operator as op
from copy import deepcopy as dcp

In [47]:
# utils

smap = starmap
cur = partial

def argmax(arg, default, xs):
  return max(xs, key=arg, default=default)

def inc(x):
  return x + 1

def dec(x):
  return x - 1

def even(x):
  return x % 2 == 0

def efilter(pred, xs):
  return tuple(filter(pred, xs))

def acc(fn, *xs, init=0):
  return accumulate(zip(*xs), lambda acc, xs: fn(acc, *xs), initial=init)

def foldl(fn, *xs, init=0):
  return reduce(lambda acc, xs: fn(acc, *xs), zip(*xs), init)

def compose2(f, g):
  return lambda *xs, **kws: f(g(*xs, **kws))

def compose(*fns):
  return reduce(compose2, fns)

def emap(f, *xs):
  return tuple(map(f, *xs))

def eacc(fn, *xs, init=0):
  return tuple(acc(fn, *xs, init))

def esmap(f, xs):
  return tuple(starmap(f, xs))

def ezip(*xs):
  return tuple(zip(*xs))

def unzip2(pairs):
  lst1, lst2 = [], []
  for x1, x2 in pairs:
    lst1.append(x1)
    lst2.append(x2)
  return lst1, lst2

def swap(f):
  return lambda a, b: f(b, a)

def identity(x):
  return x

def some(xs):
  return next(iter(xs))

def all_same_type(xs):
  return all(isinstance(x, type(some(xs))) for x in xs)

def ext(cls, name, fn):
  cp = dcp(cls)
  setattr(cp, name, fn)
  return cp

def exts(cls, names, fns):
  return foldl(ext, names, fns, init=cls)

def clext(cls1, cls2, names):
  fns = [getattr(cls2, name) for name in names]
  assert all(callable(fn) for fn in fns)
  return exts(dcp(cls1), names, fns)

def cons(x, xs):
  return (x,) + xs

def snoc(xs, x):
  return xs + (x,)

def zeros_like(x):
  av = aval(x)
  return np.zeros(shape(av), dtype(av))

In [3]:
# data structures

# Primitive = (name)

Primitive = namedtuple('Primitive', ['name'])

def name(p):
  match p:
    case Primitive(nm): return nm
    case _: raise ValueError(f"name: {p}")

In [4]:
# MainTrace = (lvl, tc_t, gd)

MainTrace = namedtuple('MainTrace', ['lvl', 'tc_t', 'gd'])

def lvl(mt):
  match mt:
    case MainTrace(l, _, _):
      return l
    case Trace(mn):
      return lvl(mn)
    case EvalTrace(mn):
      return lvl(mn)
    case JVPTrace(mn):
      return lvl(mn)
    case Tracer(tc, _):
      return lvl(tc)
    case JVPTracer(tc, _, _, _):
      return lvl(tc)
    case _: raise ValueError(f"lvl: {mt}")

def tc_t(mt):
  match mt:
    case MainTrace(_, tc_t, _):
      return tc_t
    case _: raise ValueError(f"tc_t: {mt}")

def gd(mt):
  match mt:
    case MainTrace(_, _, gd):
      return gd
    case _: raise ValueError(f"gd: {mt}")

In [5]:
# Trace = (main) | EvalTrace | JVPTrace
# EvalTrace = (main)
# JVPTrace = (main)

Trace = namedtuple('Trace', ['main'])
EvalTrace = namedtuple('EvalTrace', ['main'])
JVPTrace = namedtuple('JVPTrace', ['main'])

Tc = Trace | EvalTrace | JVPTrace

def main(tc):
  match tc:
    case Trace(mn):
      return mn
    case EvalTrace(mn):
      return mn
    case JVPTrace(mn):
      return mn
    case Tracer(tc, _):
      return main(tc)
    case JVPTracer(tc, _, _, _):
      return main(tc)
    case _: raise ValueError(f"main: {tc}")

map_main = cur(map, main)

In [6]:
# Tracer = (tc, arr_prio) | JVPTracer
# JVPTracer = (tc, primal, tangent, arr_prio)

Tracer = namedtuple('Tracer', ['tc', 'arr_prio'], defaults=[1000])
JVPTracer = namedtuple('JVPTracer', ['tc', 'primal', 'tangent', 'arr_prio'], defaults=[1000])

Tcr = Tracer | JVPTracer

_names = [
  '__neg__',
  '__add__',
  '__radd__',
  '__mul__',
  '__rmul__',
  '__gt__',
  '__lt__',
  '__bool__',
  '__nonzero__'
]

_fns = [
  lambda self: _neg(aval(self), self),
  lambda self, o: _add(aval(self), self, o),
  lambda self, o: _radd(aval(self), self, o),
  lambda self, o: _mul(aval(self), self, o),
  lambda self, o: _rmul(aval(self), self, o),
  lambda self, o: _gt(aval(self), self, o),
  lambda self, o: _lt(aval(self), self, o),
  lambda self: _bool(aval(self), self),
  lambda self: _nonzero(aval(self), self)
]

Tracer = foldl(ext, _names, _fns, init=Tracer)
JVPTracer = clext(JVPTracer, Tracer, _names)

def is_tracer(x):
  match x:
    case Tracer(_, _):
      return True
    case JVPTracer(_, _, _, _):
      return True
    case _:
      return False

filt_tracer = cur(filter, is_tracer)
efilt_tracer = cur(efilter, is_tracer)

def _tc(tcr):
  match tcr:
    case Tracer(tc, _):
      return tc
    case JVPTracer(tc, _, _, _):
      return tc
    case _: raise ValueError(f"tc: {tcr}")

def __arr_prio__(tcr):
  match tcr:
    case Tracer(_, arr_prio):
      return arr_prio
    case JVPTracer(_, _, _, arr_prio):
      return arr_prio
    case _: raise ValueError(f"arr_prio: {tcr}")

def aval(x):
  match x:
    case Tracer(_, _):
      raise ValueError(f"aval: {x}")
    case JVPTracer(tc, primal, tangent, _):
      return aval(primal)
    case v if is_pure(v):
      return ConcreteArray(np.asarray(v))
    case _:
      raise TypeError(f"aval: {x}")

def primal(x):
  match x:
    case JVPTracer(_, primal, _, _):
      return primal
    case _:
      raise ValueError(f"primal: {x}")

def tangent(x):
  match x:
    case JVPTracer(_, _, tangent, _):
      return tangent
    case _:
      raise ValueError(f"tangent: {x}")

In [7]:
# Array = ShapedArray | ConcreteArray
# ShapedArray = (shape, dtype, ab_lvl)
# ConcreteArray = (val, ab_lvl)

ShapedArray = namedtuple('ShapedArray', ['shape', 'dtype', 'ab_lvl'], defaults=[1])
ConcreteArray = namedtuple('ConcreteArray', ['val', 'ab_lvl'], defaults=[2])

Array = ShapedArray | ConcreteArray

def shape(sa):
  match sa:
    case ShapedArray(shape, _, _):
      return shape
    case ConcreteArray(val, _):
      return val.shape
    case x if type(x) in JAX_TYPES:
      return np.shape(x)
    case _: 
      raise ValueError(f"shape: {sa}")

def dtype(sa):
  match sa:
    case ShapedArray(_, dtype, _):
      return dtype
    case ConcreteArray(val, _):
      return val.dtype
    case x if type(x) in {bool, int, float}:
      return type(x)
    case x if type(x) in {np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}:
      return np.dtype(x)
    case _: raise ValueError(f"dtype: {sa}")

def ab_lvl(a): 
  match a:
    case ShapedArray(_, _, ab_lvl):
      return ab_lvl
    case ConcreteArray(_, ab_lvl):
      return ab_lvl
    case _: raise ValueError(f"ab_lvl: {sa}")

def ndim(a): 
  match a:
    case ShapedArray(shape, _, _):
      return len(shape)
    case _: 
      raise ValueError(f"ndim: {a}")
      
def val(a):
  match a:
    case ConcreteArray(val, _):
      return val
    case _:
      raise ValueError(f"val: {a}")

def _neg(av, x):
  match av:
    case ShapedArray(_, _, _): 
      return neg(x)
    case ConcreteArray(_):
      return neg(x)
    case _:
      raise TypeError(f"_neg: {av}")

def _add(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return add(x, y)
    case ConcreteArray(_):
      return add(x, y)
    case _:
      raise TypeError(f"_add: {av}")

def _radd(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return swap(add)(x, y)
    case ConcreteArray(_):
      return swap(add)(x, y)
    case _:
      raise TypeError(f"_radd: {av}")

def _mul(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return mul(x, y)
    case ConcreteArray(_):
      return mul(x, y)
    case _:
      raise TypeError(f"_mul: {av}")

def _rmul(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return swap(mul)(x, y)
    case ConcreteArray(_):
      return swap(mul)(x, y)
    case _:
      raise TypeError(f"_rmul: {av}")

def _gt(av, x, y):
  match av:
    case ShapedArray(_, _, _):
      return gt(x, y)
    case ConcreteArray(_):
      return gt(x, y)
    case _:
      raise TypeError(f"_gt: {av}")

def _lt(av, x, y):
  match av:
    case ShapedArray(_, _, _): 
      return lt(x, y)
    case ConcreteArray(_):
      return lt(x, y)
    case _:
      raise TypeError(f"_lt: {av}")

def _bool(av, x):
  match av:
    case ShapedArray(_, _, _):
      raise Exception("ShapedArray can't be unambiguously converted to bool")
    case ConcreteArray(_):
      return bool(val(aval(x)))
    case _:
      raise TypeError(f"_bool: {av}")

def _nonzero(av, x):
  match av:
    case ShapedArray(_, _, _):
      raise Exception("ShapedArray can't be unambiguously converted to bool")
    case ConcreteArray(_):
      return bool(val(aval(x)))
    case _:
      raise TypeError(f"_nonzero: {av}")

In [8]:
def prim_to_rule(tc, prim):
  match tc: 
    case EvalTrace(_):
      return _impl_rules(prim)
    case JVPTrace(_):
      return _jvp_rules(prim)

def _impl_rules(prim): 
  match prim: 
    case Primitive('add'):
      return lambda x, y: [np.add(x, y)]
    case Primitive('mul'):
      return lambda x, y: [np.multiply(x, y)]
    case Primitive('neg'):
      return lambda x: [np.negative(x)]
    case Primitive('sin'):
      return lambda x: [np.sin(x)]
    case Primitive('cos'):
      return lambda x: [np.cos(x)]
    case Primitive('reduce_sum'):
      return lambda x, *, axis: [np.sum(x, axis)]
    case Primitive('gt'):
      return lambda x, y: [np.greater(x, y)]
    case Primitive('lt'):
      return lambda x, y: [np.less(x, y)]
    case Primitive('transpose'):
      return lambda x, *, perm: [np.transpose(x, perm)]
    case Primitive('broadcast'):
      return lambda x, *, shape, axes: [np.broadcast_to(foldl(np.expand_dims, axes, init=x), shape)]

def _jvp_rules():
  def add_jvp(ps, ts): 
    (x, y), (x_dot, y_dot) = ps, ts
    return [x + y], [x_dot + y_dot]

  def mul_jvp(ps, ts): 
    (x, y), (x_dot, y_dot) = ps, ts
    return [x * y], [x_dot * y + x * y_dot]

  def sin_jvp(ps, ts): 
    (x,), (x_dot,) = ps, ts
    return [sin(x)], [cos(x) * x_dot]
  
  def cos_jvp(ps, ts):
    (x,), (x_dot,) = ps, ts
    return [cos(x)], [-sin(x) * x_dot]

  def neg_jvp(ps, ts):
    (x,), (x_dot,) = ps, ts
    return [neg(x)], [neg(x_dot)]

  def reduce_sum_jvp(ps, ts, *, axis):
    (x,), (x_dot,) = ps, ts
    return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]

  def gt_jvp(ps, ts):
    (x, y), _ = ps, ts
    out_p = gt(x, y) 
    return [out_p], [zeros_like(out_p)]

  def lt_jvp(ps, ts): 
    (x, y), _ = ps, ts
    out_p = lt(x, y)
    return [out_p], [zeros_like(out_p)]

  def jvp_rules(prim):
    match prim:
      case Primitive('add'):
        return add_jvp
      case Primitive('mul'):
        return mul_jvp
      case Primitive('sin'):
        return sin_jvp
      case Primitive('cos'):
        return cos_jvp
      case Primitive('neg'):
        return neg_jvp
      case Primitive('reduce_sum'):
        return reduce_sum_jvp
      case Primitive('gt'):
        return gt_jvp
      case Primitive('lt'):
        return lt_jvp
      case _:
        raise ValueError(f"jvp_rules: {prim}")
  return jvp_rules

_jvp_rules = _jvp_rules()

In [9]:
def f(x):
  # y = mul(sin(x), 2.0)
  # z = add(neg(y), x)
  # return z
  y = sin(x) * 2.
  z = -y + x
  return z

def max_tc(xs, default=None):
  return compose(main_to_tc, cur(argmax, lvl, default), map_main, filt_tracer)(xs)

def main_to_tc(x):
  return tc_t(x)(x)

def bind(tcs, dyn_tc, prim, *xs, **kvs):
  T = max_tc(xs, tcs[0])
  rs = map(cur(box, T), xs)
  os = proc_prim(T, prim, *rs, **kvs)
  ret, = [unbox(o) for o in os]
  return ret

tcs = [MainTrace(0, EvalTrace, None)]
dyn_tc = None
bind = partial(bind, tcs, dyn_tc)

# f(3.0)

In [10]:
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive('neg')
sin_p = Primitive('sin')
cos_p = Primitive('cos')
reduce_sum_p = Primitive('reduce_sum')
gt_p = Primitive('gt')
lt_p = Primitive('lt')
transpose_p = Primitive('transpose')
broadcast_p = Primitive('broadcast')

def add(x, y): return bind(add_p, x, y)
def mul(x, y): return bind(mul_p, x, y)
def neg(x): return bind(neg_p, x)
def sin(x): 
  return bind(sin_p, x)
def cos(x): return bind(cos_p, x)
def reduce_sum(x): return bind(reduce_sum_p, x)
def gt(x, y): return bind(gt_p, x, y)
def lt(x, y): return bind(lt_p, x, y)
def transpose(x, perm): return bind(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  match axis: 
    case None: 
      return reduce_sum(x, axis=tuple(range(ndim(x))))
    case int(v):
      return reduce_sum(x, axis=(v,))
    case tuple(_):
      return bind(reduce_sum_p, x, axis=axis)
    case _: raise ValueError(f"reduce_sum: {axis}")

In [11]:
def _is_pure():
  JAX_TYPES = {bool, int, float, np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
  return lambda x: type(x) in JAX_TYPES

is_pure = _is_pure()

# unbox Tracer x
def unbox(tcr):
  match tcr:
    case Tracer(_, _) | JVPTracer(_, _, _, _):
      return tcr
    case _:
      return tcr

# box x into a tracer wrt the Tracer tc
def box(tc, x): 
  # print(x)
  match x:
    case Tracer(_, _) | JVPTracer(_, _, _, _):
      lvl_tc = lvl(tc)
      if main(x) is main(tc):
        return x
      elif lvl(x) < lvl_tc:
        return _make_tcr(tc)(x)
      elif lvl(x) > lvl_tc:
        raise Exception(f"can't lift level {lvl(x)} to {lvl_tc}")
      else:
        raise Exception(f"different traces at same level: {_tc(x)}, {tc}")
    case y if is_pure(y):
      return _make_tcr(tc)(y)
    case _:
      raise ValueError(f"box: {x}")

def _make_tcr(tc):
  match tc:
    case EvalTrace(_):
      return identity
    case JVPTrace(_):
      return lambda x: JVPTracer(tc, x, zeros_like(x))
    case _:
      raise ValueError(f"_make_tcr: {tc}")

In [12]:
def proc_prim(tc, prim, *tcrs, **kvs): 
  match tc:
    case EvalTrace(_):
      return prim_to_rule(tc, prim)(*tcrs, **kvs)
    case JVPTrace(_):
      ps_in, ts_in = unzip2((primal(tcr), tangent(tcr)) for tcr in tcrs)
      jvp_rule = _jvp_rules(prim)
      ps_out, ts_out = jvp_rule(ps_in, ts_in, **kvs)
      return [JVPTracer(tc, p, t) for p, t in zip(ps_out, ts_out)]
    case _: 
      raise ValueError(f"proc_prim: {tc}")

In [13]:
@contextmanager
def new_main(tcs, tc_t, gd=None): 
  mn = MainTrace(len(tcs), tc_t, gd)
  tcs.append(mn)
  try:
    yield mn
  finally:
    tcs.pop()

In [14]:
def jvp_v1(f, ps, ts):
  tcs = [MainTrace(0, EvalTrace, None)]
  with new_main(tcs, JVPTrace) as mn:
    tc = JVPTrace(mn)
    tcrs_in = [JVPTracer(tc, p, t) for p, t in zip(ps, ts)]
    out = f(*tcrs_in)
    tcr_out = box(tc, out)
    p_out, t_out = primal(tcr_out), tangent(tcr_out)
  return p_out, t_out

def jvp_flat(f, ps, ts):
  tcs = [MainTrace(0, EvalTrace, None)]
  with new_main(tcs, JVPTrace) as mn: 
    tc = JVPTrace(mn)
    tcrs_in = [JVPTracer(tc, p, t) for p, t in zip(ps, ts)]
    outs = f(*tcrs_in)
    tcrs_out = emap(cur(box, tc), [outs])
    ps_out, ts_out = unzip2((primal(t), tangent(t)) for t in tcrs_out)
  return ps_out, ts_out

def jvp(f, ps, ts):
  # tcs = [MainTrace(0, EvalTrace, None)]
  ps_flat, in_tree = tree_flatten(ps)
  ts_flat, in_tree2 = tree_flatten(ts)
  assert in_tree == in_tree2
  f, out_tree = flatten_fun(f, in_tree)
  ps_out_flat, ts_out_flat = jvp_flat(f, ps_flat, ts_flat)
  ps_out = tree_unflatten(out_tree(), ps_out_flat)
  ts_out = tree_unflatten(out_tree(), ts_out_flat)
  return ps_out, ts_out

In [16]:
def deriv(f):
  tcs = [MainTrace(0, EvalTrace, None)]
  return lambda x: jvp_v1(tcs, f, (x,), (1.0,))[1]

def f2(x):
  if x > 0.:
    return 2. * x
  else:
    return x

In [53]:
def flatten_fun(f, in_tree):
  store = []
  def flat_fun(*args_flat): 
    pytree_args = tree_unflatten(in_tree, args_flat)
    out = f(*pytree_args)
    out_flat, out_tree = tree_flatten(out)
    store.append(out_tree)
    return out_flat 
  return flat_fun, store

In [145]:
NodeType = namedtuple('NodeType', ['name', 'to_iter', 'from_iter'])

def name(nt):
  return nt.name

def to_iter(nt): 
  return nt.to_iter 

def from_iter(nt): 
  return nt.from_iter

def register_pytree_node(node_types, ty, to_iter, from_iter): 
  ret = dcp(node_types)
  ret[ty] = NodeType(str(ty), to_iter, from_iter)
  return ret

node_types = {}
node_types = register_pytree_node(node_types, tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
node_types = register_pytree_node(node_types, list, lambda l: (None, l), lambda _, xs: list(xs))
node_types = register_pytree_node(node_types, 
                                  dict, 
                                  lambda d: emap(tuple, unzip2(sorted(d.items()))),
                                  lambda ks, vs: dict(zip(ks, vs)))

def nt(nts, x): 
  return nts.get(type(x))

In [161]:
PyTreeDef = namedtuple('PyTreeDef', ['node_type', 'node_metadata', 'child_treedefs'])

node_type = lambda ptd: ptd.node_type
node_metadata = lambda ptd: ptd.node_metadata
child_treedefs = lambda ptd: ptd.child_treedefs

def tree_flatten(x): 
  children_iter, treedef = _tree_flatten(x) 
  return list(children_iter), treedef

def __tree_flatten(nts):
  def _tree_flatten(x):
    _nt = nt(nts, x)
    if _nt: 
      node_metadata, cs = to_iter(_nt)(x)
      cs_flat, c_trees = unzip2(map(_tree_flatten, cs))
      flattened = chain.from_iterable(cs_flat)
      return flattened, PyTreeDef(_nt, node_metadata, tuple(c_trees))
    else: 
      return [x], 'leaf'
  return _tree_flatten

_tree_flatten = __tree_flatten(node_types)

def tree_unflatten(treedef, xs): 
  return _tree_unflatten(treedef, iter(xs))

def _tree_unflatten(treedef, xs): 
  if treedef == 'leaf':
    return next(xs)
  else:
    children = (_tree_unflatten(t, xs) for t in child_treedefs(treedef))
    return from_iter(node_type(treedef))(node_metadata(treedef), children)

In [162]:
def h(x):
  y = sin(x) * 2.
  z = -y + x
  return {'hi': z, 'there': [x, y]}

x, xdot = 3., 1.
y, ydot = jvp(h, (x,), (xdot,))
print(y, ydot)

ValueError: box: [JVPTracer(tc=JVPTrace(main=MainTrace(lvl=1, tc_t=<class '__main__.JVPTrace'>, gd=None)), primal=2.7177599838802657, tangent=2.979984993200891, arr_prio=1000), JVPTracer(tc=JVPTrace(main=MainTrace(lvl=1, tc_t=<class '__main__.JVPTrace'>, gd=None)), primal=3.0, tangent=1.0, arr_prio=1000), JVPTracer(tc=JVPTrace(main=MainTrace(lvl=1, tc_t=<class '__main__.JVPTrace'>, gd=None)), primal=0.2822400161197344, tangent=-1.9799849932008908, arr_prio=1000)]

I don't think that JAX needs all these complicated data structures to override primitive application. It may be the case that we can make judicious use of the eval() function and metaprogramming to transform primitive application.