In [7]:
from graphviz import Digraph
from autograd import Value
from MLP import Neuron, Layer, MLP

"""
This is mainly copied from Karpathy's micrograd!
"""
# Function to trace the computational graph
def trace(root):
    nodes, edges = set(), set()
    
    def build(v):
        if v is not None and v not in nodes:
            nodes.add(v)
            if v.leftChild is not None:
                edges.add((v.leftChild, v))
                build(v.leftChild)
            if v.rightChild is not None:
                edges.add((v.rightChild, v))
                build(v.rightChild)
    
    build(root)
    return nodes, edges

# Function to draw the computational graph
def draw_dot(root, format='svg', rankdir='LR'):
    """
    format: png | svg | ...
    rankdir: TB (top to bottom graph) | LR (left to right)
    """
    assert rankdir in ['LR', 'TB']
    nodes, edges = trace(root)
    dot = Digraph(format=format, graph_attr={'rankdir': rankdir})
    
    # Create nodes in the dot graph
    for n in nodes:
        # Create the main node for Value
        dot.node(name=str(id(n)), label="{ data %.4f | grad %.4f }" % (n.value, n.gradient), shape='record')
        # If this node is the result of an operation, create an op node as well.
        if n.operator is not None:
            op_node = str(id(n)) + n.operator
            dot.node(name=op_node, label=n.operator)
            dot.edge(op_node, str(id(n)))
    
    # Create edges between nodes
    for n1, n2 in edges:
        if n2.operator is not None:
            dot.edge(str(id(n1)), str(id(n2)) + n2.operator)
        else:
            dot.edge(str(id(n1)), str(id(n2)))
    
    return dot

In [8]:
train_data = [1, 1, 1, 1, 1]
Neural_Network = MLP(5, [10, 5, 5])
print(Neural_Network)

output = Neural_Network.forward(train_data)
Neural_Network.train(train_data, 1, lr = 1)

Total Layers: 3
Epoch: 0 
 Loss: Value(0.2980495458631124)


In [None]:

# Visualize the computational graph
dot = draw_dot(output[0])  # Visualize the first output neuron
dot.render('mlp_graph', format='pdf')  # Save the graph as an SVG file



'mlp_graph.svg'