# Libraries

In [2]:
import jax
import jax.numpy as jnp
from flax import linen as nn # Flax's neural network API
from jax import random

# Dense Network (Fully Connected Layer)

In [16]:
from jax.nn.initializers import constant as jnp_constant

class SimpleDenseNetwork(nn.Module):
    # Defining the number of features (neurons) for our layer
    features: int

    @nn.compact
    def __call__(self, x):
        # nn.Dense is Flax's fully connected layer
        # jnp_constant(1.0) will make all biases start at 1.0
        x = nn.Dense(
            features=self.features, bias_init=jnp_constant(1.0)
        )(x)
        return x

In [14]:
# Create an instance of our model
model = SimpleDenseNetwork(features=10)

# Generate a random key for initialization
# Using a fixed seed for reproducibility
key = random.PRNGKey(0)

# Create a dummy input array
# 1 sample, 5 input features
dummy_input = jnp.ones((1, 5))

# Initialize the model's parameters
params = model.init(key, dummy_input)['params']

print("Model parameters:")
print(jax.tree.map(lambda x: x.shape, params))

Model parameters:
{'Dense_0': {'bias': (10,), 'kernel': (5, 10)}}


In [17]:
# Create some actual input data
# Batch of 2 samples, each with 5 features
input_data = jnp.array([
    [1.0, 2.0, 3.0, 4.0, 5.0],
    [6.0, 7.0, 8.0, 9.0, 10.0],
])
print(f"Input data shape: {input_data.shape}")

# Perform the forward pass
output_data = model.apply({'params': params}, input_data)

print("\nOutput data:")
print(output_data)
print(f"\nOutput data shape: {output_data.shape}")

Input data shape: (2, 5)

Output data:
[[  2.534062    -1.966347    -2.2287571    4.9156427    3.616924
    0.08119702   1.7199428   -5.069695     5.1260605    0.6473913 ]
 [ -0.8696096   -3.6501474  -10.47789      8.267586     8.88211
    0.43612492   2.7938604  -14.02945     10.69361      2.9485097 ]]

Output data shape: (2, 10)


In [12]:
print("--- Weights (kernel) ---")
print(params['Dense_0']['kernel'])
print(f"Shape: {params['Dense_0']['kernel'].shape}")

print("\n--- Biases ---")
print(params['Dense_0']['bias'])
print(f"Shape: {params['Dense_0']['bias'].shape}")

--- Weights (kernel) ---
[[-0.7387071   0.33744565 -0.61614853 -0.38103622  0.4111639   0.5740531
   0.4669155  -0.02079375  0.18908004  0.34764266]
 [-0.81367314  0.38204673 -0.31585133  0.22778194  0.54632616 -0.12296208
  -0.6633667  -0.59038585 -0.31810617 -0.01491392]
 [-0.15151945 -0.2792575  -0.94288903 -0.24101867 -0.5139249  -0.27185342
   0.24369092 -0.32793495  0.5703741   0.5477993 ]
 [ 0.76115257 -0.6548614   0.27755     0.75913733  0.39247712 -0.1098888
  -0.01096913 -0.37985826  0.49873945  0.21229894]
 [ 0.2620127  -0.12213357 -0.05248778  0.30552435  0.2169948   0.00163671
   0.1785129  -0.47297826  0.17342253 -0.6326034 ]]
Shape: (5, 10)

--- Biases ---
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Shape: (10,)


# Activation Function

In [18]:
# Define a Multi-Layer Perceptron (MLP)
class MLP(nn.Module):
    # Number of neurons in the hidden layer
    features_hidden: int

    # Number of neurons in the output layer
    features_output: int

    @nn.compact
    def __call__(self, x):
        # 1. Hidden Layer
        # Apply the first Dense Layer
        x = nn.Dense(features=self.features_hidden)(x)
        # Apply a ReLU activation function
        x = nn.relu(x)

        # 2. Output Layer
        # Apply the second Dense Layer (output layer)
        x = nn.Dense(features=self.features_output)(x)

        return x

In [19]:
# Create an instance of our MLP model
# 32 neurons in the hidden layer
# 10 layers in the output (e.g., for 10 classes)
mlp_model = MLP(features_hidden=32, features_output=10)

# Generate a random key for initialization (seed - reproducibility)
key = random.PRNGKey(132)

# Dummy input (batch of 1 sample, 5 input features)
dummy_input = jnp.ones((1, 5))

# Initialize the MLP's parameters
mlp_params = mlp_model.init(key, dummy_input)['params']

print("MLP Model parameters:")
print(jax.tree.map(lambda x: x.shape, mlp_params))

# --- Perform a forward pass with the new data ---
input_data_mlp = jnp.array([
    [1.0, 2.0, 3.0, 4.0, 5.0],
    [-1.0, -7.0, -8.0, 9.0, 10.0],
])

print(f"\n Input data shape: {input_data_mlp.shape}")

# Apply the model with its parameters and the input data
output_data_mlp = mlp_model.apply({'params': mlp_params}, input_data_mlp)

print("\nOutput data from MLP:")
print(output_data_mlp)
print(f"\nOutput data shape: {output_data_mlp.shape}")

MLP Model parameters:
{'Dense_0': {'bias': (32,), 'kernel': (5, 32)}, 'Dense_1': {'bias': (10,), 'kernel': (32, 10)}}

 Input data shape: (2, 5)

Output data from MLP:
[[ 0.10567813 -0.68505734  2.253239   -2.7651613  -1.3348109  -0.9340192
   0.35664153 -1.9763117  -1.1373323   2.0080562 ]
 [ 3.539614   -6.495874   -0.8235873   2.8437166   0.9035323  -6.36709
   1.1798904  -5.179701    1.3909954   7.832258  ]]

