## Imports

In [1]:
import os
import gymnasium as gym
import numpy as np

from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

## Paths

In [2]:
logs_path = os.path.join('Logs')
models_path = os.path.join('Models')
Ant_TD3_NA_path = os.path.join(models_path,'Ant_Model_TD3_NA')
Ant_TD3_Orn_path = os.path.join(models_path,'Ant_Model_TD3_Orn')

## Training

### Learning

In [None]:
env = gym.make("Ant-v4", render_mode="rgb_array")

# The noise objects for TD3
n_actions = env.action_space.shape[-1]
# action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

# Evaluate the model frequantly and stop at satisfactory reward
stop_train_callback = StopTrainingOnRewardThreshold(reward_threshold=3000, verbose=1)
eval_callback = EvalCallback(eval_env=env, eval_freq=30_000, callback_after_eval=stop_train_callback, best_model_save_path=Ant_TD3_NA_path, verbose=1)

# Initialize the model
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1, tensorboard_log=os.path.join(logs_path,'Final Logs'))
model.learn(total_timesteps=int(1e6), log_interval=10, callback =eval_callback)

del model # remove to demonstrate saving and loading


### Testing trained model

#### Random Movements

In [6]:
episodes = 5
total_score = 0

render = True
render_frequency = 1

env = gym.make("Ant-v4", render_mode='rgb_array')
model = TD3.load(os.path.join(Ant_TD3_NA_path,'best_model'), env)
vec_env = model.get_env()

print(f'\nStarting < Random Movements >')
for episode in range(episodes):
    obs = vec_env.reset()
    done = False
    score = 0
    frame_count = 0
    
    while not done:
        action = [vec_env.action_space.sample()]
        obs, rewards, done, _ = vec_env.step(action)
        score += rewards
        if render:
            if frame_count % render_frequency == 0:
                vec_env.render('human')
            frame_count += 1
    
    total_score += score
    print(f'Episode: {episode} Score: {score}')

print(f'Mean Score for {episode + 1} episodes: {total_score / episodes}')
env.close()
vec_env.close()
del model

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Starting < Random Movements >
Episode: 0 Score: [6.629386]
Episode: 1 Score: [-2.8597863]
Episode: 2 Score: [-79.52603]
Episode: 3 Score: [-58.251022]
Episode: 4 Score: [-10.662735]
Mean Score for 5 episodes: [-28.934036]


#### Trained Movements

In [7]:
episodes = 1

render = True
render_frequency = 1

env = gym.make("Ant-v4", render_mode='rgb_array')
# models_to_try = [Ant_TD3_NA_path, Ant_TD3_Orn_path]
models_to_try = [Ant_TD3_Orn_path]

for current_model in models_to_try:
    
    print(f'\nStarting < {current_model.removeprefix(os.path.join(models_path,""))} >')
    model = TD3.load(os.path.join(current_model,'best_model'), env=env)
    vec_env = model.get_env()
    total_score = 0

    for episode in range(episodes):
        obs = vec_env.reset()
        done = False
        score = 0
        frame_count = 0
        
        while not done:
            action, _ = model.predict(obs)
            obs, rewards, done, info = vec_env.step(action)
            score += rewards
            if render:
                if frame_count % render_frequency == 0:
                    vec_env.render('human')
            frame_count += 1
        
        total_score += score
        print(f'Episode: {episode + 1} Score: {score}')
    
    print(f'For model < {current_model.removeprefix(os.path.join(models_path,""))} >')
    print(f'Mean Score for {episode + 1} episodes: {total_score / episodes}')

env.close()    
vec_env.close()
del model


Starting < Ant_Model_TD3_Orn >
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Episode: 1 Score: [3085.0332]
For model < Ant_Model_TD3_Orn >
Mean Score for 1 episodes: [3085.0332]
