In [1]:
import numpy
import jax
from jax import grad, jit, vmap
import jax.numpy as np
from jax import random
from jax.ops import index, index_add, index_update

In [9]:
def init_network_params_ones(sizes):
    return [
        (numpy.ones((n, m)), numpy.ones((n, 1)))
        for m, n in zip(sizes[:-1], sizes[1:])
    ]

def make_batch(N):
    X = numpy.random.rand(3,N)
    Y = numpy.sin(X*12)
    return X, Y

# init_network_params_ones([3,20,20,20,3])
# make_batch(100)

In [None]:
@jit
def predict(params, x):
    # per-example predictions
    for w, b in params[:-1]:
        x = relu(np.dot(w, x) + b)
    w, b = params[-1]
    return np.dot(w, x) + b

@jit
def loss(params, x, y):
    return np.sum((y - predict(params, x)) ** 2)

@jit
def update(params, data, LR):
    grads = grad(loss)(params, data)
    return [(w - LR * dw, b - LR * db) for (w, b), (dw, db) in zip(params, grads)]


In [None]:
for i in range(10):
    X, Y = make_batch(100)

    l = loss(params, X, Y).item()
    loss_list.append(l)
    if np.isnan(l):
        print("loss is nan")
        break

    LR = max(LR_0 * decay ** i, LR_min)

    params = update(params, data, LR)
    print(f"batch {i}, loss {l}, LR {LR}", end="\r")

    if (i + 1) % epoch_size == 0:
        # print("epoch")
        param_list.append(params)
        if plot_fn is not None:
            # print("should plot")
            plot_fn(params, loss_list, i)