# <span style="color:#2E86C1;">Algorithm Comparison of SAC and PPO</span>

## <span style="color:#E74C3C;">Notebook Purpose</span>

This notebook is designed to compare the performance of two actor-critic algorithms in a waypoint navigation task:

- <span style="color:#E74C3C;"><b>Model Evaluation</b></span>  
  - Load the **trained models** for both actor-critic algorithms.

- <span style="color:#E74C3C;"><b>Waypoint Navigation Task</b></span>  
  - Generate a <span style="color:#8E44AD;">random waypoint</span> as the navigation target.  
  - Allow both agents to attempt reaching the <span style="color:#8E44AD;">designated waypoint</span>.  

- <span style="color:#E74C3C;"><b>Trajectory Analysis & Comparison</b></span>  
  - If both agents successfully reach the waypoint, their **trajectories are saved**.  
  - Visualize the agent **trajectories in a 3D vector field plot** for direct comparison.  


## <span style="color:#27AE60;">Imports:</span>

In [1]:
from Evaluation.vis_model import comp_model_performance, plotly_vector_field
from stable_baselines3 import PPO, SAC, DDPG
import gymnasium as gym
import json

## <span style="color:#27AE60;">Load the trained models:</span>

In [2]:
env_id = "SingleWaypointQuadXEnv-v0"
control_mode = "thrust"
assert control_mode in ["angular", "thrust"], "Invalid control mode"

# Load the best trained models
ddpg_path = f"../../models/{control_mode}_control/ddpg_{control_mode}_best"
ppo_path = f"../../models/{control_mode}_control/ppo_{control_mode}_best"
sac_path = f"../../models/{control_mode}_control/sac_{control_mode}_best"

#  Initialize the environments and load the models
ddpg_env = gym.make(env_id, render_mode=None, reward_shift=0.75, flight_mode=-1 if control_mode == "thrust" else 1)
_ = ddpg_env.reset()
ddpg = DDPG.load(ddpg_path, deterministic=True)
ddpg.set_env(ddpg_env)

ppo_env = gym.make(env_id, render_mode=None, reward_shift=0.75, flight_mode=-1 if control_mode == "thrust" else 1)
_ = ppo_env.reset()
ppo = PPO.load(ppo_path, deterministic=True)
ppo.set_env(ppo_env)

sac_env = gym.make(env_id, render_mode=None, reward_shift=0.75, flight_mode=-1 if control_mode == "thrust" else 1)
_ = sac_env.reset()
sac = SAC.load(sac_path, deterministic=True)
sac.set_env(sac_env)

pybullet build time: Oct  3 2024 08:55:45


Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A
Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A
Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A


## <span style="color:#27AE60;">Run the comparison and save in 'results':</span>

In [7]:
result = comp_model_performance(ddpg_model=ddpg, ppo_model=ppo, sac_model=sac, render=False, result_file_path=f"comparison_results_{control_mode}.json")

Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A
Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A
Reward function initialized with r_LOS_weight: 1.0, r_smooth_weight: 0.0
Shifting reward to the range: [-0.25, 0.75]
[A                             [A
[A                             [A


## <span style="color:#27AE60;">Visualising the trajectories:</span>

In [7]:
with open(f"comparison_results_{control_mode}.json", "r") as f:
    comp_results = json.load(f)

ep_idx = 2
lin_vels = [comp_results["ppo"]["linear_velocity"][ep_idx], comp_results["sac"]["linear_velocity"][ep_idx]]
lin_pos = [comp_results["ppo"]["linear_position"][ep_idx], comp_results["sac"]["linear_position"][ep_idx]]
    
plotly_vector_field(lin_pos, lin_vels, comp_results["targets"][ep_idx], size=[20.0, 30.0, 40.0], save_path=False, mode="compare")