In [None]:
import jax
from jax import random, vmap, jit, grad, value_and_grad
import jax.numpy as jnp
import optax
from functools import partial
from tqdm import trange
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class FeedForwardNN:
    def __init__(self, layer_sizes, key, activation_fn=jax.nn.tanh):
        self.layer_sizes = layer_sizes
        self.activation_fn = activation_fn
        self.params = self.initialize_params(layer_sizes, key)

    def initialize_params(self, layer_sizes, key):
        params = []
        keys = random.split(key, len(layer_sizes) - 1)
        for i in range(len(layer_sizes) - 1):
            W_key, b_key = random.split(keys[i])
            # Xavier initialization for weights
            in_dim = layer_sizes[i]
            out_dim = layer_sizes[i+1]
            W = random.uniform(W_key, (in_dim, out_dim), minval=-jnp.sqrt(6 / (in_dim + out_dim)), maxval=jnp.sqrt(6 / (in_dim + out_dim)))
            # Initialize biases with zeros
            b = jnp.zeros(out_dim)
            params.append((W, b))
        return params

    @partial(jit, static_argnums=(0,))
    def forward(self, params, x):
        for W, b in params[:-1]:
            # Linear transformation
            x = jnp.dot(x, W) + b
            # Apply activation function
            x = self.activation_fn(x)
        # Output layer (no activation function)
        W, b = params[-1]
        x = jnp.dot(x, W) + b
        return x

    def predict(self, x):
        # Predict output for input x
        return vmap(self.forward, in_axes=(None, 0))(self.params, x)

In [None]:
class DeepONet:
    def __init__(self, branch_layer_sizes, trunk_layer_sizes, key, activation_fn=jax.nn.tanh):
        # Initialize branch and trunk networks
        branch_key, trunk_key = random.split(key)
        self.branch_net = FeedForwardNN(branch_layer_sizes, branch_key, activation_fn)
        self.trunk_net = FeedForwardNN(trunk_layer_sizes, trunk_key, activation_fn)

    def params(self):
        return [self.branch_net.params, self.trunk_net.params]

    @partial(jit, static_argnums=(0,))
    def forward(self, params, branch_input, trunk_input):
        # Forward pass through branch and trunk networks
        branch_output = self.branch_net.forward(params[0], branch_input)
        trunk_output = self.trunk_net.forward(params[1], trunk_input)
        # Combine outputs using inner product
        return jnp.dot(branch_output, trunk_output.T)
    
    @partial(jit, static_argnums=(0,))
    def forward_squeeze(self, params, branch_input, trunk_input):
        return self.forward(params, branch_input, trunk_input).squeeze()

    def predict(self, branch_input, trunk_input):
        # Predict output for given inputs
        return vmap(self.forward, in_axes=(None, 0, 0))(self.params(), branch_input, trunk_input)

In [None]:
# # Parameter omega
def generate_omega(omega_min, omega_max, n_samples, key):
    
    # # Normally distributed omega

    # Parameters for the true normal distribution
    mean = (omega_max + omega_min) / 2
    std = 2  # Smaller std = tighter around center

    # Rejection sampling: draw until we have enough inside bounds
    samples = []

    while len(samples) < n_samples:
        # Use JAX's random.normal instead of NumPy's
        proposed = random.normal(key, shape=(n_samples,)) * std + mean
        proposed = np.array(proposed)  # Convert to NumPy for boolean masking
        accepted = proposed[(proposed >= omega_min) & (proposed <= omega_max)]
        samples.extend(accepted.tolist())
        key, _ = random.split(key)  # Update key for next batch

    omega_normal = np.array(samples[:n_samples])

    # # Uniformly distributed omega

    omega_uniform = random.uniform(
        key, shape=(n_samples,), minval=omega_min, maxval=omega_max
    )

    # # Beta distributed omega

    # Parameters for the beta distribution
    alpha = 0.4
    beta = 0.4

    # Generate beta-distributed samples in [0, 1]
    beta_samples = random.beta(key, alpha, beta, shape=(n_samples,))

    # Scale the beta samples to the desired range [omega_min, omega_max]
    omega_beta = omega_min + (omega_max - omega_min) * beta_samples

    return omega_normal, omega_uniform, omega_beta

