In [17]:
from numba import jit, float64
import numpy as np
import pickle
import sys
sys.path.append('../')

In [18]:
@jit(nopython=True)
def _dynamics(r, v, z, u, dt, g, alpha):
    dt22 = dt ** 2 / 2.0
    mass = np.exp(z)
    a = u / mass + g

    # Compute next state
    r_next = r + dt * v + dt22 * a
    v_next = v + dt * a
    z_next = z - dt * alpha * np.linalg.norm(u) / mass

    return r_next, v_next, z_next

@jit(nopython=True)
def _propagate_state(x0, u, N, dt, g, alpha):
    r = np.zeros((N + 1, 3))
    v = np.zeros((N + 1, 3))
    z = np.zeros(N + 1)

    r[0] = x0[:3]
    v[0] = x0[3:6]
    z[0] = x0[6]

    for i in range(N):
        r[i + 1], v[i + 1], z[i + 1] = _dynamics(r[i], v[i], z[i], u[i], dt, g, alpha)
    
    return r, v, z

@jit(nopython=True)
def _get_cstr(r, v, z, u, N, rho1, rho2, pa, gsa, vmax):
    cstr_eq = np.zeros(4)
    cstr_ineq = np.zeros(5 * N + 3)

    # Equality constraints
    cstr_eq[0] = r[N, 2]
    cstr_eq[1:] = v[N, :]

    # Inequality constraints
    i_ieq = 0
    # Thrust bounds
    for i in range(N):
        cstr_ineq[i_ieq] = rho1 - np.linalg.norm(u[i])
        cstr_ineq[i_ieq + 1] = np.linalg.norm(u[i]) - rho2
        i_ieq += 2

    # Pointing angle constraint
    for i in range(N):
        cstr_ineq[i_ieq] = np.linalg.norm(u[i]) * np.cos(pa) - u[i, 2]
        i_ieq += 1

    # Glide slope constraint
    for i in range(N+1):
        cstr_ineq[i_ieq] = np.linalg.norm(r[i, :2] - r[-1, :2]) * np.tan(gsa) - (r[i, 2] - r[-1, 2])
        i_ieq += 1

    # Velocity constraint
    for i in range(N):
        cstr_ineq[i_ieq] = np.linalg.norm(v[i]) - vmax
        i_ieq += 1

    return cstr_eq, cstr_ineq



In [20]:
with open('../saved/controllable_set/lander.pkl', 'rb') as f:
    lander = pickle.load(f)

N = 60
tgo = 60.0
dt = tgo / N
alt = 1500.0
mass = 1800.0
x0 = np.array([0, 0, alt, -30.0, 0, -55.0, np.log(mass)])

u = np.zeros((N, 3))


# measure time
import time
t0 = time.time()

for i in range(1000):
    r, v, z = _propagate_state(x0, u, N, dt, lander.g, lander.alpha)
    cstr_eq, cstr_ineq = _get_cstr(r, v, z, u, N, lander.rho1, lander.rho2, lander.pa, lander.gsa, lander.vmax)
t1 = time.time()
print("Mean time: ", (t1 - t0) / 1000.0)

Mean time:  0.0012925522327423096
