In [1]:
import jax
import jax.numpy as jnp
from discretise.fn_3 import discretise_integral

In [2]:
def lagrangian(q, dq, t):
    return 0.5 * jnp.dot(q, q) - 0.5 * jnp.dot(dq, dq)

In [8]:
def fill_out_initial(initial, r):
    return jnp.repeat(initial[jnp.newaxis, :], r + 2, axis=0)

In [9]:
r = 1
dt = 0.1

In [10]:
lagrangian_d = discretise_integral(
    fn=lagrangian,
    r=r,
    dt=dt
)

In [11]:
from form_eoms.form_and_solve import single_step

qiv = single_step(
    qi=fill_out_initial(jnp.array([1.0, 0.0]), r),
    t0=0.0,
    r=r,
    f_d=lagrangian_d
)
qiv

Solver: GaussNewton, Error: 0.0010214717751522112
Solver: GaussNewton, Error: 4.300430588401906e-17


Array([[0.99958264, 0.        ],
       [1.00083368, 0.        ],
       [0.99958264, 0.        ]], dtype=float64)

In [12]:
single_step(
    qi=fill_out_initial(qiv[-1], r),
    t0=0.0 + dt,
    r=r,
    f_d=lagrangian_d
)

Solver: GaussNewton, Error: 0.0010210454518177693
Solver: GaussNewton, Error: 3.986372859861846e-17


Array([[0.99916545, 0.        ],
       [1.00041597, 0.        ],
       [0.99916545, 0.        ]], dtype=float64)

In [13]:
from jax import Array

def iterate(
        q0: Array,
        t0: float,
        dt: float,
        t_samples: int,
        r: int,
        lagrangian: callable
):
    lagrangian_d = discretise_integral(
        fn=lagrangian,
        r=r,
        dt=dt
    )

    return jax.lax.fori_loop(
        0,
        t_samples,
        lambda i, q: jax.lax.concatenate([
            q[:-1],
            single_step(
                q0=q[-1],
                t0=t0 + i * dt,
                r=r,
                f_d=lagrangian_d
            )
        ], dimension=0),
        jax.numpy.array([q0])
    )

In [14]:
iterate(
    q0=jnp.array([1.0, 0.0]),
    t0=0.0,
    dt=0.1,
    t_samples=10,
    r=1,
    lagrangian=lagrangian
)

TypeError: single_step() got an unexpected keyword argument 'q0'

In [15]:
t_sample_count = 10
dt = 0.1
t0 = 0
t_samples = t0 + jnp.arange(t_sample_count) * dt

In [16]:
jnp.arange(t_sample_count) * dt

Array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],      dtype=float64, weak_type=True)

In [19]:
lagrangian_d = discretise_integral(
    fn=lagrangian,
    r=r,
    dt=dt
)

def scan_body(
        initial_qi_values,
        t_value
):
    jax.debug.print("previous_state {}", initial_qi_values)
    jax.debug.print("t_value {}", t_value)

    qi_values = single_step(
        qi=initial_qi_values,
        t0=t_value,
        r=r,
        f_d=lagrangian_d
    )
    
    next_qi_values = fill_out_initial(qi_values[-1], r=r)
    
    return next_qi_values, qi_values

jax.lax.scan(
    f=scan_body,
    xs=t_samples,
    init=fill_out_initial(
        initial=jnp.array([1.0,1.0,0.0]),
        r=r
    )
)

previous_state [[1. 1. 0.]
 [1. 1. 0.]
 [1. 1. 0.]]
t_value 0.0
Solver: GaussNewton, Error: 0.0014445792380015777
Solver: GaussNewton, Error: 6.081727262162085e-17
previous_state [[0.99958264 0.99958264 0.        ]
 [0.99958264 0.99958264 0.        ]
 [0.99958264 0.99958264 0.        ]]
t_value 0.1
Solver: GaussNewton, Error: 0.0014439763257600535
Solver: GaussNewton, Error: 5.637582563092644e-17
previous_state [[0.99916545 0.99916545 0.        ]
 [0.99916545 0.99916545 0.        ]
 [0.99916545 0.99916545 0.        ]]
t_value 0.2
Solver: GaussNewton, Error: 0.001443373665151052
Solver: GaussNewton, Error: 6.94702425749523e-17
previous_state [[0.99874844 0.99874844 0.        ]
 [0.99874844 0.99874844 0.        ]
 [0.99874844 0.99874844 0.        ]]
t_value 0.30000000000000004
Solver: GaussNewton, Error: 0.0014427712560696247
Solver: GaussNewton, Error: 7.788478667053632e-17
previous_state [[0.9983316 0.9983316 0.       ]
 [0.9983316 0.9983316 0.       ]
 [0.9983316 0.9983316 0.       ]]

