In [None]:
import jax
from jax import random, vmap, jit, grad, value_and_grad
import jax.numpy as jnp
import optax
from functools import partial
from tqdm import trange
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class FeedForwardNN:
    def __init__(self, layer_sizes, key, activation_fn=jax.nn.tanh):
        self.layer_sizes = layer_sizes
        self.activation_fn = activation_fn
        self.params = self.initialize_params(layer_sizes, key)

    def initialize_params(self, layer_sizes, key):
        params = []
        keys = random.split(key, len(layer_sizes) - 1)
        for i in range(len(layer_sizes) - 1):
            W_key, b_key = random.split(keys[i])
            # Xavier initialization for weights
            in_dim = layer_sizes[i]
            out_dim = layer_sizes[i+1]
            W = random.uniform(W_key, (in_dim, out_dim), minval=-jnp.sqrt(6 / (in_dim + out_dim)), maxval=jnp.sqrt(6 / (in_dim + out_dim)))
            # Initialize biases with zeros
            b = jnp.zeros(out_dim)
            params.append((W, b))
        return params

    @partial(jit, static_argnums=(0,))
    def forward(self, params, x):
        for W, b in params[:-1]:
            # Linear transformation
            x = jnp.dot(x, W) + b
            # Apply activation function
            x = self.activation_fn(x)
        # Output layer (no activation function)
        W, b = params[-1]
        x = jnp.dot(x, W) + b
        return x

    def predict(self, x):
        # Predict output for input x
        return vmap(self.forward, in_axes=(None, 0))(self.params, x)

In [None]:
class DeepONet:
    def __init__(self, branch_layer_sizes, trunk_layer_sizes, key, activation_fn=jax.nn.tanh):
        # Initialize branch and trunk networks
        branch_key, trunk_key = random.split(key)
        self.branch_net = FeedForwardNN(branch_layer_sizes, branch_key, activation_fn)
        self.trunk_net = FeedForwardNN(trunk_layer_sizes, trunk_key, activation_fn)

    def params(self):
        return [self.branch_net.params, self.trunk_net.params]

    @partial(jit, static_argnums=(0,))
    def forward(self, params, branch_input, trunk_input):
        # Forward pass through branch and trunk networks
        branch_output = self.branch_net.forward(params[0], branch_input)
        trunk_output = self.trunk_net.forward(params[1], trunk_input)
        # Combine outputs using inner product
        return jnp.dot(branch_output, trunk_output.T)
    
    @partial(jit, static_argnums=(0,))
    def forward_squeeze(self, params, branch_input, trunk_input):
        return self.forward(params, branch_input, trunk_input).squeeze()

    def predict(self, branch_input, trunk_input):
        # Predict output for given inputs
        return vmap(self.forward, in_axes=(None, 0, 0))(self.params(), branch_input, trunk_input)

In [None]:

# # Parameter omega
def generate_omega(omega_min, omega_max, n_samples, key):
    
    # # Normally distributed omega

    # Parameters for the true normal distribution
    mean = (omega_max + omega_min) / 2
    std = 2  # Smaller std = tighter around center

    # Rejection sampling: draw until we have enough inside bounds
    samples = []

    while len(samples) < n_samples:
        # Use JAX's random.normal instead of NumPy's
        proposed = random.normal(key, shape=(n_samples,)) * std + mean
        proposed = np.array(proposed)  # Convert to NumPy for boolean masking
        accepted = proposed[(proposed >= omega_min) & (proposed <= omega_max)]
        samples.extend(accepted.tolist())
        key, _ = random.split(key)  # Update key for next batch

    omega_normal = np.array(samples[:n_samples])

    # # Uniformly distributed omega

    omega_uniform = random.uniform(
        key, shape=(n_samples,), minval=omega_min, maxval=omega_max
    )

    # # Beta distributed omega

    # Parameters for the beta distribution
    alpha = 0.4
    beta = 0.4

    # Generate beta-distributed samples in [0, 1]
    beta_samples = random.beta(key, alpha, beta, shape=(n_samples,))

    # Scale the beta samples to the desired range [omega_min, omega_max]
    omega_beta = omega_min + (omega_max - omega_min) * beta_samples

    return omega_normal, omega_uniform, omega_beta

