In [None]:
!pip install flax




In [None]:
!pip install --upgrade jax jaxlib




In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, value_and_grad
import flax.linen as nn
from sklearn.model_selection import train_test_split
import numpy as np
import optax

# Data generation function
def generate_data(num_samples=1000):
    x = np.random.uniform(-2, 2, size=(num_samples, 3))
    y = np.sin(x[:, 0]) + np.cos(x[:, 1]) ** 2 + np.sin(x[:, 2] ** 2)
    return x, y

x, y = generate_data()
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# Neural Network Model
class ThreeLayerNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x

# Initialize model
rng = random.PRNGKey(0)
model = ThreeLayerNet()
params = model.init(rng, jnp.ones((1, 3)))['params']

# Training function
def train(params, x_train, y_train, epochs=1000, lr=0.01):
    # Define optimizer
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    # Loss function
    def loss_fn(params, x, y):
        preds = model.apply({'params': params}, x)
        return jnp.mean((preds - y) ** 2)

    # Training loop
    for epoch in range(epochs):
        loss, grads = value_and_grad(loss_fn)(params, x_train, y_train)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if epoch % 100 == 0:
            print(f"Epoch {epoch}: Loss {loss}")

    return params

# Convert data to JAX arrays
x_train, y_train, x_test, y_test = map(jnp.array, [x_train, y_train, x_test, y_test])

# Train the model
final_params = train(params, x_train, y_train)

# Evaluate the model
preds = model.apply({'params': final_params}, x_test)
loss = jnp.mean((preds - y_test) ** 2)
print(f"Test Loss: {loss}")


Epoch 0: Loss 2.766568899154663
Epoch 100: Loss 1.0411865711212158
Epoch 200: Loss 1.0308400392532349
Epoch 300: Loss 1.0302321910858154
Epoch 400: Loss 1.0301039218902588
Epoch 500: Loss 1.030031681060791
Epoch 600: Loss 1.0299917459487915
Epoch 700: Loss 1.0299668312072754
Epoch 800: Loss 1.029952049255371
Epoch 900: Loss 1.029942274093628
Test Loss: 0.8185634016990662