omega_min = -1.0
omega_max = 3 * np.pi

random_key = random.PRNGKey(42)  # Random key for JAX
n_samples = 5000
omega_normal, omega_uniform, omega_beta = generate_omega(omega_min, omega_max, n_samples, random_key)

# --- Plot all three histograms in one figure ---
plt.figure(figsize=(10, 6))

# Plot histograms with histtype='step' (outline only)
plt.hist(
    omega_normal, bins=50, density=True, histtype="step", linewidth=2, label="Normal"
)
plt.hist(
    omega_uniform, bins=50, density=True, histtype="step", linewidth=2, label="Uniform"
)
plt.hist(
    omega_beta,
    bins=50,
    density=True,
    histtype="step",
    linewidth=2,
    label="Beta(0.4, 0.4)",
)

plt.title("Comparison of Omega Distributions", fontsize=14)
plt.xlabel("Omega", fontsize=12)
plt.ylabel("Density", fontsize=12)
plt.grid(True, linestyle="--", alpha=0.6)
plt.legend(fontsize=12)
plt.show()


In [None]:

# Generate random x values between -1 and 1 with a different random key
key_x = random.PRNGKey(2)
x = random.uniform(key_x, shape=(n_samples,), minval=-1, maxval=1)

# Compute y values for the generated omega and x
y_normal = jnp.sin(omega_normal * x)

y_uniform = jnp.sin(omega_uniform * x)

y_beta = jnp.sin(omega_beta * x)


In [None]:

def sample_training_data(omega, x, y, n_batch, iter):
    """
    Samples a batch of training data from x and y using the same random indices.

    Args:
        x (jnp.ndarray): Input vector.
        y (jnp.ndarray): Target vector.
        n_batch (int): Batch size.
        iter (int): Current iteration (used to generate a random seed).

    Returns:
        x_batch (jnp.ndarray): Sampled input batch.
        y_batch (jnp.ndarray): Sampled target batch.
    """
    # Generate a random key based on the iteration
    key = random.PRNGKey(iter)
    
    # Generate random indices for sampling
    indices = random.choice(key, len(x), shape=(n_batch,), replace=False)
    
    # Sample x and y using the same indices
    omega_batch = omega[indices]
    x_batch = x[indices]
    y_batch = y[indices]
    
    return omega_batch, x_batch, y_batch

def initialize_DON(key=None):
    # Use the provided key or default to random.PRNGKey(0)
    if key is None:
        key = random.PRNGKey(0)
    
    # Create a neural network
    branch_layer_sizes = [1, 50, 50, 10]
    trunk_layer_sizes = [1, 50, 50, 10]
    nn = DeepONet(branch_layer_sizes, trunk_layer_sizes, key)

    return nn

