In [1]:
import jax
import jax.numpy as jnp
from jax import jit, jacfwd, lax
import numpy as np

In [2]:
@jit
def dynamics_1order(state, input, input_derivs, g, m):
    x, y, z, vx, vy, vz, roll, pitch, yaw, thrust, rolldot, pitchdot, yawdot = state
    thrust_dot, roll_dd, pitch_dd, yaw_dd = input_derivs

    sr = jnp.sin(roll)
    sy = jnp.sin(yaw)
    sp = jnp.sin(pitch)
    cr = jnp.cos(roll)
    cp = jnp.cos(pitch)
    cy = jnp.cos(yaw)

    # Update derivatives
    vxdot = -(thrust / m) * (sr * sy + cr * cy * sp)
    vydot = -(thrust / m) * (cr * sy * sp - cy * sr)
    vzdot = g - (thrust / m) * (cr * cp)

    return jnp.array([vx, vy, vz, vxdot, vydot, vzdot, rolldot, pitchdot, yawdot, thrust_dot, roll_dd, pitch_dd, yaw_dd])

# Function to integrate dynamics over time
@jit
def integrate_dynamics_1order(state, inputs, input_derivs, integration_step, integrations_int, g, m):
    def for_function(i, current_state):
        return current_state + dynamics_1order(current_state, inputs, input_derivs, g, m) * integration_step

    state = jnp.hstack([state, inputs])
    pred_state = lax.fori_loop(0, integrations_int, for_function, state)
    # print(f"done: {pred_state= }")
    return pred_state

# Prediction function 1st order
@jit
def predict_states_1order(state, last_input, input_derivs, T_lookahead, g, m, integration_step=0.1):
    inputs = last_input.flatten()
    integrations_int = 8  # Or another appropriate integer
    pred_state = integrate_dynamics_1order(state, inputs, input_derivs, integration_step, integrations_int, g, m)
    return pred_state[0:9]


# Prediction function
@jit
def predict_outputs_1order(state, last_input, input_derivs, T_lookahead, g, m, C, integration_step=0.1):
    inputs = last_input.flatten()
    integrations_int = 8  # Or another appropriate integer
    pred_state = integrate_dynamics_1order(state, inputs, input_derivs, integration_step, integrations_int, g, m)
    return C@pred_state[0:9]


# Compute Jacobian
@jit
def compute_jacobian_1order(state, last_input, input_derivs, T_lookahead, g, m, C, integration_step):
    jac_fn = jacfwd(lambda x: predict_outputs_1order(state, x, input_derivs, T_lookahead, g, m, C, integration_step))
    return jac_fn(last_input)

# Compute adjusted inverse Jacobian
@jit
def compute_adjusted_invjac_1order(state, last_input, input_derivs, T_lookahead, g, m, C, integration_step):
    jac = compute_jacobian_1order(state, last_input, input_derivs, T_lookahead, g, m, C, integration_step)
    inv_jacobian = jnp.linalg.pinv(jac)
    inv_jacobian_modified = inv_jacobian.at[:, 2].set(-inv_jacobian[:, 2])
    return inv_jacobian_modified

In [3]:
# Define constants
MASS = 1.535  # Mass of the object
GRAVITY = 9.806  # Gravitational acceleration
C = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0, 0],
               [0, 1, 0, 0, 0, 0, 0, 0, 0],
               [0, 0, 1, 0, 0, 0, 0, 0, 0],
               [0, 0, 0, 0, 0, 0, 0, 0, 1]])
T_LOOKAHEAD = 0.8
INITIAL_STATE = jnp.array([-0.06161616, 0.14445689, -0.00542868, 
                           -0.02907665, 0.05602227, 0.0983548, 
                            0.02056395, 0.02023059, 0.01487323])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
# Example usage
last_input = jnp.array([15.052209999999999, 0.0, 0.0, 0.0])
input_dervs = jnp.array([0.0, 0.0, 0.0, 0.0])

In [5]:
outputs = predict_outputs_1order(INITIAL_STATE, last_input, input_dervs, T_LOOKAHEAD, GRAVITY, MASS, C, integration_step=0.1)
outputs


Array([-0.1412422 ,  0.24490063,  0.07439704,  0.01487323], dtype=float32)

In [7]:
inv_jac = compute_adjusted_invjac_1order(INITIAL_STATE, last_input, input_dervs, T_LOOKAHEAD, GRAVITY, MASS, C, integration_step=0.1)

In [9]:
print(inv_jac)

[[-1.12540476e-01  1.11065082e-01  5.47986317e+00  0.00000000e+00]
 [-2.63205767e-02  1.82046723e+00 -3.74374166e-02  2.52865143e-02]
 [-1.82085347e+00 -2.70842314e-02 -3.68456393e-02 -2.57033985e-02]
 [-1.10827386e-07 -1.02445483e-08 -6.98491931e-10  1.24999952e+00]]
