In [3]:
from mlp import MLP



In [4]:
from graphviz import Digraph

def trace(root):
  # builds a set of all nodes and edges in a graph
  nodes, edges = set(), set()
  def build(v):
    if v not in nodes:
      nodes.add(v)
      for child in v._children:
        edges.add((child, v))
        build(child)
  build(root)
  return nodes, edges

def draw_dot(root):
  dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right
  
  nodes, edges = trace(root)
  for n in nodes:
    uid = str(id(n))
    # for any value in the graph, create a rectangular ('record') node for it
    dot.node(name = uid, label = "{ %s | data %.4f | grad %.4f }" % (n.name, n.data, n.grad), shape='record')
    if n._op:
      # if this value is a result of some operation, create an op node for it
      dot.node(name = uid + n._op, label = n._op)
      # and connect this node to it
      dot.edge(uid + n._op, uid)

  for n1, n2 in edges:
    # connect n1 to the op node of n2
    dot.edge(str(id(n1)), str(id(n2)) + n2._op)

  return dot


In [5]:
xs = [
  [2.0, 3.0, -1.0],
  [3.0, -1.0, 0.5],
  [0.5, 1.0, 1.0],
  [1.0, 1.0, -1.0],
]
ys = [1.0, -1.0, -1.0, 1.0] # desired targets

In [6]:
model = MLP(3,[3,3,1])
model

MLP:
0- Layer of [Activation = tanh Neuron(inputs: 3),Activation = tanh Neuron(inputs: 3),Activation = tanh Neuron(inputs: 3)]
1- Layer of [Activation = tanh Neuron(inputs: 3),Activation = tanh Neuron(inputs: 3),Activation = tanh Neuron(inputs: 3)]
2- Layer of [Activation = tanh Neuron(inputs: 3)]

In [9]:
ypred = [model(x) for x in xs] 
ypred

[Scalar(data=0.11140266601210351, grad=0),
 Scalar(data=-0.03737601863388727, grad=0),
 Scalar(data=0.13907986726606306, grad=0),
 Scalar(data=0.062294336776679116, grad=0)]

In [11]:
loss = sum([(yhat-y)**2 for yhat,y in zip(ypred,ys)])
loss

Scalar(data=3.8930450063235034, grad=0)

In [13]:
model.parameters()

[Scalar(data=0.2073793417740526, grad=0),
 Scalar(data=0.11448718051788243, grad=0),
 Scalar(data=0.1803923521705857, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=-0.9507696730371364, grad=0),
 Scalar(data=-0.09362543039259519, grad=0),
 Scalar(data=0.5672261758981587, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=-0.6089887023634686, grad=0),
 Scalar(data=-0.9593376924025712, grad=0),
 Scalar(data=0.9794838998523143, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=0.6917593858863644, grad=0),
 Scalar(data=0.43072767241653365, grad=0),
 Scalar(data=0.7403801441384537, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=-0.18553489141073864, grad=0),
 Scalar(data=-0.3132570346633301, grad=0),
 Scalar(data=0.37883549170815267, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=0.014070907211675365, grad=0),
 Scalar(data=0.24530432622780207, grad=0),
 Scalar(data=0.4593752556675883, grad=0),
 Scalar(data=0, grad=0),
 Scalar(data=0.11225382392770444, grad=0),
 Scalar(data=-0.6857066348917109, gra