In [62]:
import numpy as np
import plotly.graph_objects as go
from typing import NamedTuple, Optional, Any
from collections.abc import Sequence
from contextlib import contextmanager 

In [56]:
# helper functions

def plot(fn, xs): 
    fig = go.Figure()
    y = fn(xs)
    fig = go.Figure(data=go.Scatter(x=xs, y=y, mode='markers'))
    fig.update_layout(xaxis_title='x', yaxis_title='f(x)', template='plotly')
    fig.show()
  

In [57]:
def f(x): 
  y = np.sin(x) * 2. 
  z = - y + x 
  return z
plot(f, np.arange(-10, 10, 1))

In [61]:
# core primitives

class Primitive(NamedTuple): 
  name: str 

add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive('neg')
sin_p = Primitive('sin')
cos_p = Primitive('cos')
reduce_sum_p = Primitive('reduce_sum')
greater_p = Primitive('greater')
less_p = Primitive('less')
transpose_p = Primitive('transpose')
broadcast_p = Primitive('broadcast')

# convention:
  # array data -> pass as positional args to bind()
  # metadata -> pass as keyword args to bind()
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None: 
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int: 
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params): 
  out, = bind(prim, *args, **params) 
  return out

In [71]:
class MainTrace(NamedTuple): 
  level: int
  trace_type: type['Trace']
  glibal_data: Optional[Any]

trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None): 
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data) 
  trace_stack.append(main) 

  try:
    yield main 
  finally:
    trace_stack.pop()