Output data shape: (2, 10)


# Loss Function

In [20]:
# A simple MSE loss function
def mse_loss(predictions, targets):
    return jnp.mean((predictions - targets)**2)

# Example usage
dummy_predictions = jnp.array([1.0, 2.0, 3.0])
dummy_targets = jnp.array([1.1, 1.9, 3.2])

print(f"Dummy predictions: {dummy_predictions}")
print(f"Dummy targets: {dummy_targets}")
print(f"MSE Loss: {mse_loss(dummy_predictions, dummy_targets)}")

Dummy predictions: [1. 2. 3.]
Dummy targets: [1.1 1.9 3.2]
MSE Loss: 0.02000000886619091


# Optimizer

In [24]:
import optax

# Define the optimizer
learning_rate = 0.001
optimizer = optax.adam(learning_rate)

# This state is tied to our model's initial parameters
opt_state = optimizer.init(mlp_params)

print(f"\nOptimizer state:\n")
print(jax.tree.map(lambda x:x.shape, opt_state))


Optimizer state:

(ScaleByAdamState(count=(), mu={'Dense_0': {'bias': (32,), 'kernel': (5, 32)}, 'Dense_1': {'bias': (10,), 'kernel': (32, 10)}}, nu={'Dense_0': {'bias': (32,), 'kernel': (5, 32)}, 'Dense_1': {'bias': (10,), 'kernel': (32, 10)}}), EmptyState())


# Training

In [28]:
@jax.jit
def train_step(params, opt_state, batch_inputs, batch_targets):
    # This is the function we want to differentiate
    def loss_fn(curr_params):
        pred = mlp_model.apply({'params': curr_params}, batch_inputs)
        loss = mse_loss(predictions, batch_targets)
        return loss

    # Use value_and_grad to get both loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(params)

    # Apply gradients to update the optimizer state and parameters
    updt, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updt)

    return new_params, new_opt_state, loss

In [26]:
# Generate synthetic data
key_data = random.PRNGKey(567) # For data generation randomness

num_samples = 1000
input_features = 5
output_features = 10 # Matches our MLP output

# Generate random inputs
X = random.normal(key_data, (num_samples, input_features))

# This creates a non-linear relationship for the MLP to learn
key_noise, key_target_func = random.split(key_data)
true_weights = random.normal(key_target_func, (input_features, output_features))
true_bias = random.normal(key_target_func, (output_features,))

# A slightly more complex non-linear target
Y_true = jnp.dot(X**2, true_weights) + true_bias + random.normal(key_noise, (num_samples, output_features)) * 0.1

print(f"Generated X shape: {X.shape}")
print(f"Generated Y_true shape: {Y_true.shape}")

Generated X shape: (1000, 5)
Generated Y_true shape: (1000, 10)


In [27]:
num_epochs = 500

print("\n Starting training...")
for epoch in range(num_epochs):
    # In a real scenario, we'd shuffle and create mini-batches here
    # For simplicity, we'll use the whole dataset as one 'batch' for now
    batch_inputs = X
    batch_targets = Y_true

    mlp_params, opt_state, loss = train_step(mlp_params, opt_state, batch_inputs, batch_targets)

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

print("\nTraining complete.")
print(f"Final Loss: {loss:.4f}")


 Starting training...
Epoch 50/500, Loss: 15.6692
Epoch 100/500, Loss: 12.7086
Epoch 150/500, Loss: 10.3679
Epoch 200/500, Loss: 8.6159
Epoch 250/500, Loss: 7.3812
Epoch 300/500, Loss: 6.5138
Epoch 350/500, Loss: 5.8393
Epoch 400/500, Loss: 5.2482
Epoch 450/500, Loss: 4.7049
Epoch 500/500, Loss: 4.2103

Training complete.
Final Loss: 4.2103


# Inference

In [29]:
# Generate some new, unseen input data for inference
key_inference = jax.random.PRNGKey(999)
num_inferences_samples = 5

X_new = jax.random.normal(key_inference, (num_inferences_samples, input_features))

print(f"New input data for inference: \n{X_new}")
print(f"Shape of new input data: {X_new.shape}")

# Perform inference: apply the trained parameters to the new data
predictions = mlp_model.apply({'params': mlp_params}, X_new)

print(f"\nModel predictions on new data:\n{predictions}")
print(f"Shape of new predictions: {predictions.shape}")

New input data for inference: 
[[-1.898162    1.3126683  -0.93570435 -0.40194905  1.1873449 ]
 [ 0.8682871   0.05332198 -1.2655848  -0.32980207  0.24419203]
 [-1.7266408  -0.6284628   2.944953   -1.3948829   0.62099993]
 [ 0.90469366  0.6898744  -2.0962944  -0.9236466  -0.19609946]
 [ 0.38934988  0.43965247 -1.1499423  -0.9559521   1.2161872 ]]
Shape of new input data: (5, 5)

Model predictions on new data:
[[ -7.53619     -8.692844    -1.9360822   -2.9083357   -1.6106429
    2.0235376    1.2659616    1.0644199   -3.0131845   -1.6381968 ]
 [ -5.230658    -5.1156707   -2.6157258   -2.3324254   -1.6410153
    1.3128057    2.5697365    0.5646908   -0.8015749   -1.9574323 ]
 [-10.263551   -10.581922    -4.466531    -5.645301    -3.8386297
    3.1153164    5.514456     0.7701984   -2.111065    -4.586107  ]
 [ -8.351736    -7.9151163   -4.1317945   -4.149193    -2.7879117
    2.2939289    3.5006623   -0.6195143    0.1408867   -3.268793  ]
 [ -6.4102254   -6.752006    -2.709612    -0.9659612 