In [12]:
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import jax
import optax

from jax import jacfwd

import matplotlib.pyplot as plt

In [13]:
import sys
import os

collocation2_path = os.path.abspath(os.path.join('..', 'collocation2'))

# Add the directory to sys.path
if collocation2_path not in sys.path:
    sys.path.append(collocation2_path)

from interpolation import BarycentricInterpolation

In [14]:
# Define a simple neural network with Flax
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(2)(x)  # Assuming a 2D output for u and v
        return x

# Initialize the neural network
def create_model():
    model = SimpleNN()
    return model

# Initialize parameters
def init_model_params(model, rng, input_shape):
    params = model.init(rng, jnp.ones(input_shape))
    return params

# Define the forward function
def model_forward(params, x):
    return SimpleNN().apply(params, x)


In [15]:
# Create a training state
def create_train_state(rng, learning_rate, model, input_shape):
    params = init_model_params(model, rng, input_shape)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Define a loss function
def compute_loss(params, batch):
    preds = model_forward(params, batch['inputs'])
    loss = jnp.mean((preds - batch['targets']) ** 2)
    return loss

# Training step
@jax.jit
def train_step(state, batch):
    loss_fn = lambda params: compute_loss(params, batch)
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

# Dummy data for demonstration
rng = jax.random.PRNGKey(0)
input_shape = (10, 2)  # Example input shape
learning_rate = 0.001
num_epochs = 100

# Initialize model and state
model = create_model()
state = create_train_state(rng, learning_rate, model, input_shape)

# Example training loop
for epoch in range(num_epochs):
    # Dummy batch for demonstration
    batch = {'inputs': jax.random.normal(rng, input_shape), 'targets': jax.random.normal(rng, input_shape)}
    state = train_step(state, batch)

# Trained parameters
trained_params = state.params


In [16]:
def lagrange_basis(t, k, t_grid):
    terms = [(t - t_grid[m]) / (t_grid[k] - t_grid[m]) for m in range(len(t_grid)) if m != k]
    return jnp.prod(jnp.array(terms))


def interpolate(u, t, t_grid):
    """
    u: The array of function values at the grid points
    t: The point at which to interpolate
    t_grid: The array of interpolation nodes
    """
    if u.ndim == 1:
        return jnp.sum(jnp.array([u_k * lagrange_basis(t, k, t_grid) for k, u_k in enumerate(u)]))
    return jnp.array([jnp.sum(jnp.array([u_k[i] * lagrange_basis(t, k, t_grid) for k, u_k in enumerate(u)])) for i in range(u.shape[1])])


In [17]:
# Use the trained neural network as the function f
def neural_net_f(t, uv, params):
    return model_forward(params, uv)

# Modify the solver to pass the neural network parameters
def newton_method_nn(u_init, t_grid, f, params, u0=1, tol=1e-6, max_iter=20):
    """
    Solves a system of nonlinear equations using Newton's method with a neural network.
    
    Parameters:
    u_init : ndarray
        Initial guess for the solution.
    t_grid : ndarray
        Grid points where the solution is evaluated.
    f : function
        The function defining the system of differential equations.
    params : dict
        Trained parameters of the neural network.
    u0 : float or ndarray, optional
        Initial condition for the system (default is 1).
    tol : float, optional
        Tolerance for the norm of the update (default is 1e-6).
    max_iter : int, optional
        Maximum number of iterations (default is 100).
    
    Returns:
    ndarray
        The solution vector.
    
    Raises:
    ValueError
        If Newton's method did not converge.
    """
    u = u_init
    
    for i in range(max_iter):
        F_u = system(u, t_grid, f, u0, params)
        
        # Compute the Jacobian matrix of the system at u
        J_u = jacfwd(system, argnums=0)(u, t_grid, f, u0, params)
        
        # Flatten the Jacobian and F_u for solving
        F_u_flat = F_u.reshape(-1)
        J_u_flat = J_u.reshape(F_u_flat.shape[0], -1)
        
        # Solve for the update Δu in the linear system J(u) Δu = -F(u)
        delta_u_flat = jnp.linalg.solve(J_u_flat, -F_u_flat)
        
        # Reshape the update and apply
        delta_u = delta_u_flat.reshape(u.shape)
        u = u + delta_u
        
        norm_delta_u = jnp.linalg.norm(delta_u)
        if norm_delta_u < tol:
            print(f"Converged at iteration {i+1}")
            return u
    
    raise ValueError("Newton's method did not converge")

# Update system function to pass the neural network parameters
def system(u, t_grid, f, u0, params):
    u0 = jnp.array([1.0, 0.0])  # Initial condition u(0) = 1, v(0) = 0
    eqs = [u[0] - u0]
    
    for k in range(1, len(t_grid)):
        eq = u[k] - u[k-1] - integral(f, t_grid[k-1], t_grid[k], u, t_grid, params)
        eqs.append(eq)
    
    return jnp.array(eqs)

# Update integral function to pass the neural network parameters
def integral(f, t_k, t_k1, u, t_grid, params):
    integrand = lambda s: f(s, interpolate(u, s, t_grid), params)
    vectorized_integrand = jax.vmap(integrand)
    s_values = jnp.linspace(t_k, t_k1, 100)
    integrand_values = vectorized_integrand(s_values)
    integral_value = jnp.trapezoid(integrand_values, s_values, axis=0)
    return integral_value


In [18]:
# Parameters
T = 4.0  # End time
N = 10  # Number of grid points
omega = 1

# Create the grid points
interpolator = BarycentricInterpolation(N, start=0, stop=T)
t_grid = interpolator.nodes

# Initial guess for the solution
uv_init = jnp.ones((N, 2))  # Start with a non-zero initial guess
uv_init = uv_init.at[0, 1].set(0)  # Set the initial velocity to 0

# Solve the system using the neural network
uv_solution_nn = newton_method_nn(uv_init, t_grid, neural_net_f, trained_params)

# Print the solution
print("Solution at grid points using neural network:", uv_solution_nn)

ValueError: Newton's method did not converge

In [None]:
# Plot predicted vs. true solution
true_solution_u = jnp.cos(omega * t_grid)
true_solution_v = -omega * jnp.sin(omega * t_grid)

plt.figure(figsize=(10, 6))
plt.plot(t_grid, uv_solution_nn[:, 0], 'ro-', label='Predicted Solution u(t)')
plt.plot(t_grid, true_solution_u, 'b-', label='True Solution u(t)')
plt.plot(t_grid, uv_solution_nn[:, 1], 'go-', label='Predicted Solution v(t)')
plt.plot(t_grid, true_solution_v, 'm-', label='True Solution v(t)')
plt.xlabel('Time')
plt.ylabel('u(t), v(t)')
plt.legend()
plt.title('Predicted vs. True Solution for Harmonic Oscillator using Neural Network')
plt.grid(True)
plt.show()