In [1]:
from __future__ import annotations

In [2]:
class Value:
  def __init__(self, data: float, father: Value = None, mother: Value = None) -> None:
    self.data = data
    self.father: Value = father
    self.mother: Value = mother
    self.grad = 0 # gradient w.r.t to the target
    self._backward = lambda: None # calculates gradients of parents
    self._zero_grad = lambda: None # zeros out gradients of parents


  def __repr__(self) -> str:
    return f"Value(data={self.data})"
  
  # when backwards called on new node, want to calculate gradients for parents
  # you know the derivative of the parent node wrt to the child node
  
  def __add__(self, other: Value):
    father = self; mother = other
    child = Value(self.data + other.data, father, mother)
    def _backward():
      dchild_dfather = 1
      dchild_dmother = 1
      father.grad += dchild_dfather * child.grad
      mother.grad += dchild_dmother * child.grad
    def _zero_grad():
      father.grad = 0
      mother.grad = 0
    child._backward = _backward
    child._zero_grad = _zero_grad
    return child
  
  def __mul__(self, other: Value):
    father = self; mother = other
    child = Value(father.data * mother.data, father, mother)
    def _backward():
      dchild_dfather = mother.data
      dchild_dmother = father.data
      father.grad += dchild_dfather * child.grad
      mother.grad += dchild_dmother * child.grad
    def _zero_grad():
      father.grad = 0
      mother.grad = 0
    child._backward = _backward
    child._zero_grad = _zero_grad
    return child
  
  def __neg__(self):
    father = self
    child = Value(-father.data, father)
    def _backward():
      dchild_dfather = -1
      father.grad += dchild_dfather * child.grad
    def _zero_grad():
      father.grad = 0
    child._backward = _backward
    child._zero_grad = _zero_grad
    return child

  def backward(self): # defines this node to be the target. derivative of all nodes in graph wrt to this node. 
    # edge cases to consider: 
    # 1. _backward() should only be called when a node's gradient wrt to target is fully computed. i.e. when all it's children have had their _backward's called. EX: a -> b -> c & a -> c. target order is c, b, a but _backward could be called on a before b. b is a's kid. the gradient of a wrt to c is not fully computed.
    # 2. _backward should only be called once per node. need some kind of cycle detection. a node could have multiple children, so it could be visited multiple times. just tracking count won't be good enough. need to consider if there's a duplicate visit from the same path. 

    # could do reverse dfs. root is last to be marked as traversed. if reversing order, it will be called before any of it's children can be called. dfs traverses all children before marking parent as traversed. reversing this means that parent is marked as traversed before any of it's children. this behavior is desirable for topological sorting. from perspective of root, but can be applied to any node b.c. recursive. 
    # a b c -> c b a. regardless of the path that is taken, b's child, a will always be traversed before b, even if a isn't reached through b. reversing order, a's parents will always be traversed before a.
    # adding a traversed variable, so you don't have to check the list. having extra indicator variables save so much computation. 

    dfs_order = []
    def dfs(node: Value):
      if node.father:
        dfs(node.father)
      if node.mother:
        dfs(node.mother)
      dfs_order.append(node)
    dfs(self)

    self.grad = 1 # derivative wrt to self is 1
    topological_order = dfs_order[::-1]

    print(dfs_order)
    print(topological_order)

    for node in topological_order:
      node._backward()

  def zero_grad(self): # zeros out gradients of ancestors
    def dfs(node: Value):
      if node.father:
        dfs(node.father)
      if node.mother:
        dfs(node.mother)
      node._zero_grad()

In [3]:
a = Value(2)
b = Value(3)
c = a * -b
a, b, c

(Value(data=2), Value(data=3), Value(data=-6))

In [4]:
c.grad = 1
c.backward()
a.grad, b.grad, c.grad # only goes one step back. as expected. -b node got calculated, but b node did not. 

[Value(data=2), Value(data=3), Value(data=-3), Value(data=-6)]
[Value(data=-6), Value(data=-3), Value(data=3), Value(data=2)]


(-3, -2, 1)