In [None]:
import numpy as np
import matplotlib.pyplot as plt

custom_params = {
#    'figure.figsize': (4, 1.2),  # Width, Height in inches
#    'font.size': 12,           # Default font size
    'axes.grid': True,         # Always show grid
}
plt.rcParams.update(custom_params)

In [None]:
#| echo: false 
sizes = np.array([50.0, 75.5, 99.3, 149.8, 190.4, 200.8, 200.0, 300.0])
prices = np.array([83.3, 125.3, 189.2, 295.1, 644.0, 660.8, 693.6, 1189.5])

# Linear fit
coef = np.polyfit(sizes, prices, 2)

# Plot
plt.figure(figsize=(6, 4))
plt.title("Real estate application")
plt.scatter(sizes, prices, label="Data")

x_line = np.linspace(sizes.min(), sizes.max(), 100)
y_line = np.polyval(coef, x_line)
plt.plot(x_line, y_line, color="black", label="Model")

plt.xlabel("House size ($\\mathrm{m}^2$)")
plt.ylabel("Price [kEUR]")
plt.legend()
plt.tight_layout()

In [None]:
#| echo: false
n_x = 2 # number of inputs
n_y = 1 # number of outputs
a = -2.0 # lower bound x_1/x_2
b = 2.0 # upper bound x_1/x_2
n_samples = 500 # number of samples in the training/test datasets
sigma_e = 0.1 # standard deviation of the noise
grid_points = 100 # number of points in the grid for the plot


def f(x):
    return 2*np.sin(x[..., 0])  - 3*np.cos(x[..., 1]) # ellipses used to handle an optional "batch" dimension
# f(np.tensor([0.2, 0.4])), 2*np.sin(0.2) - 3*np.cos(0.4) # test


x1_train = a + np.random.rand(n_samples)*(b - a)
x2_train = a + np.random.rand(n_samples)*(b - a)
X_train = np.stack([x1_train, x2_train], axis=-1)

y_train = f(X_train) + sigma_e * np.random.randn(n_samples)
y_train = y_train.reshape(-1, 1)
X_train.shape, y_train.shape 

x1_test = a + np.random.rand(n_samples)*(b - a)
x2_test = a + np.random.rand(n_samples)*(b - a)
X_test = np.stack([x1_test, x2_test], axis=-1)

y_test = f(X_test) + sigma_e * np.random.randn(n_samples)
y_test = y_test.reshape(-1, 1)
X_test.shape, y_test.shape 

## visualization
x1_grid = np.linspace(a, b, grid_points)
x2_grid = np.linspace(a, b, grid_points)
X1_mesh, X2_mesh = np.meshgrid(x1_grid, x2_grid)
X_grid = np.c_[X1_mesh.ravel(), X2_mesh.ravel()]
y_grid = f(X_grid)

fig = plt.figure(figsize=(5, 5))
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X1_mesh, X2_mesh, y_grid.reshape(100, 100), cmap='coolwarm', alpha=0.7)#, edgecolor='none')
ax1.set_xlabel("$x_1$")
ax1.set_ylabel("$x_2$")
ax1.scatter(X_train[:, 0], X_train[:, 1], y_train[:, 0], color='k', s=10)
ax1.legend()
ax1.set_zlabel("$y$")
plt.show()

In [None]:
# Initialize all parameters and organize them in a dictionary

import jax
import jax.random as jr
import jax.numpy as jnp
import optax

key = jr.key(4)
key_W1, key_b1, key_W2, key_b2, key_W3, key_b3 = jr.split(key, 6)
nx = 2; ny = 1; nh = 16
hidden_size = [16, 8]

params = {
  "W1": jr.normal(key_W1, shape=(hidden_size[0], nx)),
  "b1": jr.normal(key_b1, shape=(hidden_size[0],)),
  "W2": jr.normal(key_W2, shape=(hidden_size[1], hidden_size[0])),
  "b2": jr.normal(key_b2, shape=(hidden_size[1],)),
  "W3": jr.normal(key_W3, shape=(ny, hidden_size[1])),
  "b3": jr.normal(key_b3, shape=(ny,)),
}

def neural_net(params, x):
    h1 = jnp.tanh(jnp.dot(params["W1"], x) + params["b1"])
    h2 = jnp.tanh(jnp.dot(params["W2"], h1) + params["b2"])
    y = jnp.dot(params["W3"], h2) + params["b3"]
    return y

In [None]:
neural_net(params, X_train[0]).shape

In [None]:
batched_neural_net = jax.vmap(neural_net, in_axes=(None, 0))

In [None]:
batched_neural_net(params, X_train).shape

In [None]:
def loss_fn(params, y, x):
    y_pred = batched_neural_net(params, x)
    return jnp.mean((y - y_pred) ** 2)

loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn, 0))

In [None]:
loss_grad_fn(params, y_train, X_train)[0]

In [None]:
lr = 1e-2
iters = 5000

# Setup optimizer
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)

# Training loop
LOSS = []
for iter in range(iters):
    loss_val, grads = loss_grad_fn(params, y_train, X_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    LOSS.append(loss_val)

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(LOSS);

In [None]:
y_test_pred = batched_neural_net(params, X_test)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].set_title("Test residual plot")
ax[0].plot(y_test, y_test_pred, 'C0o')
ax[0].plot(y_test, y_test - y_test_pred, 'ro')
ax[1].set_title("Test residuals histogram")
ax[1].hist(y_test - y_test_pred);

In [None]:
y_pred_grid = batched_neural_net(params, X_grid).squeeze(-1)

In [None]:
# 3D Plot of True Function
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X1_mesh, X2_mesh, y_grid.reshape(grid_points, grid_points), cmap='coolwarm', edgecolor='none')
ax1.set_title("True Function")
ax1.set_xlabel("$x_1$")
ax1.set_ylabel("$x_2$")
ax1.set_zlabel("y")

# 3D Plot of NN Predictions
ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_surface(X1_mesh, X2_mesh, y_pred_grid.reshape(grid_points, grid_points), cmap='coolwarm', edgecolor='none')
ax2.set_title("NN Function")
ax2.set_xlabel("$x_1$")
ax2.set_ylabel("$x_2$")
ax2.set_zlabel("y");