In [1]:
import gymnasium
import push_box
import pybullet as p
import pybullet_data
import time
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
import os
from typing import Any, Dict
import torch as th

pybullet build time: Nov 28 2023 23:48:36


In [None]:
p.disconnect()

In [2]:
log_path = os.path.join('Training', 'Logs')

In [3]:
env = gymnasium.make('pushBox-v0')
env = DummyVecEnv([lambda: env])

In [4]:
# Add a callback to training stage for early stopping
save_path = os.path.join('Training', 'SavedModels', 'PPO_66_best')
stop_callback = StopTrainingOnRewardThreshold(reward_threshold = 100, verbose = 1)
eval_callback = EvalCallback(env, 
                            callback_on_new_best = stop_callback,
                            eval_freq = 10000, 
                            best_model_save_path = save_path, 
                            verbose = 1)


In [5]:
# Learning rate schedule: linearly decreasing from 0.0007 to 0.0001
def linear_lr(progress_remaining: float):
    start_lr = 0.0007
    end_lr = 0.0003
    return end_lr + (start_lr - end_lr) * progress_remaining

In [7]:
model = PPO('MlpPolicy', env, learning_rate=linear_lr, verbose=1, tensorboard_log=log_path)

Using cuda device


In [8]:
model.learn(total_timesteps=10000000, callback=eval_callback)

Logging to Training/Logs/PPO_66
-----------------------------
| time/              |      |
|    fps             | 789  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 515         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009840984 |
|    clip_fraction        | 0.0802      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.82       |
|    explained_variance   | 0.922       |
|    learning_rate        | 0.0007      |
|    loss                 | 0.0173      |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00397    |
|    std                  | 0.981       |
|    value_loss           | 0.00479     |
--



Eval num_timesteps=10000, episode_reward=-0.00 +/- 0.00
Episode length: 500.00 +/- 0.00
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 500         |
|    mean_reward          | -0.005      |
| time/                   |             |
|    total_timesteps      | 10000       |
| train/                  |             |
|    approx_kl            | 0.009419713 |
|    clip_fraction        | 0.0848      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.83       |
|    explained_variance   | 0.906       |
|    learning_rate        | 0.0007      |
|    loss                 | -0.00659    |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.00562    |
|    std                  | 1           |
|    value_loss           | 0.00193     |
-----------------------------------------
New best mean reward!
------------------------------
| time/              |       |
|    fps             | 418   |

<stable_baselines3.ppo.ppo.PPO at 0x7fc1dea25b10>

In [9]:
PPO_Path = os.path.join('Training', 'SavedModels', 'PPO_66_10M_360_cont_actions')

In [10]:
model.save(PPO_Path)

In [9]:
del model

In [10]:
model = PPO.load(PPO_Path, env=env)

In [11]:
evaluate_policy(model, env, n_eval_episodes=10)

box reached target


(0.789357189137263, 1.8141790892853324)

In [17]:
p.disconnect()

numActiveThreads = 0
stopping threads
Thread with taskId 0 exiting
Thread TERMINATED
destroy semaphore
semaphore destroyed
destroy main semaphore
main semaphore destroyed
finished
numActiveThreads = 0
btShutDownExampleBrowser stopping threads
Thread with taskId 0 exiting
Thread TERMINATED
destroy semaphore
semaphore destroyed
destroy main semaphore
main semaphore destroyed
