In [9]:
import jax
import jax.numpy as jnp
from jax import jit, jacfwd, lax
import numpy as np
from jax import random

In [10]:
C = np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1]])

In [11]:
key1, key2 = random.split(random.PRNGKey(0))

In [12]:
@jit
def dynamics(state, inputs, g, m):
    x, y, z, vx, vy, vz, roll, pitch, yaw = state
    curr_thrust, curr_rolldot, curr_pitchdot, curr_yawdot = inputs

    sr = jnp.sin(roll)
    sy = jnp.sin(yaw)
    sp = jnp.sin(pitch)
    cr = jnp.cos(roll)
    cp = jnp.cos(pitch)
    cy = jnp.cos(yaw)

    vxdot = -(curr_thrust / m) * (sr * sy + cr * cy * sp)
    vydot = -(curr_thrust / m) * (cr * sy * sp - cy * sr)
    vzdot = g - (curr_thrust / m) * (cr * cp)

    return jnp.array([vx, vy, vz, vxdot, vydot, vzdot, curr_rolldot, curr_pitchdot, curr_yawdot])

# Function to integrate dynamics over time
@jit
def integrate_dynamics(state, inputs, integration_step, integrations_int, g, m):
    def for_function(i, current_state):
        return current_state + dynamics(current_state, inputs, g, m) * integration_step

    pred_state = lax.fori_loop(0, integrations_int, for_function, state)
    return pred_state

# Prediction function
@jit
def predict_states(state, last_input, T_lookahead, g, m, integration_step=0.1):
    inputs = last_input.flatten()
    integrations_int = 8 #int(T_lookahead / integration_step)
    pred_state = integrate_dynamics(state, inputs, integration_step, integrations_int, g, m)
    return pred_state

# Prediction function
@jit
def predict_outputs(state, last_input, T_lookahead, g, m, C, integration_step=0.1):
    inputs = last_input.flatten()
    integrations_int = 8 #int(T_lookahead / integration_step)
    pred_state = integrate_dynamics(state, inputs, integration_step, integrations_int, g, m)
    return C@pred_state

In [15]:
def get_outputs(num_data_gen=100):
  allData = np.empty((0, 17))
  # print(allData)
  for i in range(num_data_gen):
    curr_thrust = random.uniform(key1, minval=0.0, maxval=27)
    curr_rolldot, curr_pitchdot, curr_yawdot = random.uniform(key1, minval=-.9, maxval=.9), random.uniform(key1, minval=-.9, maxval=.9), random.uniform(key1, minval=-.9, maxval=.9)

    xx = 2.0
    curr_x, curr_y, curr_z = random.uniform(key1, minval=-xx, maxval=xx), random.uniform(key1, minval=-xx, maxval=xx), random.uniform(key1, minval=-xx, maxval=xx)

    vv = 1.0
    curr_vx, curr_vy, curr_vz = random.uniform(key1, minval=-vv, maxval=vv), random.uniform(key1, minval=-vv, maxval=vv), random.uniform(key1, minval=-vv, maxval=vv)

    ll = jnp.pi
    curr_roll, curr_pitch, curr_yaw = random.uniform(key1, minval=-ll, maxval=ll), random.uniform(key1, minval=-ll, maxval=ll), random.uniform(key1, minval=-ll, maxval=ll)

    STATE = jnp.array([curr_x, curr_y, curr_z, curr_vx, curr_vy, curr_vz, curr_roll, curr_pitch, curr_yaw])
    INPUT = jnp.array([curr_thrust, curr_rolldot, curr_pitchdot, curr_yawdot])

    outputs = predict_outputs(STATE, INPUT, T_lookahead, g, m, C, integration_step)


    outputs2 = outputs.squeeze()
    # print(f"{outputs2.shape}")
    # print(f"{STATE.shape}")
    # print(f"{INPUT.shape}")
    # print(f"\n")
    finalData = jnp.hstack((STATE, INPUT, outputs2))
    allData = np.vstack((allData, finalData))
  return allData

In [16]:
T_lookahead = 0.8
g = 9.806
m = 1.53
integration_step = 0.1
data = get_outputs(500)

In [17]:
sim = True

In [18]:
np.save('50k_iris_sim_expanded_more5' if sim else '50k_holybro_expanded_more4', data)
print(f"Saved under: {'50k_iris_sim_expanded_more5' if sim else '50k_holybro_expanded_more4'} because sim is {sim}")

Saved under: 50k_iris_sim_expanded_more5 because sim is True
