In [17]:
# train with SAC, stable baseline3
import stable_baselines3
from stable_baselines3 import SAC, PPO
from stable_baselines3.sac import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from wandb.integration.sb3 import WandbCallback
import wandb
# pip install gym-robotics
import matplotlib.pyplot as plt
from gym_robotics.envs.fetch.reach import MujocoPyFetchReachEnv
from gym.wrappers import TimeLimit
from jesnk_utils.rgb_to_video import RGB2VIDEO
import cv2


In [40]:
load_checkpoint_path = "./checkpoint/PPO-FetchReach-dense-20230621_004248.zip"
#load_checkpoint_path = "./checkpoint/PPO-FetchReach-dense-20230621_004058.zip"

rollout_path = "./checkpoint_rollout/"
model = SAC.load(load_checkpoint_path)

In [41]:
# Get model name
model_name = load_checkpoint_path.split("/")[-1].split(".")[0]
print(model_name)

PPO-FetchReach-dense-20230621_004248


In [43]:
rgb_to_video = RGB2VIDEO()

env = MujocoPyFetchReachEnv(reward_type='dense',)
env.render_mode = 'rgb_array'
#env = Monitor(env_eval, log_dir)
#env = TimeLimit(env, max_episode_steps=100)
env = DummyVecEnv([lambda: env])
#env.render_mode = 'rgb_array'

episode_step = 0
episode_num = 0
replay_step = 300
cumulative_reward = 0
frames = []
obs = env.reset()

success = []

for i in range(1,replay_step+1):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    cumulative_reward += rewards[0]
    
    frame = env.render()
    #frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # append infos into image (rewards, episode_step, episode_num)
    frame = cv2.putText(frame, f'rewards: {rewards[0]:.2f}', (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'cumulative_reward: {cumulative_reward:.2f}', (10, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'episode_step: {episode_step}', (10, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frame = cv2.putText(frame, f'episode_num: {episode_num}', (10, 75), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)
    frames.append(frame)
    
    episode_step += 1
    if dones[0]:
        obs = env.reset()
        success.append(info[0]['is_success'])
        episode_step = 0
        cumulative_reward = 0
        episode_num += 1

print(f'episode {i} done')
success_rate = sum(success)/len(success)
print(f'success rate: {success_rate}')

rgb_to_video.set_frames(frames)
rgb_to_video.set_fps(5)
rgb_to_video.save(path=f'{rollout_path}{model_name}_{success_rate:.3f}.gif',mode='gif')

frames = []
rgb_to_video.container.clear()


episode 300 done
success rate: 1.0
