In [1]:
# train with SAC, stable baseline3
import stable_baselines3
from stable_baselines3 import SAC, PPO
from stable_baselines3.sac import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from wandb.integration.sb3 import WandbCallback
import wandb
# pip install gym-robotics
import matplotlib.pyplot as plt
from gym_robotics.envs.fetch.reach import MujocoPyFetchReachEnv
from gym_robotics.envs.fetch.push import MujocoPyFetchPushEnv
from gym.wrappers import TimeLimit
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
import datetime
log_dir = "./tb_log/"

goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE


# init mujoco fetch enviroenment
env_name = 'FetchReach'

if env_name == 'FetchReach':
    env_class = MujocoPyFetchReachEnv
elif env_name == 'FetchPush':
    env_class = MujocoPyFetchPushEnv
else :
    raise ValueError(f"env_name: {env_name} is not supported")

model_name = 'SAC'


max_steps = 100_000
reward_type = 'dense'
time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

#distance_threshold = 0.05
config = {
    "policy_type": model_name,
    "total_timesteps": max_steps,
    "env_name": env_name,
    "reward_type": reward_type,
    "max_steps": max_steps,
}

name = f"{config['env_name']}-{config['policy_type']}-{config['reward_type']}"
run = wandb.init(
    project="sb3",
    name= name,
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional 
)


# init mujoco fetch environment
env = env_class(reward_type=config['reward_type'], max_episode_steps=50, action_scale=0.1)
env = Monitor(env, log_dir)
env = TimeLimit(env, max_episode_steps=100)
env = DummyVecEnv([lambda: env])

env_eval = env_class(reward_type=config['reward_type'],max_episode_steps=50,action_scale=0.1)
env_eval = Monitor(env_eval, log_dir)
env_eval = TimeLimit(env_eval, max_episode_steps=100)
env_eval = DummyVecEnv([lambda: env_eval])

env.render_mode = 'rgb_array'
# wrap environment
# init model

if model_name == 'SAC-HER':
    model = SAC(MlpPolicy, env, verbose=1, 
            replay_buffer_class=HerReplayBuffer,
            # Parameters for HER
            replay_buffer_kwargs=dict(
                n_sampled_goal=4,
                goal_selection_strategy=goal_selection_strategy,),
            device='cuda',wandb_log=True)

elif model_name == 'SAC':
    model = SAC(MlpPolicy, env, verbose=1,
                device='cuda',wandb_log=True)
else :
    raise ValueError(f"model_name: {model_name} is not supported")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


/research/jesnk_packages/gym_robotics/envs/fetch/reach.py


[34m[1mwandb[0m: Currently logged in as: [33mjesnk[0m. Use [1m`wandb login --relogin`[0m to force relogin


  logger.warn(


Using cuda device


In [2]:
env.action_space

Box(-0.1, 0.1, (4,), float32)

In [None]:
rgb_to_video = RGB2VIDEO()
#env = MujocoPyFetchReachEnv(reward_type='dense')
env.render_mode = 'rgb_array'
#env = Monitor(env_eval, log_dir)
#env = TimeLimit(env, max_episode_steps=100)
env = DummyVecEnv([lambda: env])
#env.render_mode = 'rgb_array'

episode_step = 0
episode_num = 0
replay_step = 300
cumulative_reward = 0
frames = []
obs = env.reset()

success = []

for i in range(1,replay_step+1):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    cumulative_reward += rewards[0]
    
    frame = env.render()
    #frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # append infos into image (rewards, episode_step, episode_num)
    frame = cv2.putText(frame, f'rewards: {rewards[0]:.2f}', (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'cumulative_reward: {cumulative_reward:.2f}', (10, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'episode_step: {episode_step}', (10, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'episode_num: {episode_num}', (10, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frames.append(frame)
    
    episode_step += 1
    if dones[0]:
        obs = env.reset()
        success.append(info[0]['is_success'])
        episode_step = 0
        cumulative_reward = 0
        episode_num += 1

print(f'episode {i} done')
success_rate = sum(success)/len(success)
print(f'success rate: {success_rate}')

rgb_to_video.set_frames(frames)
rgb_to_video.set_fps(5)
rgb_to_video.save(path=f'{rollout_path}epi{len(success)}_sucrat{success_rate:.3f}.gif',mode='gif')

frames = []
rgb_to_video.container.clear()


In [None]:

# train model
model.learn(total_timesteps=max_steps, 
            log_interval=10, 
            tb_log_name="sac_fetch_reach", 
            reset_num_timesteps=False, 
            eval_freq=100, 
            n_eval_episodes=20,
            eval_log_path="sac_fetch_reach_eval",
            eval_env=env_eval,
            )

model.save(f"./checkpoint/{name}-{time}")
