In [1]:
class Tensor:
    def __init__(self, value, grad_fn=None, grad=0, requires_grad=False, parents=None, children=None):
        self.value = value
        self.grad_fn = grad_fn
        self.grad = grad
        self.requires_grad = requires_grad
        self.parents = parents or []
        self.children = children or []

    def backward(self, *output_grads):
        self.grad = self.grad_fn(*output_grads)

In [2]:
def backward_from(start_node):
    node_set = {start_node}
    done = set()
    while node_set:
        node = node_set.pop()
        if node.grad_fn is not None:
            output_grads = tuple(c.grad for c in node.children)
            node.grad = node.grad_fn(*output_grads)
        else:
            node.grad = 1.

        done.add(node)

        for parent in node.parents:
            if parent.requires_grad and all(c in done for c in parent.children):
                node_set.add(parent)

In [3]:
m = Tensor(2, requires_grad=True)
x = Tensor(1)
b = Tensor(-1, requires_grad=True)
y = Tensor(4)

In [4]:
v_1 = Tensor(m.value * x.value, requires_grad=True)
m.children.append(v_1)
x.children.append(v_1)
v_1.parents.append(m)
v_1.parents.append(x)
m.grad_fn = lambda ddv: ddv * x.value

In [5]:
y_hat = Tensor(v_1.value + b.value, requires_grad=True)
v_1.children.append(y_hat)
b.children.append(y_hat)
y_hat.parents.append(v_1)
y_hat.parents.append(b)
v_1.grad_fn = lambda ddy: ddy * 1
b.grad_fn = lambda ddy: ddy * 1

In [6]:
v_2 = Tensor(y_hat.value - y.value, requires_grad=True)
y_hat.children.append(v_2)
y.children.append(v_2)
v_2.parents.append(y_hat)
v_2.parents.append(y)
y_hat.grad_fn = lambda ddv: ddv

In [7]:
loss = Tensor(v_2.value ** 2, requires_grad=True)
v_2.children.append(loss)
loss.parents.append(v_2)
v_2.grad_fn = lambda ddl: ddl * 2 * v_2.value

In [8]:
loss.value

9

In [38]:
v_2.grad = v_2.grad_fn(1)
y_hat.grad = y_hat.grad_fn(v_2.grad)
b.grad = b.grad_fn(y_hat.grad)
v_1.grad = v_1.grad_fn(y_hat.grad)
m.grad = m.grad_fn(v_1.grad)

In [39]:
m.grad

-6

In [9]:
backward_from(loss)

In [10]:
m.grad

-6.0