In [1]:
import os
import gymnasium as gym
from stable_baselines3 import PPO
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize,VecMonitor
from stable_baselines3.common.env_util import make_vec_env
from deform_rl.envs.Cable_reshape_env.environment import CableReshapeV2,CableReshape
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.callbacks import EvalCallback,BaseCallback

import pygame

In [10]:
def make_env(rank,seed=0):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = CableReshapeV2(render_mode='human',seg_num=10,cable_length=300,scale_factor=800)
        env = TimeLimit(env,max_episode_steps=1000)
        env = Monitor(env)
        # use a seed for reproducibility
        # Important: use a different seed for each environment
        # otherwise they would generate the same experiences
        env.reset(seed=seed + rank)
        return env

    set_random_seed(seed)
    return _init

env0 = DummyVecEnv([make_env(i+4) for i in range(4)])
training_env =VecNormalize(env0)
env1 = DummyVecEnv([make_env(i+4) for i in range(1)])
validation_env = VecNormalize(env1)

save_dir = os.path.join("saved_models/reshape")
log_dir = os.path.join("logs/reshape")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

class SaveNormalizeCallback(BaseCallback):
    def __init__(self, verbose = 0,):
        super().__init__(verbose)

    def _on_step(self):
        training_env.save(save_dir+"/vecnorms.pkl")
        super()._on_step()
        return True
save_callback = SaveNormalizeCallback()


eval_callback = EvalCallback(
    eval_env=validation_env,
    n_eval_episodes=15,
    eval_freq=10000,
    callback_on_new_best=save_callback,
    best_model_save_path=save_dir,
    verbose=1,
    render=False
)


In [4]:
model = PPO('MlpPolicy', training_env,device='cpu',verbose=1,tensorboard_log=log_dir)
model.learn(800000,tb_log_name='reshape_v2') 


Using cpu device
Logging to logs/reshape/reshape_v2_17
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 1e+03     |
|    ep_rew_mean     | -4.71e+03 |
| time/              |           |
|    fps             | 2241      |
|    iterations      | 1         |
|    time_elapsed    | 3         |
|    total_timesteps | 8192      |
----------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+03       |
|    ep_rew_mean          | -4.99e+03   |
| time/                   |             |
|    fps                  | 1693        |
|    iterations           | 2           |
|    time_elapsed         | 9           |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.016673803 |
|    clip_fraction        | 0.21        |
|    clip_range           | 0.2         |
|    entropy_loss         | -28.4       |
|    ex

<stable_baselines3.ppo.ppo.PPO at 0x7349531ffbe0>

In [13]:

model.learn(1600000,tb_log_name='reshape_v2',callback=eval_callback,reset_num_timesteps=False) 


Logging to logs/reshape/reshape_v2_17
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 237      |
|    ep_rew_mean     | 52.5     |
| time/              |          |
|    fps             | 2248     |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 4222272  |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 272         |
|    ep_rew_mean          | -5.02       |
| time/                   |             |
|    fps                  | 1392        |
|    iterations           | 2           |
|    time_elapsed         | 11          |
|    total_timesteps      | 4230464     |
| train/                  |             |
|    approx_kl            | 0.052151516 |
|    clip_fraction        | 0.456       |
|    clip_range           | 0.2         |
|    entropy_loss         | -22.7       |
|    explained_variance   | 0.6771

<stable_baselines3.ppo.ppo.PPO at 0x7349531ffbe0>