(Array([[0.99583421, 0.99583421, 0.        ],
        [0.99583421, 0.99583421, 0.        ],
        [0.99583421, 0.99583421, 0.        ]], dtype=float64),
 Array([[[0.99958264, 0.99958264, 0.        ],
         [1.00083368, 1.00083368, 0.        ],
         [0.99958264, 0.99958264, 0.        ]],
 
        [[0.99916545, 0.99916545, 0.        ],
         [1.00041597, 1.00041597, 0.        ],
         [0.99916545, 0.99916545, 0.        ]],
 
        [[0.99874844, 0.99874844, 0.        ],
         [0.99999844, 0.99999844, 0.        ],
         [0.99874844, 0.99874844, 0.        ]],
 
        [[0.9983316 , 0.9983316 , 0.        ],
         [0.99958107, 0.99958107, 0.        ],
         [0.9983316 , 0.9983316 , 0.        ]],
 
        [[0.99791493, 0.99791493, 0.        ],
         [0.99916389, 0.99916389, 0.        ],
         [0.99791493, 0.99791493, 0.        ]],
 
        [[0.99749844, 0.99749844, 0.        ],
         [0.99874687, 0.99874687, 0.        ],
         [0.99749844, 0.9974984

In [10]:
lagrangian_d(
    jnp.array([1.0, 0.0, 0.0]),
    0.0
)

Array(-11.65833333, dtype=float64)

In [21]:
def scan_body_2(
        previous_q,
        t_value
):
    jax.debug.print("previous_state {}", previous_q)
    jax.debug.print("t_value {}", t_value)

    qi_values = single_step(
        qi=fill_out_initial(previous_q, r=r),
        t0=t_value,
        r=r,
        f_d=lagrangian_d
    )

    return qi_values[-1], qi_values


jax.lax.scan(
    f=scan_body_2,
    xs=t_samples,
    init=jnp.array([1.0, 1.0, -0.1]),
)

previous_state [ 1.   1.  -0.1]
t_value 0.0
Solver: GaussNewton, Error: 0.001448186183037093
Solver: GaussNewton, Error: 6.085530614726997e-17
previous_state [ 0.99958264  0.99958264 -0.09995826]
t_value 0.1
Solver: GaussNewton, Error: 0.00144758176539437
Solver: GaussNewton, Error: 5.641837645972793e-17
previous_state [ 0.99916545  0.99916545 -0.09991655]
t_value 0.2
Solver: GaussNewton, Error: 0.0014469776000124678
Solver: GaussNewton, Error: 6.962214183774937e-17
previous_state [ 0.99874844  0.99874844 -0.09987484]
t_value 0.30000000000000004
Solver: GaussNewton, Error: 0.0014463736867861728
Solver: GaussNewton, Error: 7.800050506885826e-17
previous_state [ 0.9983316   0.9983316  -0.09983316]
t_value 0.4
Solver: GaussNewton, Error: 0.0014457700256101992
Solver: GaussNewton, Error: 4.076956855911779e-17
previous_state [ 0.99791493  0.99791493 -0.09979149]
t_value 0.5
Solver: GaussNewton, Error: 0.001445166616379374
Solver: GaussNewton, Error: 7.639355691558635e-17
previous_state [ 0.

(Array([ 0.99583421,  0.99583421, -0.09958342], dtype=float64),
 Array([[[ 0.99958264,  0.99958264, -0.09995826],
         [ 1.00083368,  1.00083368, -0.10008337],
         [ 0.99958264,  0.99958264, -0.09995826]],
 
        [[ 0.99916545,  0.99916545, -0.09991655],
         [ 1.00041597,  1.00041597, -0.1000416 ],
         [ 0.99916545,  0.99916545, -0.09991655]],
 
        [[ 0.99874844,  0.99874844, -0.09987484],
         [ 0.99999844,  0.99999844, -0.09999984],
         [ 0.99874844,  0.99874844, -0.09987484]],
 
        [[ 0.9983316 ,  0.9983316 , -0.09983316],
         [ 0.99958107,  0.99958107, -0.09995811],
         [ 0.9983316 ,  0.9983316 , -0.09983316]],
 
        [[ 0.99791493,  0.99791493, -0.09979149],
         [ 0.99916389,  0.99916389, -0.09991639],
         [ 0.99791493,  0.99791493, -0.09979149]],
 
        [[ 0.99749844,  0.99749844, -0.09974984],
         [ 0.99874687,  0.99874687, -0.09987469],
         [ 0.99749844,  0.99749844, -0.09974984]],
 
        [[ 0.99708