In [2]:
import jax
import numpy as np
import jax.numpy as jnp
import time
import optax
import matplotlib.pyplot as plt 
from jax.flatten_util import ravel_pytree
from functools import partial

import lagrangian as lgr
import util
import plotting
from bnn import (
    BaselineNN, 
    compute_loss, 
    train_step, 
    create_trajectory,
)

In [3]:
@jax.jit
def L_analytical(state):
    q = lgr.coordinate(state)
    v = lgr.velocity(state)
    q_flat, _ = ravel_pytree(q)
    v_flat, _ = ravel_pytree(v)
    
    T = 0.5 * jnp.sum(v_flat**2)
    V = 0.5 * jnp.sum(q_flat**2)
    return T - V

In [4]:
@jax.jit
def H_analytical(state, m=1.0, k=1.0):
    q = lgr.coordinate(state) 
    v = lgr.velocity(state) 
    
    q_flat, _ = ravel_pytree(q) 
    v_flat, _ = ravel_pytree(v) 
    
    T = 0.5 * m * jnp.sum(v_flat**2) 
    V = 0.5 * k * jnp.sum(q_flat**2) 
    return T + V 

In [5]:
q_dim = 1
hidden_dim = 128
num_epochs = 5000
key = jax.random.PRNGKey(0)

num_trajectories = 50
N_points_per_traj = 500
t_end = 25.0
t_eval = jnp.linspace(0.0, t_end, N_points_per_traj)

split_ratio = 0.5
N_points_train = int(N_points_per_traj * split_ratio)

In [6]:
ds_true = lgr.state_derivative(L_analytical)
solver_true = util.ode_solver(ds_true) #時間軸と初期値をslover_trueを与えると時間発展を計算
a_true_func = lgr.lagrangian_to_acceleration(L_analytical)
vmap_a_true_func = jax.vmap(
    lambda t, q, v : a_true_func((t, q, v)),
    in_axes=(0,0,0)
)

train_t_list, train_q_list, train_v_list, train_a_list = [], [], [], []
test_t_list, test_q_list, test_v_list, test_a_list = [], [], [], []
initial_energies_list = []