In [None]:
import jax.numpy as jnp
from jax import jit
from jax import grad , jacfwd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from space_traj_opt.plotting import plot

In [None]:
thrust = 210000.0
g0 = 1.61544
STANDARD_GRAV = 9.81
Isp = 265.2 
mdot = thrust / STANDARD_GRAV / Isp

vch_params = (thrust, g0, mdot)

x0 = jnp.array([0,0,0,0,50000])

In [None]:
@jit
def dynamics(x, u, params):

    thrust = params[0]
    g0 = params[1]
    mdot = params[2]
    u_ = jnp.append(u, u[-1])
    cos_theta  = jnp.cos(u_)
    sin_theta  = jnp.sin(u_)
    dx = jnp.zeros_like(x)
    dx =dx.at[:, 0].set(x[:,2])
    dx =dx.at[:, 1].set(x[:,3])
    dx =dx.at[:, 2].set((thrust * cos_theta) / x[:,4] )
    dx =dx.at[:, 3].set((thrust * sin_theta)  / x[:,4] - g0)
    dx =dx.at[:, 4].set(-80.71918487460852) 
    return dx


# Define the dynamics defects
# @jit
def state_defects(decision_variables: jnp.array, args):
    N = args[0] 
    states_dim = args[1] 
    params = args[2]
    # Index [0 N )in the decision_variables vector contains the control inputs 
    u = decision_variables[:N]
    # Index [N -1] in the decision_variables vector contains the state  
    x = decision_variables[N:-1].reshape((N+1, states_dim))
    dt = decision_variables[-1] # Time is the last decision var here
    # Calculate the dynamics
    x_dot = dynamics(x, u, params)
    # Calculate the approximation of integral using trapezoidal quadrature
    integral = ((x_dot[:-1] + x_dot[1:])) / 2 * dt
    # Calculate the state defects
    defects = x[1:] - x[:-1]  - integral
    return defects.transpose().flatten()

In [None]:
state_defects_jac = jacfwd(state_defects)

In [None]:
# Define the initial and final states 
x0 = jnp.array([0,0,0,0,50000])
terminal_x = [None, 1.8500e+05, 1.6270e+03, 0, None]
_, xf_y, xf_vx, xf_vy, _ = terminal_x
states_dim = 5
N = 25 # Number of collocation points, 
full_params = (N,states_dim, vch_params)

#bounds of pitch angle
u_b = jnp.pi/2
l_b = -jnp.pi/2

In [None]:
mdot

In [None]:
mdot = thrust / STANDARD_GRAV / Isp
t_max = 50000 / mdot
dt_max = t_max/ N

In [None]:
# Initial guess for control inputs
t_init  = jnp.linspace(0, 450, N+1) 

u_init = jnp.linspace(1.3, -0.76, N) # Decision variable

dt_init = t_init[1] - t_init[0]  # time step

# Initial guess for states
x_init = jnp.zeros((N+1, states_dim))
x_init =  x_init.at[:, 0].set(jnp.linspace(0., 215e3, N+1))
x_init =  x_init.at[:, 1].set(jnp.linspace(0., xf_y, N+1))
x_init =  x_init.at[:, 2].set(jnp.linspace(0., xf_vx, N+1))
x_init =  x_init.at[:, 3].set(jnp.linspace(0., xf_vy, N+1))
x_init =  x_init.at[:, 4].set(jnp.linspace(50000., 0.2 *  50000., N+1))

# Concatenate control inputs and states into a single decision variable
initial_guess = jnp.concatenate([u_init, x_init.flatten(), jnp.array([dt_init])])

In [None]:
#@jit
def objective(decision_variables, args):
    N = args[0] 
    #u = decision_variables[:N]
    # Index [N -1] in the decision_variables vector contains the state  
    x = decision_variables[N:-1].reshape((N+1, states_dim))
    dt = decision_variables[-1] # Time is the last decision var here
    return  100.*dt *float(N) #+ jnp.sum(defects )**2 - jnp.sum((x[:,4]/500.0)**2) 
#
objective_grad = grad(objective) 


# Define the bounds for the decision variables
bounds = [(l_b, u_b)] * N # this case it is the control parameter for the trajectory, the pitch command
num_state_bounds = states_dim*(N+1)

state_bounds = [(0, np.inf)] * num_state_bounds

# for i in range(0,N+1):
#     state_bounds[states_dim*i] = (0,None)

xf_y, xf_vx, xf_vy
bounds = bounds + state_bounds    
#Enforcing Bound constraint on initial and final states
bounds[N]    = (0.0,0.0) # x 
bounds[N+1]  = (0.0,0.0) # y
bounds[N+2]  = (0.0, 0.0) # v_x 
bounds[N+3] = (0.0, 0.0) # v_y 
bounds[N+4] = (50000, 50000) # mass

# Final Bounds
bounds[N+ num_state_bounds - states_dim + 0] = (0.0,np.inf) # x 
bounds[N+ num_state_bounds - states_dim + 1] = (xf_y,xf_y) # y
bounds[N+ num_state_bounds - states_dim + 2] = (xf_vx, xf_vx) # v_x 
bounds[N+ num_state_bounds - states_dim + 3] = (xf_vy, xf_vy) # v_y 
bounds[N+ num_state_bounds - states_dim + 4] = (0, 30000) # mass
bounds = bounds + [(dt_max/2, dt_max)] # Adding time bound 

arguments = (full_params,)


# Define the constraints
constraints = [{'type': 'eq', 'fun': state_defects, 'jac': state_defects_jac, 'args':arguments },]


In [None]:
result = minimize(
    objective, 
    initial_guess, 
    method='SLSQP', 
    bounds=bounds, 
    args=arguments, 
    constraints=constraints, 
    jac=objective_grad, 
    options = {"maxiter": 500, "disp": True},
    tol= 1e-2
    )
print(result)


In [None]:
u = result.x[:N]
# Index [N -1] in the decision_variables vector contains the state  
x = result.x[N:-1].reshape((N+1, states_dim))
dt = result.x[-1] # Time

In [None]:
t = jnp.linspace(0, N*dt, N+1) 

In [None]:
plot(
    t, [u],
    title="Time vs Pitch steering", 
    xlabel="Time", 
    ylabel="theta",
    )

In [None]:
plot(
    t,[x[:,0], x[:,1]], y2 = [x[:,2], x[:,3]],
    title="Time vs States", 
    xlabel="Time", 
    ylabel=("Pos", "Vel"),
    trace_names=("pos_x", "pos_y", "vel_x", "vel_y")
    )

In [None]:
plot( t, [x[:,4]], title="time vs Mass", xlabel="Time", ylabel="Mass")