omega_min = -1.0
omega_max = 3 * np.pi

random_key = random.PRNGKey(42)  # Random key for JAX
n_samples = 1000
omega_normal, omega_uniform, omega_beta = generate_omega(omega_min, omega_max, n_samples, random_key)

# --- Plot all three histograms in one figure ---
plt.figure(figsize=(10, 6))

# Plot histograms with histtype='step' (outline only)
plt.hist(omega_normal, bins=50, density=True, histtype='step', linewidth=2, label='Normal')
plt.hist(omega_uniform, bins=50, density=True, histtype='step', linewidth=2, label='Uniform')
plt.hist(omega_beta, bins=50, density=True, histtype='step', linewidth=2, label='Beta(0.5, 0.5)')

plt.title('Comparison of Omega Distributions', fontsize=14)
plt.xlabel('Omega', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(fontsize=12)
plt.show()


In [None]:

def initialize_DON(n_input_sensor):
    # Create a neural network
    key = random.PRNGKey(0)
    branch_layer_sizes = [n_input_sensor, 50, 50, 10]
    trunk_layer_sizes = [1, 50, 50, 10]
    nn = DeepONet(branch_layer_sizes, trunk_layer_sizes, key)

    return nn

def train_DON(nn,u,omega,x,y,max_iter):

    def loss_fn(params, u, omega, x, y):
        # return jnp.mean(jnp.mean(squared_errors_batch(params, u, x, y)))
        return jnp.mean(squared_residual_batch(params, u, omega, x)) + jnp.mean(squared_errors_batch(params, u, x, y))

    # Define training step with Adam optimizer
    @jit
    def train_step(params, opt_state, u, omega, x, y):    
        loss, grads = jax.value_and_grad(loss_fn)(params, u, omega, x, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss

    # Define the loss function
    u_x = jit(grad(nn.forward_squeeze, argnums=2))
    squared_residual = jit(lambda params, u, omega, x: (u_x(params, u, x) - omega * jnp.cos(omega * x)) ** 2)
    squared_residual_batch = jit(vmap(squared_residual, in_axes=(None, 0, 0, 0))) 

    # With or without data
    # squared_error = jit(lambda params, u: nn.forward(params, u, 0) ** 2)
    squared_error = jit(lambda params, u, x, y: (nn.forward(params, u, x) - y) ** 2)
    squared_errors_batch = jit(vmap(squared_error, in_axes=(None, 0, 0, 0)))
    
    # Initialize the Adam optimizer with the learning rate schedule
    learning_rate_schedule = optax.piecewise_constant_schedule(
        init_value=0.01,
        boundaries_and_scales={300: 0.2, 2000: 0.5}
    )
    optimizer = optax.adam(learning_rate=learning_rate_schedule)
    opt_state = optimizer.init(nn.params())
    losses = []

    # Training loop
    max_iterations = max_iter

    # Training loop with tqdm progress bar
    pbar = trange(max_iterations, desc="Training", leave=True)
    for epoch in pbar:
        don_params, opt_state, current_loss = train_step(nn.params(), opt_state, u, omega, x, y)
        nn.branch_net.params = don_params[0] 
        nn.trunk_net.params = don_params[1]
        losses.append(current_loss)
        if epoch % 100 == 0:
            pbar.set_postfix({'loss': current_loss})
        if current_loss < 1.0e-5:
            break
    
    return nn, losses


In [None]:
omega = omega_beta

# Generate random x values between -1 and 1 with a different random key
key_x = random.PRNGKey(4)
x = random.uniform(key_x, shape=(n_samples,), minval=-1, maxval=1)

# Compute y values for the generated omega and x
y = jnp.sin(omega * x)

n_input_sensor = 2
# key_input_sensor = random.PRNGKey(5)
# x_input_sensor = random.uniform(key_input_sensor, shape=(n_input_sensor,), minval=-1, maxval=1)

# The points should not be symmetrically distributed when having very few sensor points
x_sensor_min = -0.7
x_sensor_max = 0.9
x_input_sensor = np.linspace(x_sensor_min,x_sensor_max,n_input_sensor)

u = jnp.cos(omega[:,None] * x_input_sensor) * omega[:,None]

max_iterations = 5000

nn = initialize_DON(n_input_sensor)

nn, losses = train_DON(
    nn, u, omega, x, y, max_iterations
)


In [None]:
# Generate test data for x
x_test = jnp.linspace(-1.0, 1.0, 1000)

# Plot the final results for nn
n_test_u = 50  # Number of test omega values

# Generate equally spaced omega_test values
delta = 1

# key_test = random.PRNGKey(10)
# omega_tests = random.uniform(key_test, shape=(n_samples,), minval=omega_min-delta, maxval=omega_max+delta)
omega_tests = jnp.linspace(omega_min - delta, omega_max + delta, n_test_u)

u_tests = jnp.cos(omega_tests[:,None] * x_input_sensor) * omega_tests[:,None]

mse_results = []

for i in range(n_test_u):
    omega_test_array = jnp.ones_like(x_test) * omega_tests[i]
    y_test = jnp.sin(omega_test_array * x_test)
    u_test = jnp.tile(u_tests[i], (len(x_test), 1))

    mse = np.mean(
        (y_test - nn.predict(u_test, x_test).reshape(y_test.shape))
        ** 2
    )
    
    mse_results.append(mse)
    

# Plot MSE vs Omega (optional)

colors = ["b", "g", "r", "c", "m"]

# Create figure with 2 subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# --- First subplot: Omega distributions ---
ax1.hist(omega, bins=50, density=True, histtype="step", 
         linewidth=2, color=colors[0])

ax1.set_xlabel("Omega", fontsize=12)
ax1.set_ylabel("Density", fontsize=12)
ax1.set_title("Omega Distribution", fontsize=14)
ax1.grid(True, linestyle="--", alpha=0.6)

# --- Second subplot: MSE results (log scale) ---
ax2.semilogy(omega_tests, mse_results, "o-", linewidth=2, 
             markersize=5, color=colors[0])

# Add vertical dotted lines at boundaries
ax2.axvline(omega_min, color='k', linestyle=':', alpha=0.7, linewidth=1.5)
ax2.axvline(omega_max, color='k', linestyle=':', alpha=0.7, linewidth=1.5)

ax2.set_xlabel("Omega", fontsize=12)
ax2.set_ylabel("MSE (log scale)", fontsize=12)
ax2.set_title("Prediction Error (Log Scale)", fontsize=14)
ax2.grid(True, linestyle="--", alpha=0.6, which='both')

# Add minor grid lines for the log scale
ax2.minorticks_on()
ax2.grid(which='minor', linestyle=':', alpha=0.4)

plt.tight_layout()
plt.show()


In [None]:

# Plot the results for nn
# Create a figure with two subplots
fig, axs = plt.subplots(2, 1, figsize=(14, 10))

# Plot the loss over the number of iterations for nn
axs[0].set_yscale("log")
axs[0].plot(range(len(losses)), losses, label="Loss")
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].legend()

# Plot the final results for nn
omega_tests = jnp.linspace(omega_min, omega_max, 5)
u_tests = jnp.cos(omega_tests[:,None] * x_input_sensor) * omega_tests[:,None]

# omega_tests = omega_tests.at[1].set(0)

colors = ["b", "g", "r", "c", "m"]

# fig, axs = plt.subplots(1, 1, figsize=(10, 6))  # Adjust layout if needed

for omega_test, color, i in zip(omega_tests, colors, range(5)):
    omega_test_array = jnp.ones_like(x_test) * omega_test
    y_test = jnp.sin(omega_test_array * x_test)
    u_test = jnp.tile(u_tests[i], (len(x_test), 1))

    # Plot true sine function (dashed line)
    axs[1].plot(
        x_test,
        y_test,
        color=color,
        linestyle="--",
        label=f"Sine function (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

    # Plot neural network prediction (solid line)
    axs[1].plot(
        x_test,
        nn.predict(u_test, x_test).reshape(y_test.shape),
        color=color,
        label=f"Neural network (ω={float(omega_test):.2f})",  # Round to 2 decimals
    )

# Add vertical dotted lines at each sensor point
for x_sensor in x_input_sensor:
    axs[1].axvline(x=x_sensor, color='k', linestyle=':', alpha=0.7, linewidth=1.5)

# Add legend outside the plot
axs[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))

# Set titles for each subplot
axs[0].set_title("Loss")
axs[1].set_title("Solutions")

plt.show()

In [None]:
key_adjust = 42

# Generate test data for x
x_test = jnp.linspace(-1.0, 1.0, 1000)

# Plot the final results for nn
n_test_u = 50  # Number of test omega values

# Generate equally spaced omega_test values
delta = 1
omega_tests = jnp.linspace(omega_min - delta, omega_max + delta, n_test_u)

# Number of repetitions for training
n_input_sensor_list = [1, 2, 4, 8]
max_iterations = 4000

num_repeats = 5

mean_mse_results = []

for i in range(len(n_input_sensor_list)): 
    print(f"Training repetition {i + 1}/{len(n_input_sensor_list)}")

    # List to store MSE results for each repeat
    all_mse_results = []

    for j in range(num_repeats):
        print(f"Repetition {j + 1}/{num_repeats}")

        random_key = random.PRNGKey(key_adjust + j)  # Random key for JAX
        omega_normal, omega_uniform, omega_beta = generate_omega(omega_min, omega_max, n_samples, random_key)

        omega = omega_beta

        # Generate random x values between -1 and 1 with a different random key
        key_x = random.PRNGKey(key_adjust + j)
        x = random.uniform(key_x, shape=(n_samples,), minval=-1, maxval=1)

        # Compute y values for the generated omega and x
        y = jnp.sin(omega * x)

        n_input_sensor = n_input_sensor_list[i]
        x_input_sensor = np.linspace(x_sensor_min,x_sensor_max,n_input_sensor)

        u = jnp.cos(omega[:, None] * x_input_sensor) * omega[:, None]

        u_tests = jnp.cos(omega_tests[:, None] * x_input_sensor) * omega_tests[:, None]

        nn = initialize_DON(n_input_sensor)

        nn, losses = train_DON(
            nn, u, omega, x, y, max_iterations
        )

        mse_results = []

        for j in range(n_test_u):
            omega_test_array = jnp.ones_like(x_test) * omega_tests[j]
            y_test = jnp.sin(omega_test_array * x_test)
            u_test = jnp.tile(u_tests[j], (len(x_test), 1))

            mse = np.mean(
                (y_test - nn.predict(u_test, x_test).reshape(y_test.shape))
                ** 2
            )
            
            mse_results.append(mse)

        # Save the MSE results for this repeat
        all_mse_results.append(mse_results)

    mean_mse_results.append(np.mean(all_mse_results, axis=0))

# Plot all MSE results together
colors = ["b", "g", "r", "c", "m"]

plt.figure(figsize=(10, 6))

for i, mean_mse_result in enumerate(mean_mse_results):
    plt.semilogy(
        omega_tests,
        mean_mse_result,
        label=f"Input sensor points={n_input_sensor_list[i]}",
        color=colors[i % len(colors)],
        marker="o",
        linewidth=2,
        markersize=5,
    )

plt.xlabel("Omega", fontsize=12)
plt.ylabel("MSE (log scale)", fontsize=12)
plt.title("MSE results for different numbers of sensor points", fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, linestyle="--", alpha=0.6, which="both")

# Add vertical dotted lines at boundaries
plt.axvline(omega_min, color='k', linestyle=':', alpha=0.7, linewidth=1.5)
plt.axvline(omega_max, color='k', linestyle=':', alpha=0.7, linewidth=1.5)

plt.tight_layout()
plt.show()