In [2]:
import jax.numpy as jnp
from jax import grad, jit, random, value_and_grad
from jax.scipy.linalg import solve
from jax.experimental import optimizers
import matplotlib.pyplot as plt

def chebyshev_nodes_second_kind(n, start, stop):
    k = jnp.arange(n)
    x = jnp.cos(jnp.pi * k / (n - 1))
    nodes = 0.5 * (stop - start) * x + 0.5 * (start + stop)
    return jnp.sort(nodes)

def lagrange_basis_node(xi):
    n = len(xi)
    L = jnp.eye(n)
    return L

def lagrange_basis(xi, x):
    n = len(xi)
    L = jnp.ones((n, len(x)))
    for i in range(n):
        for j in range(n):
            if i != j:
                L = L.at[i, :].set(L[i, :] * (x - xi[j]) / (xi[i] - xi[j]))
    return L

def lagrange_basis_single(xi, x, j):
    n = len(xi)
    L = 1.0
    for m in range(n):
        if m != j:
            L *= (x - xi[m]) / (xi[j] - xi[m])
    return L

def derivative_at_node(xi, j, h=1e-5):
    x_j = xi[j]
    forward = lagrange_basis_single(xi, x_j + h, j)
    backward = lagrange_basis_single(xi, x_j - h, j)
    derivative = (forward - backward) / (2 * h)
    return derivative

def lagrange_derivative(xi, weights):
    n = len(xi)
    D = jnp.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                D = D.at[i, j].set(weights[j] / weights[i] / (xi[i] - xi[j]))
            else:
                approx_derivative = derivative_at_node(xi, j)
                D = D.at[i, j].set(approx_derivative)
    return D

def compute_weights(xi):
    n = len(xi)
    xi = jnp.array(xi)
    weights = jnp.zeros(n)
    for j in range(n):
        terms = xi[j] - jnp.delete(xi, j)
        product = jnp.prod(terms)
        weights = weights.at[j].set(1.0 / product)
    return weights

def collocation_ode_solver(ode_system, initial_conditions, t, N, spacing="Chebyshev", extra_params={}):
    T_start, T_end = t[0], t[-1]
    if spacing == "Chebyshev":
        collocation_points = chebyshev_nodes_second_kind(N, T_start, T_end)
    else:
        collocation_points = jnp.linspace(T_start, T_end, N)
        
    phi = jnp.eye(N)
    weights = compute_weights(collocation_points)
    dphi_dt = lagrange_derivative(collocation_points, weights)

    A = ode_system(dphi_dt, phi, extra_params)

    b = jnp.zeros(2 * N)
    b_aug = jnp.concatenate([b, jnp.array(initial_conditions)])

    I_x1 = jnp.zeros((1, 2 * N))
    I_x1 = I_x1.at[0, :N].set(phi[0, :])

    I_x2 = jnp.zeros((1, 2 * N))
    I_x2 = I_x2.at[0, N:].set(phi[0, :])

    A_aug = jnp.vstack([A, I_x1, I_x2])

    c = solve(A_aug.T @ A_aug, A_aug.T @ b_aug)

    c1 = c[:N]
    c2 = c[N:]
    
    lb = jnp.transpose(lagrange_basis(collocation_points, t))
    x1 = lb @ c1
    x2 = lb @ c2

    solution = jnp.vstack([x1, x2]).T
    return solution

def init_mlp(layer_sizes, key):
    keys = random.split(key, len(layer_sizes))
    params = [(random.normal(k, (m, n)) * jnp.sqrt(2.0 / m), jnp.zeros(n)) for k, m, n in zip(keys, layer_sizes[:-1], layer_sizes[1:])]
    return params

def forward(params, x):
    for w, b in params[:-1]:
        x = jnp.tanh(jnp.dot(x, w) + b)
    final_w, final_b = params[-1]
    return jnp.dot(x, final_w) + final_b

def neural_ode(params, t, state, extra_params):
    return forward(params, state)

@jit
def loss_fn(params, t, states, true_states, extra_params):
    preds = jnp.array([collocation_ode_solver(lambda dphi_dt, phi, extra_params: neural_ode(params, t, state, extra_params),
                                               state, t, len(t), extra_params=extra_params) for state in states])
    return jnp.mean((preds - true_states) ** 2)

def main():
    key = random.PRNGKey(0)
    
    layer_sizes = [2, 64, 64, 2]
    params = init_mlp(layer_sizes, key)

    example_ode_system = lambda dphi_dt, phi, params: jnp.block([
        [dphi_dt, -phi],
        [params['omega']**2 * phi, dphi_dt]
    ])
    
    t_span = jnp.linspace(0, 10, 100)
    N = 20
    omega = 2.0
    initial_conditions = (1.0, 0.0)
    extra_params = {'omega': omega}
    
    true_solution = collocation_ode_solver(example_ode_system, initial_conditions, t_span, N, extra_params=extra_params)
    
    opt_init, opt_update, get_params = optimizers.adam(1e-3)
    opt_state = opt_init(params)
    
    @jit
    def step(i, opt_state):
        params = get_params(opt_state)
        value, grads = value_and_grad(loss_fn)(params, t_span, [initial_conditions], true_solution, extra_params)
        opt_state = opt_update(i, grads, opt_state)
        return opt_state, value

    num_steps = 1000
    for i in range(num_steps):
        opt_state, value = step(i, opt_state)
        if i % 100 == 0:
            print(f"Step {i}, Loss: {value}")

    trained_params = get_params(opt_state)

    pred_solution = collocation_ode_solver(lambda dphi_dt, phi, params: neural_ode(trained_params, t_span, (dphi_dt, phi), params),
                                           initial_conditions, t_span, N, extra_params=extra_params)
    
    plt.plot(t_span, true_solution[:, 0], label='True x1')
    plt.plot(t_span, true_solution[:, 1], label='True x2')
    plt.plot(t_span, pred_solution[:, 0], '--', label='Predicted x1')
    plt.plot(t_span, pred_solution[:, 1], '--', label='Predicted x2')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    main()


ImportError: cannot import name 'optimizers' from 'jax.experimental' (/Users/mariiashapo/anaconda3/envs/collocation_env/lib/python3.9/site-packages/jax/experimental/__init__.py)