def train_DON(nn,omega,x,y,max_iter,batch_sampling=False):

    def loss_fn(params, omega, x, y):
        return jnp.mean(squared_residual_batch(params, omega, x)) + jnp.mean(squared_errors_batch(params, omega, x, y))

    # Define training step with Adam optimizer
    @jit
    def train_step(params, opt_state, omega, x, y):    
        loss, grads = jax.value_and_grad(loss_fn)(params, omega, x, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss

    # Define the loss function (Mean Squared Error)
    # Define the loss function
    u_x = jit(grad(nn.forward_squeeze, argnums=2))
    squared_residual = jit(lambda params, omega, x: (u_x(params, omega, x) - omega * jnp.cos(omega * x)) ** 2)
    squared_residual_batch = jit(vmap(squared_residual, in_axes=(None, 0, 0)))

    # With or without data
    # squared_error = jit(lambda params, omega: nn.forward(params, omega, 0) ** 2)
    squared_error = jit(lambda params, omega, x, y: (nn.forward(params, omega, x) - y) ** 2)
    squared_errors_batch = jit(vmap(squared_error, in_axes=(None, 0, 0, 0)))
    
    # Initialize the Adam optimizer with the learning rate schedule
    learning_rate_schedule = optax.piecewise_constant_schedule(
        init_value=0.01,
        boundaries_and_scales={300: 0.2, 2000: 0.5}
    )
    optimizer = optax.adam(learning_rate=learning_rate_schedule)
    opt_state = optimizer.init(nn.params())
    losses = []

    # Training loop
    max_iterations = max_iter
    batch_ratio = 10

    # Training loop with tqdm progress bar with batch sampling
    if batch_sampling:
        pbar = trange(max_iterations, desc="Training", leave=True)
        for i, epoch in enumerate(pbar):
            
            omega_batch, x_batch, y_batch = sample_training_data(omega, x, y, round(n_samples/batch_ratio), i)

            don_params, opt_state, current_loss = train_step(nn.params(), opt_state, omega_batch, x_batch, y_batch)
            nn.branch_net.params = don_params[0] 
            nn.trunk_net.params = don_params[1]
            losses.append(current_loss)
            if epoch % 100 == 0:
                pbar.set_postfix({'loss': current_loss})
            if current_loss < 1.0e-5:
                break
    else:
        # Training loop without batch sampling
        pbar = trange(max_iterations, desc="Training", leave=True)
        for i, epoch in enumerate(pbar):
            don_params, opt_state, current_loss = train_step(nn.params(), opt_state, omega, x, y)
            nn.branch_net.params = don_params[0] 
            nn.trunk_net.params = don_params[1]
            losses.append(current_loss)
            if epoch % 100 == 0:
                pbar.set_postfix({'loss': current_loss})
            if current_loss < 1.0e-5:
                break

    return nn, losses


In [None]:
max_iterations = 5000

nn_normal = initialize_DON()

nn_normal, losses_normal = train_DON(
    nn_normal, omega_normal, x, y_normal, max_iterations
)

nn_uniform = initialize_DON()

nn_uniform, losses_uniform = train_DON(
    nn_uniform, omega_uniform, x, y_uniform, max_iterations
)

nn_beta = initialize_DON()

nn_beta, losses_beta = train_DON(nn_beta, omega_beta, x, y_beta, max_iterations)

In [None]:
# Generate test data for x
x_test = jnp.linspace(-1.0, 1.0, 1000)

# Plot the final results for nn
n_test_omegas = 50  # Number of test omega values

# Generate equally spaced omega_test values
delta = 1
omega_tests = jnp.linspace(omega_min - delta, omega_max + delta, n_test_omegas)

mse_normal_results = []
mse_uniform_results = []
mse_beta_results = []

for omega_test in omega_tests:
    omega_test_array = jnp.ones_like(x_test) * omega_test
    y_test = jnp.sin(omega_test_array * x_test)

    mse_normal = np.mean(
        (y_test - nn_normal.predict(omega_test_array, x_test).reshape(y_test.shape))
        ** 2
    )
    mse_uniform = np.mean(
        (y_test - nn_uniform.predict(omega_test_array, x_test).reshape(y_test.shape))
        ** 2
    )
    mse_beta = np.mean(
        (y_test - nn_beta.predict(omega_test_array, x_test).reshape(y_test.shape)) ** 2
    )

    mse_normal_results.append(mse_normal)
    mse_uniform_results.append(mse_uniform)
    mse_beta_results.append(mse_beta)


# Plot MSE vs Omega (optional)

colors = ["b", "g", "r", "c", "m"]

# Create figure with 2 subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# --- First subplot: Omega distributions ---
ax1.hist(omega_normal, bins=50, density=True, histtype="step", 
         linewidth=2, label="Normal", color=colors[0])
ax1.hist(omega_uniform, bins=50, density=True, histtype="step", 
         linewidth=2, label="Uniform", color=colors[1])
ax1.hist(omega_beta, bins=50, density=True, histtype="step", 
         linewidth=2, label="Beta", color=colors[2])

ax1.set_xlabel("Omega", fontsize=12)
ax1.set_ylabel("Density", fontsize=12)
ax1.set_title("Omega Distributions", fontsize=14)
ax1.legend(fontsize=12)
ax1.grid(True, linestyle="--", alpha=0.6)

# --- Second subplot: MSE results (log scale) ---
ax2.semilogy(omega_tests, mse_normal_results, "o-", linewidth=2, 
             markersize=5, label="Normal", color=colors[0])
ax2.semilogy(omega_tests, mse_uniform_results, "o-", linewidth=2, 
             markersize=5, label="Uniform", color=colors[1])
ax2.semilogy(omega_tests, mse_beta_results, "o-", linewidth=2, 
             markersize=5, label="Beta", color=colors[2])

# Add vertical dotted lines at boundaries
ax2.axvline(omega_min, color='k', linestyle=':', alpha=0.7, linewidth=1.5)
ax2.axvline(omega_max, color='k', linestyle=':', alpha=0.7, linewidth=1.5)

ax2.set_xlabel("Omega", fontsize=12)
ax2.set_ylabel("MSE (log scale)", fontsize=12)
ax2.set_title("Prediction Error (Log Scale)", fontsize=14)
ax2.legend(fontsize=12)
ax2.grid(True, linestyle="--", alpha=0.6, which='both')

# Add minor grid lines for the log scale
ax2.minorticks_on()
ax2.grid(which='minor', linestyle=':', alpha=0.4)

plt.tight_layout()
plt.show()


In [None]:

# Plot the results for nn
# Create a figure with two subplots
fig, axs = plt.subplots(4, 1, figsize=(14, 20))

# Plot the loss over the number of iterations for nn
axs[0].set_yscale("log")
axs[0].plot(range(len(losses_normal)), losses_normal, label="Loss (normal)")
axs[0].plot(range(len(losses_uniform)), losses_uniform, label="Loss (uniform)")
axs[0].plot(range(len(losses_beta)), losses_beta, label="Loss (beta)")
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].legend()

