In [1]:
import jax
import jax.numpy as jnp
import immrax as irx
 
# jax.config.update('jax_enable_x64', True)

In [2]:
class Bicycle (irx.System) :
    def __init__ (self) :
        self.evolution = 'continuous'
        self.xlen = 4
    def f (self, t, x, u, w) :
        px, py, psi, v = x.ravel()
        u1, u2 = u.ravel()
        beta = jnp.arctan(jnp.tan(u2)/2)
        return jnp.array([
            v*jnp.cos(psi + beta),
            v*jnp.sin(psi + beta),
            v*jnp.sin(beta),
            u1
        ])

In [3]:
olsys = Bicycle()
embsys = irx.natemb(olsys)

In [4]:
x0 = jnp.array([10., 10., 0., 0.])
x0int = irx.icentpert(x0, 0.0)
# x0int = irx.interval(x0)
# x0int = irx.interval(
#     [-1.,-1.,0.,0.],
#     [1.,1.,0.1,0.1]
# )
print(f"x0int = {x0int}\n")
print(f"irx.i2ut(x0int)= {irx.i2ut(x0int)}")


u = jnp.array([10., 10.])
uint = irx.icentpert(u, 0.0)
print(f"uint = {uint}\n")

w = jnp.array([0., 0., 0., 0.])
print(f"olsys.f(0., x0, u, w): {olsys.f(0., x0, u, w)}\n")

print(f"embsys.E([0.], irx.i2ut(x0int), uint, w): {embsys.E([0.], irx.i2ut(x0int), uint, w)}\n")

x0int = [[(10., 10.)]
 [(10., 10.)]
 [( 0.,  0.)]
 [( 0.,  0.)]]

irx.i2ut(x0int)= [10. 10.  0.  0. 10. 10.  0.  0.]
uint = [[(10., 10.)]
 [(10., 10.)]]

olsys.f(0., x0, u, w): [ 0.  0.  0. 10.]

embsys.E([0.], irx.i2ut(x0int), uint, w): [ 0.  0.  0. 10.  0.  0.  0. 10.]



In [10]:
def u_map (t, x) :
    return irx.icentpert([10., 10.], 0.0)
 
def w_map (t, x) :
    return irx.interval([0.])

In [11]:
print(u_map(0., x0))

[[(10., 10.)]
 [(10., 10.)]]


In [12]:
t0 = 0.
tf = 10.
dt = 0.1

In [13]:
traj = embsys.compute_trajectory(
    t0, tf, irx.i2ut(x0int), (u_map,w_map), dt, solver='rk45'
)
 

In [14]:
tfinite = jnp.where(jnp.isfinite(traj.ts))
tt = traj.ts[tfinite]
xx = traj.ys[tfinite]
 
print(xx.shape)
 
print(xx[0])
print(xx[-1])

(101, 8)
[10. 10.  0.  0. 10. 10.  0.  0.]
[  7.2649727  15.820192  154.1903    100.          7.2649727  15.820192
 154.1903    100.       ]
