In [None]:
import numpy as np
from  numpy import deg2rad as d2r
from  numpy import array as arr

import numba
from scipy.integrate import solve_ivp, OdeSolution
from scipy.optimize import minimize, OptimizeResult
from functools import lru_cache
from models import dynamics, CtrlMode, STANDARD_GRAV
from transcription import MultiShootingTranscription

from plotting import plot

In [None]:
mu_earth = 3.986004418e14
earth_r = 6_378_000 # m
circ_orbit_alt = 200_000 
v_circ = np.sqrt(mu_earth / (earth_r + circ_orbit_alt))

In [None]:
n_engines_s1 = 9
n_engines_s2 = 1
isp_s1 = 311
engine_thrust_s1 = n_engines_s1*24910.04  # N Average between sl and vac
isp_s2 = 343
engine_thrust_s2 = n_engines_s2* 25_000  # N
s1_vch_params = (engine_thrust_s1, isp_s1)
s2_vch_params = (engine_thrust_s2, isp_s2)

fairing_mass = 50
farinig_timing = 184 - 162 # sec
payload = 250
s1_dry_mass = 1076.47308279  
s2_dry_mass = 257.90093739  

s1_wet_mass = 10047.082106064723
s2_wet_mass = 2602.454913676189
total_mass = 12949.537019740912

mdot_s1 = engine_thrust_s1 / STANDARD_GRAV / isp_s1
mdot_s2 = engine_thrust_s2 / STANDARD_GRAV / isp_s2

mdot_s2

In [None]:
farinig_timing * mdot_s2

In [None]:

NUM_X= 5
NUM_U = 5
NUM_PHASE = 4

# Guesses 
x0 = arr([
    [0,0,0,0,total_mass], 
    [7.5,390,1.5,80,12200],
    [30000,60000,2000,500,s2_wet_mass + payload + fairing_mass],
    [45000,80000,2800,1000,s2_wet_mass + payload - farinig_timing * mdot_s2]
])

x_f = arr([200_000, circ_orbit_alt, v_circ, 0.0, s2_dry_mass + payload])

In [None]:
problem = MultiShootingTranscription(["phase0", "phase1", "phase2", "phase3"], 5)

problem.set_dynamics_params("phase0", s1_vch_params)
problem.set_dynamics_params("phase1", s1_vch_params)
problem.set_dynamics_params("phase2", s2_vch_params)
problem.set_dynamics_params("phase3", s2_vch_params)

In [None]:
problem.set_phase_init_x("phase0", x0 = x0[0], bounds= x0[0])
problem.set_phase_control("phase0", CtrlMode.ANGLE_STEER , u0 = d2r(89.5), bounds = [(d2r(85), d2r(89.8))])
problem.set_phase_time("phase0", t0 = 10, bounds = 10)

problem.set_phase_init_x("phase1", x0 = x0[1])
problem.set_phase_control("phase1", CtrlMode.ZERO_ALPHA , u0 = [])
problem.set_phase_time("phase1", t0 = 100, bounds=(60, 180))

problem.set_phase_init_x("phase2", x0 = x0[2])
problem.set_phase_control("phase2", CtrlMode.LTS , u0 = arr([-0.001,1]), bounds = [(-0.1,0.1), (-3,3)])
problem.set_phase_time("phase2", t0 = farinig_timing, bounds = farinig_timing)

problem.set_phase_init_x("phase3", x0 = x0[3])
problem.set_phase_control("phase3", CtrlMode.LTS , u0 = arr([-0.001,1]), bounds = [(-0.1,0.1), (-3,3)])
problem.set_phase_time("phase3", t0 = 320)

problem.set_non_zero_defect(("phase1", "phase2"), arr([0,0,0,0,s1_dry_mass]))
problem.set_non_zero_defect(("phase2", "phase3"), arr([0,0,0,0,fairing_mass]))

problem.set_terminal_state(x_final = x_f, bounds = arr([None, circ_orbit_alt, v_circ, 0, None ]))

In [None]:
d0, d_bounds, full_params = problem.build()

In [None]:
def unpack_decision_var(decision_var, config, phase_id):
    control_law = config[0]
    u = decision_var[:NUM_U][range(*config[1])]
    x = decision_var[NUM_U:(NUM_PHASE+1)*NUM_X + NUM_U].reshape(NUM_PHASE+1, NUM_X)[phase_id]
    t_terminal = decision_var[-NUM_PHASE:][phase_id] # Time is the last decision var here
    return (u, x, t_terminal, control_law)


