In [1]:
import numpy as np

In [2]:
class Node:
    def __init__(self, value):
        self.value = np.array(value, dtype=np.float32)
        self.grad = np.zeros_like(self.value, np.float32)
        self.inputs = []

    def zero_grad(self):
        self.grad = np.zeros_like(self.value)
        for input, _ in self.inputs:
            input.zero_grad()

    def backward(self, upstream_grad=None):
        if upstream_grad is None:
            upstream_grad = np.ones_like(self.value, np.float32)

        self.grad = self.grad + upstream_grad
        for input, grad_fn in self.inputs:
            input_grad = grad_fn(upstream_grad)
            input.backward(input_grad)

    def __add__(self, other):
        return Add(self, other)

    def __sub__(self, other):
        return Subtract(self, other)

    def __mul__(self, other):
        return Multiply(self, other)

    def __matmul__(self, other):
        return MatMultiply(self, other)


class Add(Node):
    def __init__(self, a, b):
        super().__init__(a.value + b.value)
        self.inputs = [(a, lambda upstream: upstream), (b, lambda upstream: upstream)]


class Subtract(Node):
    def __init__(self, a, b):
        super().__init__(a.value - b.value)
        self.inputs = [(a, lambda upstream: upstream), (b, lambda upstream: upstream)]


class Multiply(Node):
    def __init__(self, a, b):
        super().__init__(a.value * b.value)
        self.inputs = [
            (a, lambda upstream: upstream * b.value),
            (b, lambda upstream: upstream * a.value),
        ]


class MatMultiply(Node):
    def __init__(self, a, b):
        super().__init__(a.value @ b.value)
        self.inputs = [
            (a, lambda upstream: upstream @ b.value.T),
            (b, lambda upstream: a.value.T @ upstream),
        ]


class ReLU(Node):
    def __init__(self, x):
        super().__init__(np.maximum(0, x.value))
        self.inputs = [
            (x, lambda upstream: upstream * (x.value > 0).astype(np.float32))
        ]

class Transpose(Node):
    def __init__(self, x):
        super().__init__(x.value.T)
        self.inputs = [
            (x, lambda upstream: upstream.T)
        ]


def mse_loss(y_pred, y_true):
    diff = y_pred - y_true
    return Transpose(diff) @ diff

In [36]:
n_samples = 100000

X = Node(np.random.randn(n_samples, 2))
y = Node((np.cos(X.value[:, 0]) + np.sin(X.value[:, 1]))[:, None])


hidden_layer_dim = 10
learning_rate = 0.0001
epochs = 10001
batch_size = 100

weights_1 = Node(np.random.randn(X.value.shape[1], hidden_layer_dim))
biases_1 = Node(np.random.randn(1, hidden_layer_dim))
weights_2 = Node(np.random.randn(hidden_layer_dim, 1))
biases_2 = Node(np.random.randn(1, 1))
for epoch in range(epochs):
    random_indices = np.random.choice(X.value.shape[0], batch_size, replace=False)
    out_1 = ReLU(Node(X.value[random_indices]) @ weights_1 + biases_1)
    out_2 = out_1 @ weights_2 + biases_2
    loss = mse_loss(out_2, Node(y.value[random_indices]))
    loss.zero_grad()
    loss.backward()
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: {loss.value / batch_size}")
    
    weights_1.value -= learning_rate * weights_1.grad
    weights_2.value -= learning_rate * weights_2.grad
    biases_1.value -= learning_rate * biases_1.grad.sum(axis=0)
    biases_2.value -= learning_rate * biases_2.grad.sum(axis=0)


Epoch 0: [[24.314356]]
Epoch 1000: [[0.06819037]]
Epoch 2000: [[0.02172773]]
Epoch 3000: [[0.01131087]]
Epoch 4000: [[0.01069991]]
Epoch 5000: [[0.01001856]]
Epoch 6000: [[0.0124352]]
Epoch 7000: [[0.01677421]]
Epoch 8000: [[0.00629853]]
Epoch 9000: [[0.01197684]]
Epoch 10000: [[0.01631875]]


In [37]:
test_x = Node(np.random.randn(5, 2))
out_1 = ReLU(test_x @ weights_1 + biases_1)
out_2 = out_1 @ weights_2 + biases_2
out_2.value


array([[ 1.0936732 ],
       [ 1.9834421 ],
       [ 0.33278322],
       [-0.3883252 ],
       [ 0.18915081]], dtype=float32)

In [38]:
(np.cos(test_x.value[:,0]) + np.sin(test_x.value[:,1]))[:, None]

array([[ 1.1357993 ],
       [ 1.9952853 ],
       [ 0.29267114],
       [-0.43164307],
       [ 0.23005635]], dtype=float32)