In [None]:
class Node:
    def __init__(self, value, _parents=(), _op=""):
        self.value = value
        self.dloss_dvalue = 0
        self._parents = _parents 
        self._op = _op 

    def __repr__(self):
        return f"Node(value={self.value})"

    def __add__(self, other):
        new_node = Node(
            value=self.value + other.value,
            _parents=(self, other),
            _op="add"
        )
        return new_node

    def __mul__(self, other):
        new_node = Node(
            value=self.value * other.value,
            _parents=(self, other),
            _op="mul"
        )
        return new_node  

    def relu(self):
        new_node = Node(
            value=max(0, self.value),
            _parents=(self,),
            _op="relu"
        )
        return new_node

    @staticmethod
    def update_parent_derivatives(node):
        """
        Propagate the loss gradients of the current child node
        to the loss gradients of both its parents
        """

        if node._op == "add":
            parent1, parent2 = node._parents
            parent1.dloss_dvalue += node.dloss_dvalue
            parent2.dloss_dvalue += node.dloss_dvalue
            
        elif node._op == "mul":
            parent1, parent2 = node._parents
            parent1.dloss_dvalue += node.dloss_dvalue * parent2.value
            parent2.dloss_dvalue += node.dloss_dvalue * parent1.value

        elif node._op == "relu":
            (parent,) = node._parents
            dnodevalue_dparentvalue = 1.0 if node.value > 0 else 0.0
            parent.dloss_dvalue += node.dloss_dvalue * dnodevalue_dparentvalue


    @staticmethod
    def build_visit_order(root_node):
        """
        Depth-first search of parents, appending the current node LAST
        
        Note that we build by OPERATION, not by LAYER 
        (e.g. one layer might have an add and a ReLU operation - so, two sets of parents)
        """

        visited = set()
        visit_order = []

        def visit_parents(node):
            """Recursive visit function"""
            if node in visited:
                return
            else:
                visited.add(node)
                for parent in node._parents:
                    visit_parents(parent)
                visit_order.append(node)
                return

        visit_parents(root_node)
        return visit_order
    
    def calculate_derivatives(self):
        """Not DFS - just a linear pass over an ordered list"""

        parents_visit_list = reversed(Node.build_visit_order(self))
        self.dloss_dvalue = 1 # assumes that we only call calculate_derivatives from the root node

        for node in parents_visit_list:
            if not node._parents:
                continue 
            else:
                Node.update_parent_derivatives(node)
    

In [8]:
x = Node(value=1.0)
y = Node(value=2.0)
z = x + y
z

Node(value=3.0)

In [9]:
z.calculate_derivatives()
print([node.dloss_dvalue for node in (x, y, z)])

[1, 1, 1]
