In [5]:
import types
from __future__ import annotations

In [62]:
class Node:
  def __init__(self, value: float) -> None:
    self.name: str = ""
    self.value: float = value

    self.parents: tuple[Node, Node] = (None, None)
    self.gradient: float = 0
    self._backward: function = lambda *args: None

    self.visited = False # makes dfs more efficient
  
  def __add__(self, other: Node | float) -> Node:
    other = other if type(other) is Node else Node(other)
    child = Node(self.value + other.value)
    child.parents = (self, other)
    def _backward(): # accumulates gradients of parents
      self.gradient += 1 * child.gradient # local derivative * global derivative
      other.gradient += 1 * child.gradient
    child._backward = _backward
    return child
  
  def __radd__(self, other: Node | float) -> Node:
    return self.__add__(other)
  
  def __sub__(self, other: Node | float) -> Node:
    other = other if type(other) is Node else Node(other)
    child = Node(self.value - other.value)
    child.parents = (self, other)
    def _backward():
      self.gradient += 1 * child.gradient
      other.gradient += -1 * child.gradient
    child._backward = _backward
    return child
  
  def __mul__(self, other: Node | float) -> Node:
    other = other if type(other) is Node else Node(other)
    child = Node(self.value * other.value)
    child.parents = (self, other)
    def _backward():
      self.gradient += other.value * child.gradient
      other.gradient += self.value * child.gradient
    child._backward = _backward
    return child

  def topological_sort(self) -> list[Node]: # reverse dfs
    dfs_sort: list[Node] = []
    def dfs(node: Node) -> None:
      if node.parents[0]:
        dfs(node.parents[0])
      if node.parents[1]:
        dfs(node.parents[1])
      if not node.visited:
        dfs_sort.append(node); node.visited = True
    dfs(self); return list(reversed(dfs_sort))

  def backward(self) -> None:
    self.gradient = 1 # derivative w.r.t to this node
    for node in self.topological_sort():
      node._backward()
      print(node.name, node.value, node.gradient)

  def zero_grad(self) -> None:
    self.grad = 0
    self.visited = False

  def __repr__(self) -> str:
    return f"Scalar(name={self.name}, value={self.value})"

In [63]:
a = Node(3); a.name = "a"
b = Node(2); b.name = "b"
c = Node(-1); c.name = "c"
d = Node(4); d.name = "d"
e = a * b; e.name = "e"
f = e + c; f.name = "f"
g = f * d; g.name = "g"
h = g + g; h.name = "h"
l = h + a; l.name = "l"
l.backward()

l 43 1
h 40 1
g 20 2
d 4 10
f 5 8
c -1 8
e 6 8
b 2 24
a 3 17


In [30]:
dfs_traversal = []
def dfs(node: Node) -> Node: 
  if node.parents[0]:
    dfs(node.parents[0])
  if node.parents[1]:
    dfs(node.parents[1])

  if not node.visited:
    dfs_traversal.append(node)
    node.visited = True

dfs(l)
print(list(reversed(dfs_traversal)))
# only call a node's _backward() when gradients have been fully accumulated

[Scalar(name=l, value=43), Scalar(name=h, value=40), Scalar(name=g, value=20), Scalar(name=d, value=4), Scalar(name=f, value=5), Scalar(name=c, value=-1), Scalar(name=e, value=6), Scalar(name=b, value=2), Scalar(name=a, value=3)]