# @lru_cache(maxsize=128, typed=True)
def traj_rollout(t_terminal, x0, params: tuple) -> OdeSolution:
    
    sol = solve_ivp(
        dynamics, 
        t_span=[0.0, t_terminal], 
        t_eval= np.linspace(0.0, t_terminal), # This greatly improves convergance and stability of the jac
        y0=x0,    
        args=(params,)
    )
    return sol  

def full_traj_rollout(d0, config_list):
    sol_list = []
    for  phase_id, config in enumerate(config_list):

        u, x, t_terminal, control_law = unpack_decision_var(d0,config, phase_id=phase_id)
        vch_params = (config[3],(control_law, u))
        #if sol is not None:
        #    x = sol.y[:,-1]
        sol = traj_rollout(t_terminal, x, vch_params)
        sol_list.append(sol)
    return sol_list

def dynamics_knot_constrant(d0, config_list): 
    defect_vector_list = []
    sol_list=  full_traj_rollout(d0, config_list)
    for idx in range(1,NUM_PHASE):
        _,_, knot_defect,_ = config_list[idx]
        defect_sub_vector = sol_list[idx].y[:,0] - sol_list[idx-1].y[:,-1] + knot_defect
        defect_vector_list.append(defect_sub_vector)
    # Terminal Defect 
    terminal_state = d0[NUM_U:(NUM_PHASE+1)*NUM_X + NUM_U].reshape(NUM_PHASE+1, NUM_X)[-1]
    terminal_defect = terminal_state - sol_list[-1].y[:,-1]
    defect_vector_list.append(terminal_defect)
    defect_vec = arr(defect_vector_list).flatten()
    return defect_vec

In [None]:
constraints = [{'type': 'eq', 'fun': dynamics_knot_constrant, 'args':(full_params,) },]

In [None]:
def objective(decision_var: tuple, params: tuple) -> float:
        """Objective function for min prop

        Args:
            decision_var : Optimization problem decision vector
            params : 

        Returns:
            Cost to minimize
        """
        x = decision_var[NUM_U:(NUM_PHASE+1)*NUM_X + NUM_U].reshape(NUM_PHASE+1, NUM_X)
        terminal_state = x[-1]
        terminal_mass= terminal_state[-1]
        return -terminal_mass*terminal_mass


# def jac_objective(decision_var: tuple, params: tuple):
#         jac = np.zeros_like(decision_var)
#         val = -decision_var[29] - decision_var[29]
#         jac[29]= val
#         return jac

In [None]:
result = minimize(
    objective, 
    d0, 
    # jac= jac_objective,
    method='SLSQP', 
    bounds=d_bounds, 
    constraints=constraints,
    options = {"maxiter": 500, "disp": True},
    args=(full_params,)
)
result

In [None]:
sol_list = full_traj_rollout(result.x, full_params)

In [None]:
def unpack_sol_list(sol_list_in , state_index):
    t_offsets = 0
    x_list = []
    y_list = []
    for sol in sol_list_in:
        x_list.append(sol.t + t_offsets )
        y_list.append(sol.y[state_index] )
        t_offsets += sol.t[-1]

    return x_list, y_list

In [None]:
plot(
    *unpack_sol_list(sol_list,0),
    title="Time vs States", 
    xlabel="Time", 
    ylabel="Pos x",
    trace_names=("phase0", "phase1", "phase2", "phase3")
    )

In [None]:
plot(
    *unpack_sol_list(sol_list,1),
    title="Time vs States", 
    xlabel="Time", 
    ylabel="Pos y",
    trace_names=("phase0", "phase1", "phase2", "phase3")
    )


In [None]:
plot(
    *unpack_sol_list(sol_list,2),
    title="Time vs States", 
    xlabel="Time", 
    ylabel="Vel",
    trace_names=("phase0", "phase1", "phase2", "phase3")
    )

In [None]:
plot(
    *unpack_sol_list(sol_list,3),
    title="Time vs States", 
    xlabel="Time", 
    ylabel="Vel",
    trace_names=("phase0", "phase1", "phase2", "phase3")
    )

In [None]:
plot(
    *unpack_sol_list(sol_list,4),
    title="Time vs States", 
    xlabel="Time", 
    ylabel="Mass",
    trace_names=("phase0", "phase1", "phase2", "phase3")
    )