In [1]:
import matplotlib.pyplot as plt
import numpy as np
import math

In [2]:
from graphviz import Digraph

def trace(root):
    # builds a set of all nodes and edges in a graph
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges


def draw_dot(root):
    dot = Digraph(format="png", graph_attr={"rankdir":"LR"})  # LR = left to right

    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # for any value in the graph, create a rectangular ('record') node for it
        dot.node(name=uid, label="{%s | data %.4f | grad %.4f}" % (n.label, n.data, n.grad), shape='record')
        if n._operation:
            # if this value is a result of some operation, create avn op node for it
            dot.node(name=uid + n._operation, label=n._operation)
            # and connect this node to it
            dot.edge(uid + n._operation, uid)

    for n1, n2 in edges:
        # connect n1 to the op node of n2
        dot.edge(str(id(n1)), str(id(n2)) + n2._operation)

    return dot

In [52]:
class Value:
    
    def __init__(self, data, _children = (), _operation = "", label = ""):
        self.data = data
        self._prev = set(_children)
        self._operation = _operation
        self.label = label
        self.grad = 1.0
        self._backward = lambda : None
        
    def __repr__(self):
        result = f"Value(data = {self.data})"
        return result
    
    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        addition = self.data + other.data
        children = (self, other)
        operation = " + "
        result = Value(data = addition, _children = children, _operation = operation)
        def _backward():
            self.grad += 1.0 * result.grad
            other.grad += 1.0 * result.grad
        result._backward = _backward
        return result
    
    def __radd__(self, other):
        result = self + other
        return result
    
    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        multiplication = self.data * other.data
        children = (self, other)
        operation = " x "
        result = Value(data = multiplication, _children = children, _operation = operation)
        def _backward():
            self.grad += other.data * result.grad
            other.grad += self.data * result.grad
        result._backward = _backward        
        return result
    
    def __rmul__(self, other):
        result = self * other
        return result
    
    def __pow__(self, other):
        assert isinstance(other, (int, float)), "Only integer or float dtype numbers accepted"
        power = self.data ** other
        children = (self, )
        operation = " pow "
        result = Value(data = power, _children = children, _operation = operation)
        def _backward():
            self.grad += other * self.data ** (other - 1) * result.grad
        result._backward = _backward        
        return result
    
    def __neg__(self):
        negotiator = -1.0
        result = self.data * negotiator
        return result
    
    def __sub__(self, other):
        result = self.data + (-other.data)
        return result
    
    def __truediv__(self, other):
        power = -1.0
        result = self * other ** power
        return result
    
    def exp(self):
        e = math.exp(self.data)
        children = (self, )
        operation = " e "
        result = Value(data = e, _children = children, _operation = operation)
        def _backward():
            self.grad += result.grad * result.data
        result._backward = _backward
        return result
    
    def tanh(self):
        e = math.exp(2 * self.data)
        tanh_n = (e - 1.0)  / (e + 1.0)
        children = (self, )
        operation = " tanh(x) "
        result = Value(data = tanh_n, _children = children, _operation = operation)
        def _backward():
            self.grad += (1 - tanh_n ** 2) * result.grad
        result._backward = _backward
        return result
    
    def backward(self):
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
            for child in v._prev:
                build_topo(child)
            topo.append(v)
        build_topo(self)
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()