In [2]:
import matplotlib
import numpy as np
import sklearn
import scipy
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [3]:
def f(x):
    return jnp.where(x < 0, 5 + jnp.sum(jnp.sin(jnp.arange(1,5) * x[:, None]), axis=1), jnp.cos(10*x))

In [4]:
def generate_data(n_train=200, n_test=500):
    xs_train = jnp.linspace(-jnp.pi, jnp.pi, n_train)
    ys_train = f(xs_train)
    xs_test = jnp.linspace(-jnp.pi, jnp.pi, n_test)
    ys_test = f(xs_test)
    return xs_train, ys_train, xs_test, ys_test

In [5]:
def relu(x):
    return jnp.maximum(0, x)

def init_network_params(layer_widths, key):
    params = []
    keys = random.split(key, len(layer_widths))
    for i, (m, n) in enumerate(zip(layer_widths[:-1], layer_widths[1:])):
        key, subkey = random.split(keys[i])
        params.append((random.normal(key, (n, m)) * jnp.sqrt(2/m), random.normal(subkey, (n, 1)) * jnp.sqrt(2/m)))
    return params

def predict(params, x):
    # Ensure x is a 2D array with each input as a row
    x = x.reshape(-1, 1) if x.ndim == 1 else x

    activations = x
    # Apply computations for the first hidden layer
    for w, b in params[:-2]:  # Iterate through all but the last layer
        activations = jnp.dot(activations, w.T) + b.reshape(-1)
        activations = relu(activations)

    # Apply computations for the last layer (output layer)
    final_w, final_b = params[-1]
    y_pred = jnp.dot(activations, final_w.T) + final_b.reshape(-1)

    return y_pred.flatten() if x.shape[1] == 1 else y_pred


In [6]:
def mse_loss(params, xs, ys):
    preds = predict(params, xs)
    return jnp.mean((preds - ys)**2)

@jit
def update(params, xs, ys, lr=0.01):
    grads = grad(mse_loss)(params, xs, ys)
    return [(w - lr * dw, b - lr*db) for (w, b), (dw, db) in zip(params, grads)]

In [7]:
def l2_relative_error(params, xs, ys):
    preds = predict(params, xs)
    return jnp.sqrt(jnp.sum((preds-ys) ** 2)) / jnp.sqrt(jnp.sum(ys ** 2))

In [12]:
def train_network(params, xs, ys, n_epochs=1000, lr=0.01):
    for epoch in range(n_epochs):
        params = update(params, xs, ys, lr)
    return params

def test_network(params, xs_test, ys_test):
    return l2_relative_error(params, xs_test, ys_test)

def run_experiment(widths, n_train=200, n_test=500, n_epochs=20000, lr=0.0005, n_runs=3):
    xs_train, ys_train, xs_test, ys_test = generate_data(n_train, n_test)
    errors = {width: [] for width in widths}
    
    for width in widths:
        for run in range(n_runs):
            key = random.PRNGKey(run)
            layer_widths = [1, width, width, 1] # input layer, 2 hidden layers, output layer
            params = init_network_params(layer_widths, key)
            params = train_network(params, xs_train, ys_train, n_epochs, lr)
            error = test_network(params, xs_test, ys_test)
            errors[width].append(error)
    
    for width in widths:
        mean_error = np.mean(errors[width])
        std_error = np.std(errors[width])
        print(f"Width: {width}, Mean L^2 Relative Error: {mean_error}, Std Dev: {std_error}")

run_experiment(widths=[10, 30, 100, 300, 1000])

KeyboardInterrupt: 