In [1]:
import numpy as np
import plotly.graph_objects as go
from typing import NamedTuple, Optional, Any, Type
from collections.abc import Sequence
from contextlib import contextmanager 
import inspect



In [2]:
# helper functions

def plot(fn, xs): 
    fig = go.Figure()
    y = fn(xs)
    fig = go.Figure(data=go.Scatter(x=xs, y=y, mode='markers'))
    fig.update_layout(xaxis_title='x', yaxis_title='f(x)', template='plotly')
    fig.show()
  

In [3]:
def f(x): 
  y = np.sin(x) * 2. 
  z = - y + x 
  return z
plot(f, np.arange(-10, 10, 0.5))

In [4]:
# core primitives

class Primitive(NamedTuple): 
  name: str 

add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive('neg')
sin_p = Primitive('sin')
cos_p = Primitive('cos')
reduce_sum_p = Primitive('reduce_sum')
greater_p = Primitive('greater')
less_p = Primitive('less')
transpose_p = Primitive('transpose')
broadcast_p = Primitive('broadcast')

# convention:
  # array data -> pass as positional args to bind()
  # metadata -> pass as keyword args to bind()
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None: 
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int: 
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params): 
  out, = bind(prim, *args, **params) 
  return out

In [7]:
class MainTrace(NamedTuple): 
  level: int
  trace_type: type['Trace']
  glibal_data: Optional[Any]

trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None): 
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data) 
  trace_stack.append(main) 

  try:
    yield main 
  finally:
    trace_stack.pop()

In [6]:
class Trace: 
  main: MainTrace 

  def __init__(self, main: MainTrace): 
    self.main = main 

  def pure(self, val): assert False # must override
  def lift(self, val): assert False # must override

  def process_primitive(self, primitive, tracers, params): 
    assert False # must override

In [None]:
class Tracer:
  _trace: Trace 

  __array_priority__ = 1000

  @property 
  def aval(self): 
    assert False # must override

  def full_lower(self): 
    return self # default implementation
  
  def __neg__(self): return self.aval._neg(self)
  def __add__(self, other): return self.aval._add(self, other)

In [48]:
# A = np.array([[0, 4, 3], [3, 1, -5], [-2, 1, 3]])
# A = np.array([[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]])
A = np.array([[1000, -1, -1], [-1, -1, -1], [-1, -1, -1]])
A = np.array([[-1, 2], [3, 4]])
n = len(A)

In [52]:
def f(R, I): 
  if R >= n: return 0 
  if I >= n: return 0
  x1 = []
  for i in range(0, I + 1):
    x1.append(A[i, R])
  x2 = []
  for r in range(0, R + 1): 
    x2.append(A[I, r])
  # print(x1)
  # print(x2)
  return max(f(R + 1, I), f(R + 1, I) + sum(x1), f(R, I + 1), f(R, I + 1) + sum(x2))

f(0, 0)


#  1 2
#  3 4 

IndentationError: unexpected indent (220998899.py, line 17)