In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [2]:
class Node:
    def __init__(self, value):
        self.value = np.array(value, dtype=np.float32)
        self.grad = np.zeros_like(self.value, np.float32)
        self.inputs = []

    def zero_grad(self):
        self.grad = np.zeros_like(self.value)
        for input, _ in self.inputs:
            input.zero_grad()

    def backward(self, upstream_grad=None):
        if upstream_grad is None:
            upstream_grad = np.ones_like(self.value, np.float32)

        self.grad = self.grad + upstream_grad
        for input, grad_fn in self.inputs:
            input_grad = grad_fn(upstream_grad)
            input.backward(input_grad)

    def __add__(self, other):
        return Add(self, other)

    def __sub__(self, other):
        return Subtract(self, other)

    def __mul__(self, other):
        return Multiply(self, other)

    def __matmul__(self, other):
        return MatMultiply(self, other)


class Add(Node):
    def __init__(self, a, b):
        super().__init__(a.value + b.value)
        self.inputs = [(a, lambda upstream: upstream), (b, lambda upstream: upstream)]


class Subtract(Node):
    def __init__(self, a, b):
        super().__init__(a.value - b.value)
        self.inputs = [(a, lambda upstream: upstream), (b, lambda upstream: upstream)]


class Multiply(Node):
    def __init__(self, a, b):
        super().__init__(a.value * b.value)
        self.inputs = [
            (a, lambda upstream: upstream * b.value),
            (b, lambda upstream: upstream * a.value),
        ]


class MatMultiply(Node):
    def __init__(self, a, b):
        super().__init__(a.value @ b.value)
        self.inputs = [
            (a, lambda upstream: upstream @ b.value.T),
            (b, lambda upstream: a.value.T @ upstream),
        ]


class ReLU(Node):
    def __init__(self, x):
        super().__init__(np.maximum(0, x.value))
        self.inputs = [
            (x, lambda upstream: upstream * (x.value > 0).astype(np.float32))
        ]


class Transpose(Node):
    def __init__(self, x):
        super().__init__(x.value.T)
        self.inputs = [(x, lambda upstream: upstream.T)]


def mse_loss(y_pred, y_true):
    diff = y_pred - y_true
    return Transpose(diff) @ diff

In [59]:
n_samples = 100000

X = Node(np.random.randn(n_samples, 1) * 2)
y = Node(np.sin(X.value))


hidden_layer_dim = 10
learning_rate = 0.0001
epochs = 10001
batch_size = 100

weights_1 = Node(np.random.randn(X.value.shape[1], hidden_layer_dim))
biases_1 = Node(np.random.randn(1, hidden_layer_dim))
weights_2 = Node(np.random.randn(hidden_layer_dim, 1))
biases_2 = Node(np.random.randn(1, 1))

y_preds = []
for epoch in range(epochs):
    random_indices = np.random.choice(X.value.shape[0], batch_size, replace=False)
    out_1 = ReLU(Node(X.value[random_indices]) @ weights_1 + biases_1)
    out_2 = out_1 @ weights_2 + biases_2
    loss = mse_loss(out_2, Node(y.value[random_indices]))
    loss.zero_grad()
    loss.backward()
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: {loss.value / batch_size}")

    if epoch % 100 == 0:
        out_1 = ReLU(X @ weights_1 + biases_1)
        out_2 = out_1 @ weights_2 + biases_2
        y_preds.append(out_2.value.flatten())

    weights_1.value -= learning_rate * weights_1.grad
    weights_2.value -= learning_rate * weights_2.grad
    biases_1.value -= learning_rate * biases_1.grad.sum(axis=0)
    biases_2.value -= learning_rate * biases_2.grad.sum(axis=0)


Epoch 0: [[10.633076]]
Epoch 1000: [[0.02033566]]
Epoch 2000: [[0.00827521]]
Epoch 3000: [[0.00688277]]
Epoch 4000: [[0.0065132]]
Epoch 5000: [[0.00709296]]
Epoch 6000: [[0.0159478]]
Epoch 7000: [[0.00676529]]
Epoch 8000: [[0.00521723]]
Epoch 9000: [[0.00895854]]
Epoch 10000: [[0.00695995]]


In [60]:
from IPython import display

fig = plt.figure(figsize=(12, 6))
plt.title("Predictions")
plt.grid()
plt.xlim((-4, 4))
plt.ylim((-1.5, 1.5))


plt.scatter(
    X.value.flatten(),
    y.value.flatten(),
    s=0.5,
    color="C0",
    label="true y",
)

line_plotted = plt.scatter(
    X.value.flatten(),
    y_preds[0],
    s=0.5,
    color="C1",
    label="pred y",
)
plt.legend()


def animation_function(frame):
    data = np.stack((X.value.flatten(), y_preds[frame])).T
    line_plotted.set_offsets(data)
    plt.title(f"Epoch {frame * 100}")


anim_created = FuncAnimation(fig, animation_function, frames=len(y_preds), interval=100)
video = anim_created.to_html5_video()
html = display.HTML(video)
display.display(html)

plt.close()
