<a href="https://colab.research.google.com/github/lizhieffe/nn_grad_engine/blob/main/Auto_Differentiator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tutorial: https://bclarkson-code.com/posts/llm-from-scratch-scalar-autograd/post.html

In [4]:
from typing import Optional

In [53]:
class Tensor:

  value: float

  # If the tensor is the result of an operation, the operation arguments are
  # stored in the args.
  args: tuple["Tensor"] = ()

  # If the Tensor is the result of an operation, the derivatives to the
  # operation arguments are stored in local_derivatives.
  #
  # length == # of the input Tensors
  local_derivatives: tuple["Tensor"] = ()

  # The derivative of the loss to the current Tensor in back prop.
  derivative: Optional["Tensor"] = None

  def __init__(self, value: float) -> None:
    self.value = value

  def __repr__(self) -> str:
    return f"Tensor({self.value})"



  def backward(self) -> None:
    # if not self.args:
    #   return

    # if not self.args:
    #   raise Exception("Tensor has no args")

    # if not self.local_derivatives:
    #   raise Exception("Tensor has no local derivatives")

    # if not self.derivative:
    #   self.derivative = Tensor(1)

    # for arg, local_derivative in zip(self.args, self.local_derivatives):
    #   d = _mul(local_derivative, self.derivative)
    #   if arg.derivative is None:
    #     arg.derivative = d
    #   else:
    #     arg.derivative = _add(arg.derivative, d)

    #   arg.backward()

    # self.derivative = Tensor(1)

    self.derivative = Tensor(1)

    dfs(self, self.derivative)

Tensor(5)

Tensor(5)

In [54]:
def _add(a: Tensor, b: Tensor):
  ret = Tensor(a.value + b.value)
  ret.args = (a, b)
  ret.local_derivatives = (Tensor(1), Tensor(1))
  return ret

def _sub(a: Tensor, b: Tensor):
  ret = Tensor(a.value - b.value)
  ret.args = (a, b)
  ret.local_derivatives = (Tensor(1), Tensor(-1))
  return ret

def _mul(a: Tensor, b: Tensor):
  ret = Tensor(a.value * b.value)
  ret.args = (a, b)
  ret.local_derivatives = (Tensor(b.value), Tensor(a.value))
  return ret

assert _add(Tensor(1), Tensor(2)).value == 3
assert _sub(Tensor(1), Tensor(2)).value == -1
assert _mul(Tensor(1), Tensor(2)).value == 2

In [55]:
# DFS to back propagate
# 1. for a node, its derivative equals the sum of derivatives along all paths that back-prop to it
# 2. because of the chain rule of deriviate, a node's derivate on that single path equals the multiplying of deriviates of its descend node * local_deriviate
def dfs(t: "Tensor", curr_derivative: float) -> None:
  if not t.args:
    return

  for arg, local_der in zip(t.args, t.local_derivatives):
    derivative = _mul(local_der, curr_derivative)
    if arg.derivative is None:
      arg.derivative = derivative
    else:
      arg.derivative = _add(arg.derivative, derivative)

    dfs(arg, derivative)

In [56]:
# Unit Test
add_tensor = _add(Tensor(1), Tensor(2))
assert add_tensor.value == 3
assert add_tensor.args[0].value == 1
assert add_tensor.args[1].value == 2
assert add_tensor.local_derivatives[0].value == 1
assert add_tensor.local_derivatives[1].value == 1

In [57]:
# Unit test back prop
#
# L = a + b * c
a = Tensor(1)
b = Tensor(2)
c = Tensor(3)

L = _add(a, _mul(b, c))
L.backward()

assert a.derivative.value == 1
assert b.derivative.value == 3
assert c.derivative.value == 2

In [58]:
# Unit test back prop
#
# L = a + b * (a + c)
a = Tensor(1)
b = Tensor(2)
c = Tensor(3)

L = _add(a, _mul(b, _add(a, c)))
L.backward()

assert a.derivative.value == 1 + 2
assert b.derivative.value == 1 + 3
assert c.derivative.value == 2

In [64]:
# Unit test back prop
#
# L = d * F
# F = H + G
# H = a * I
# I = a + b
# G = c * H
#
# L = d*a**2 + abd + a**2cd + abcd
# dL/da = 2ad + bd + 2acd + bcd
# dL/db = ad + acd
# dL/dc = a**2d + abd
# dL/dd = a**2 + ab + a**2c+abc

a = Tensor(1)
b = Tensor(2)
c = Tensor(3)
d = Tensor(4)

I = _add(a, b)
H = _mul(a, I)
G = _mul(c, H)
F = _add(H, G)
L = _mul(d, F)

L.backward()

assert L.value == 4*(1**2) + 1*2*4 + (1**2)*3*4 + 1*2*3*4
assert a.derivative.value == (2*1*4 + 2*4 + 2*1*3*4 + 2*3*4), f"{a.derivative=}"

In [65]:
# Unit test back-prop
#
# Example in https://bclarkson-code.com/posts/llm-from-scratch-scalar-autograd/post.html

y = Tensor(1)
m = Tensor(2)
x = Tensor(3)
c = Tensor(4)

left = _sub(y, _add(_mul(m, x), c))
right = _sub(y, _add(_mul(m, x), c))

L = _mul(left, right)
L.backward()

assert x.derivative.value == 36