In [15]:
# pip install pygraphviz
import pygraphviz as pgv

In [16]:
from micrograd.engine import Value

In [17]:
def trace(root):
    nodes, edges = set(), set()
    
    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
                
    build(root)
    return nodes, edges

def draw_graph(root, format='svg', rankdir='LR'):
    """
    format: png | svg | ...
    rankdir: TB (top to bottom graph) | LR (left to right)
    """
    assert rankdir in ['LR', 'TB']
    nodes, edges = trace(root)

    G = pgv.AGraph(directed=True, rankdir=rankdir)

    for n in nodes:
        label = "{ data %.4f | grad %.4f }" % (n.data, n.grad)
        G.add_node(str(id(n)), label=label, shape='record')
        if n._op:
            G.add_node(str(id(n)) + n._op, label=n._op)
            G.add_edge(str(id(n)) + n._op, str(id(n)))

    for n1, n2 in edges:
        G.add_edge(str(id(n1)), str(id(n2)) + n2._op)

    output_filename = f"graph.{format}"
    G.draw(output_filename, format=format, prog='dot')
    print(f"Graph saved as {output_filename}")



In [18]:
# a very simple example
x = Value(1.0)
y = (x * 2 + 1).relu()
y.backward()
draw_graph(y)

Graph saved as graph.svg


In [21]:
# a simple 2D neuron
import random
from micrograd import nn

random.seed(1337)
n = nn.Neuron(2)
x = [Value(1.0), Value(-2.0)]
y = n(x)
y.backward()

draw_graph(y)

Graph saved as graph.svg