# Plot the final results for nn
omega_tests = jnp.linspace(omega_min, omega_max, 5)
# key_uniform = random.PRNGKey(10)
# omega_tests = random.uniform(key_uniform, shape=(5,), minval=omega_min, maxval=omega_max)
omega_tests = omega_tests.at[1].set(0)

colors = ["b", "g", "r", "c", "m"]

# fig, axs = plt.subplots(1, 1, figsize=(10, 6))  # Adjust layout if needed

for omega_test, color in zip(omega_tests, colors):
    omega_test_array = jnp.ones_like(x_test) * omega_test
    y_test = jnp.sin(omega_test_array * x_test)
    # Plot true sine function (dashed line)
    axs[1].plot(
        x_test,
        y_test,
        color=color,
        linestyle="--",
        label=f"Sine function (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    # Plot neural network prediction (solid line)
    axs[1].plot(
        x_test,
        nn_normal.predict(omega_test_array, x_test).reshape(y_test.shape),
        color=color,
        label=f"Neural network (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    axs[2].plot(
        x_test,
        y_test,
        color=color,
        linestyle="--",
        label=f"Sine function (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    # Plot neural network prediction (solid line)
    axs[2].plot(
        x_test,
        nn_uniform.predict(omega_test_array, x_test).reshape(y_test.shape),
        color=color,
        label=f"Neural network (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    axs[3].plot(
        x_test,
        y_test,
        color=color,
        linestyle="--",
        label=f"Sine function (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    # Plot neural network prediction (solid line)
    axs[3].plot(
        x_test,
        nn_beta.predict(omega_test_array, x_test).reshape(y_test.shape),
        color=color,
        label=f"Neural network (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

# Add legend outside the plot
axs[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
axs[2].legend(loc='center left', bbox_to_anchor=(1, 0.5))
axs[3].legend(loc='center left', bbox_to_anchor=(1, 0.5))

# Set titles for each subplot
axs[0].set_title("Loss")
axs[1].set_title("Normal")
axs[2].set_title("Uniform") 
axs[3].set_title("Beta")

plt.show()

In [None]:
# Generate test data for x
x_test = jnp.linspace(-1.0, 1.0, 1000)

# Plot the final results for nn
n_test_omegas = 50  # Number of test omega values

# Generate equally spaced omega_test values
delta = 1
omega_tests = jnp.linspace(omega_min - delta, omega_max + delta, n_test_omegas)

# Number of repetitions for training
num_repeats = 5

# Initialize lists to store MSE results for each repetition
all_mse_normal_results = []
all_mse_uniform_results = []
all_mse_beta_results = []

n_samples = 500
max_iterations = 4000

key_adjust = 42

for i in range(num_repeats):
    print(f"Training repetition {i + 1}/{num_repeats}")

    random_key = random.PRNGKey(key_adjust + i)  # Random key for JAX
    omega_normal, omega_uniform, omega_beta = generate_omega(omega_min, omega_max, n_samples, random_key)

    # Generate random x values between -1 and 1 with a different random key
    key_x = random.PRNGKey(key_adjust + i)
    x = random.uniform(key_x, shape=(n_samples,), minval=-1, maxval=1)

    # Compute y values for the generated omega and x
    y_normal = jnp.sin(omega_normal * x)

    y_uniform = jnp.sin(omega_uniform * x)

    y_beta = jnp.sin(omega_beta * x)

    # Train DON models
    random_key = random.PRNGKey(i)
    nn_normal = initialize_DON(key_adjust + random_key)
    nn_normal, losses_normal = train_DON(
        nn_normal, omega_normal, x, y_normal, max_iterations
    )

    random_key = random.PRNGKey(key_adjust + num_repeats + i)
    nn_uniform = initialize_DON(random_key)
    nn_uniform, losses_uniform = train_DON(
        nn_uniform, omega_uniform, x, y_uniform, max_iterations
    )
    
    random_key = random.PRNGKey(key_adjust + 2 * num_repeats + i)
    nn_beta = initialize_DON(random_key)
    nn_beta, losses_beta = train_DON(nn_beta, omega_beta, x, y_beta, max_iterations)

    # Compute MSE for each omega_test
    mse_normal_results = []
    mse_uniform_results = []
    mse_beta_results = []

    for omega_test in omega_tests:
        omega_test_array = jnp.ones_like(x_test) * omega_test
        y_test = jnp.sin(omega_test_array * x_test)

        mse_normal = np.mean(
            (y_test - nn_normal.predict(omega_test_array, x_test).reshape(y_test.shape))
            ** 2
        )
        mse_uniform = np.mean(
            (
                y_test
                - nn_uniform.predict(omega_test_array, x_test).reshape(y_test.shape)
            )
            ** 2
        )
        mse_beta = np.mean(
            (y_test - nn_beta.predict(omega_test_array, x_test).reshape(y_test.shape))
            ** 2
        )

        mse_normal_results.append(mse_normal)
        mse_uniform_results.append(mse_uniform)
        mse_beta_results.append(mse_beta)

    # Store results for this repetition
    all_mse_normal_results.append(mse_normal_results)
    all_mse_uniform_results.append(mse_uniform_results)
    all_mse_beta_results.append(mse_beta_results)

# Compute mean MSE across all repetitions for each omega
mean_mse_normal = np.mean(all_mse_normal_results, axis=0)
mean_mse_uniform = np.mean(all_mse_uniform_results, axis=0)
mean_mse_beta = np.mean(all_mse_beta_results, axis=0)

# Plot mean MSE vs Omega

colors = ["b", "g", "r"]

# Create figure
fig, ax = plt.subplots(figsize=(10, 6))

# Plot mean MSE results (log scale)
ax.semilogy(
    omega_tests,
    mean_mse_normal,
    "o-",
    linewidth=2,
    markersize=5,
    label="Normal",
    color=colors[0],
)
ax.semilogy(
    omega_tests,
    mean_mse_uniform,
    "o-",
    linewidth=2,
    markersize=5,
    label="Uniform",
    color=colors[1],
)
ax.semilogy(
    omega_tests,
    mean_mse_beta,
    "o-",
    linewidth=2,
    markersize=5,
    label="Beta",
    color=colors[2],
)

# Add vertical dotted lines at boundaries
ax.axvline(omega_min, color="k", linestyle=":", alpha=0.7, linewidth=1.5)
ax.axvline(omega_max, color="k", linestyle=":", alpha=0.7, linewidth=1.5)

ax.set_xlabel("Omega", fontsize=12)
ax.set_ylabel("Mean MSE (log scale)", fontsize=12)
ax.set_title("Mean Prediction Error (Log Scale)", fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, linestyle="--", alpha=0.6, which="both")

# Add minor grid lines for the log scale
ax.minorticks_on()
ax.grid(which="minor", linestyle=":", alpha=0.4)

plt.tight_layout()
plt.show()