Basic environemnt usage with Gymnasium, Stable Baselines3 and SAC

In [9]:
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
import os

In [2]:
# Import environment creator function
from Environment_Creator import env_creator
# Instantiate one of the custom environments
config = None
orbital_env = env_creator(config, "interplanetary")

  self.initial_orbit = Orbit.from_ephem(Earth, Ephem.from_body(Earth, self.epoch), self.epoch)  # Earth's current heliocentric orbit
  self.target_orbit = Orbit.from_ephem(Mars, Ephem.from_body(Mars, self.epoch), self.epoch)    # Mars' orbit around the Sun
  gym.logger.warn(
  gym.logger.warn(


In [3]:
# Register the custom environment with Gym
gym.register(
    id="SpacecraftInterplanetaryEnv-v0",
    entry_point=lambda config: env_creator(config, "interplanetary")
)
# Verify environment registry
gym.pprint_registry()

===== classic_control =====
Acrobot-v1                  CartPole-v0                 CartPole-v1
MountainCar-v0              MountainCarContinuous-v0    Pendulum-v1
===== phys2d =====
phys2d/CartPole-v0          phys2d/CartPole-v1          phys2d/Pendulum-v0
===== box2d =====
BipedalWalker-v3            BipedalWalkerHardcore-v3    CarRacing-v3
LunarLander-v3              LunarLanderContinuous-v3
===== toy_text =====
Blackjack-v1                CliffWalking-v0             FrozenLake-v1
FrozenLake8x8-v1            Taxi-v3
===== tabular =====
tabular/Blackjack-v0        tabular/CliffWalking-v0
===== mujoco =====
Ant-v2                      Ant-v3                      Ant-v4
Ant-v5                      HalfCheetah-v2              HalfCheetah-v3
HalfCheetah-v4              HalfCheetah-v5              Hopper-v2
Hopper-v3                   Hopper-v4                   Hopper-v5
Humanoid-v2                 Humanoid-v3                 Humanoid-v4
Humanoid-v5                 HumanoidStandup-v2    

In [11]:
# Make our environment
gym_env = gym.make("SpacecraftInterplanetaryEnv-v0", config=config)
gym_env = Monitor(gym_env)  # Logs training progress

# Instantiate the env
vec_env = make_vec_env("SpacecraftInterplanetaryEnv-v0", n_envs=1, env_kwargs=dict(config=config))

  self.initial_orbit = Orbit.from_ephem(Earth, Ephem.from_body(Earth, self.epoch), self.epoch)  # Earth's current heliocentric orbit
  self.target_orbit = Orbit.from_ephem(Mars, Ephem.from_body(Mars, self.epoch), self.epoch)    # Mars' orbit around the Sun
  self.initial_orbit = Orbit.from_ephem(Earth, Ephem.from_body(Earth, self.epoch), self.epoch)  # Earth's current heliocentric orbit
  self.target_orbit = Orbit.from_ephem(Mars, Ephem.from_body(Mars, self.epoch), self.epoch)    # Mars' orbit around the Sun


In [8]:
# Check env support with baselines3
from stable_baselines3.common.env_checker import check_env
check_env(gym_env)

In [None]:
# Define SAC model with hyperparameters
# RUN_01
model = SAC(
    "MlpPolicy",  # Uses a neural network to approximate the policy
    gym_env,
    learning_rate=3e-4,
    buffer_size=1000000,  # Large replay buffer
    batch_size=256,  # Larger batch size helps with training stability
    tau=0.005,  # Polyak averaging coefficient (for target network updates)
    gamma=0.99,  # Discount factor
    train_freq=10,  # Train every 10 steps
    gradient_steps=10,  # Update the model for 10 steps per training iteration
    verbose=1,  # Print training info
    device="auto",  # Use GPU if available
)

In [16]:
# RUN_02
model = SAC(
    "MlpPolicy",  # Uses a neural network to approximate the policy
    gym_env,
    learning_rate=3e-4,
    buffer_size=100000,  # Large replay buffer
    batch_size=500,  # Larger batch size helps with training stability
    tau=0.005,  # Polyak averaging coefficient (for target network updates)
    gamma=0.99,  # Discount factor
    train_freq=10,  # Train every 10 steps
    gradient_steps=10,  # Update the model for 10 steps per training iteration
    verbose=1,  # Print training info
    device="auto",  # Use GPU if available
)

Using cpu device
Wrapping the env in a DummyVecEnv.


In [17]:
# Setup model checkpointing
checkpoint_callback = CheckpointCallback(
    save_freq=50000,  # Save model every 50,000 steps
    save_path="./checkpoints/",
    name_prefix="sac_spacecraft",
    save_replay_buffer=True,
    save_vecnormalize=True,
)

# Train the agent for 1 million timesteps (longer training for better results)
model.learn(total_timesteps=100_000, callback=checkpoint_callback)

# Save final model
model.save("sac_spacecraft_final")
print("Training complete. Model saved.")

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.2     |
|    ep_rew_mean     | -2.4e+03 |
| time/              |          |
|    episodes        | 4        |
|    fps             | 906      |
|    time_elapsed    | 0        |
|    total_timesteps | 85       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 17.4     |
|    ep_rew_mean     | -2e+03   |
| time/              |          |
|    episodes        | 8        |
|    fps             | 73       |
|    time_elapsed    | 1        |
|    total_timesteps | 139      |
| train/             |          |
|    actor_loss      | 2.29e+06 |
|    critic_loss     | 4.55e+11 |
|    ent_coef        | 1.01     |
|    ent_coef_loss   | -0.538   |
|    learning_rate   | 0.0003   |
|    n_updates       | 30       |
---------------------------------
----------------------------------
| rollout/           |           |
|    ep_len_

KeyboardInterrupt: 

In [None]:
# Load trained model [if fully finished]
model = SAC.load("sac_spacecraft_final")

In [18]:
# Load last checkpoint
model = SAC.load("checkpoints/sac_spacecraft_50000_steps.zip")

In [19]:
obs, _ = gym_env.reset()
done = False

while not done:
    action, _ = model.predict(obs, deterministic=True)  # Use deterministic actions for evaluation
    obs, reward, done, _, _ = gym_env.step(action)
    print(f"Reward: {reward}")

print("Evaluation complete.")

Reward: -108.5066099634368
Reward: -109.3109497601294
Reward: -110.14604156855623
Reward: -110.99364056511632
Reward: -111.84747458822928
Reward: -112.70484861632626
Reward: -113.56442067474606
Reward: -114.42544942386017
Reward: -115.28749258393411
Reward: -116.15027018663285
Reward: -117.0135964734116
Reward: -217.87734338829438
Evaluation complete.


RUN_01:
Reward: -107.76274152837807
Reward: -108.56436027711736
Reward: -109.39830038893808
Reward: -110.24529011169618
Reward: -111.09875427400894
Reward: -111.95588236623631
Reward: -112.8152800644034
Reward: -113.67617920864811
Reward: -114.53812247680536
Reward: -115.40082085733859
Reward: -116.26408284858648
Reward: -217.12777658172516
Evaluation complete.