In [1]:
!pwd

In [2]:
import os
os.environ["DARM_MUJOCO_PATH"] = "/home/daniel/DARM/darm_mujoco"
os.getenv('DARM_MUJOCO_PATH')

'/home/daniel/DARM/darm_mujoco'

In [3]:
%%bash
cd ../mujoco_env
bash generate_darm_xml.sh false true

In [4]:
import gym
from darm_gym_env import DARMEnv
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env

import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.callbacks import CallbackList, EvalCallback, StopTrainingOnRewardThreshold, StopTrainingOnNoModelImprovement


from datetime import datetime

In [5]:
run_name = "test1_MF_SB3_SAC"

config = {
    "env_id": "darm/DarmHand-v0", # changed from darm/DarmHand-v0
    "single_finger_env": False,
    "algo": "SAC",
    "rl_lib": "SB3",
    
    "seed": 0,
    "mean_reward_thresh": 1_300,
    "total_timesteps": 10_000_000,
    "pi_net_arch": [128, 256, 256, 128],
    "qf_net_arch": [128, 256, 256, 128],
    "learning_starts": 40_000,
    "num_cpu": 6,
    
    "eval_freq": 2_000, # 5_000
    "max_no_improvement_evals": 10,
    "no_improvement_min_evals": 20,
    
    "log_interval": 20, # episodes
    "wandb_model_save_freq": 2_000, #5_000 timesteps?
    
    "run_local_dir": f"{os.getenv('DARM_MUJOCO_PATH')}/darm_training/results/darm_mf_hand/{run_name}" 
}

In [6]:
notes = """
- The environment was updated such that the target is within a range from the start point
- Velocity penalty was removed and only effort penalty was used
- The reward function was updated according to the reach task reward used in facebookresearch/myosuite [https://github.com/facebookresearch/myosuite/blob/main/myosuite/envs/myo/reach_v0.py]
- The done signal is trigerred only when the fingertip goes beyond a threshold. The episode continues to the maximum timestep otherwise.
- The friction and damping coefficient of the environment is updated. Values are inspired from Deepmind's Mujoco Menagerie [https://github.com/deepmind/mujoco_menagerie/blob/main/shadow_hand/right_hand.xml]
- The range of action from the model was changed to [-1, 1]. This action is mapped to the actual action sent to mujoco e.g [0, 2]]. This change is inspired from values used in OpenAI's Gym Mujoco environments.
- max_episode_steps was updated to 200.
- Velocity vector (size [3,]) was added to observation. Observation size is now (9,)
- Action range was increased to [0, 5]
- Observation warpper to scale observation from m and m/s to cm and cm/s was applied
- Max Tension for Digitorum Extensor Communis was increased to 10
- FIXED: Velocity Observation from (prev_pos - new_pos)/time to (new_pos - prev_pos)/time
- FIXED: Removed weight of 1 from 'sparse', 'solved', and 'done' in reward weighting
- Reduced max_target_th to 5*0.004, 20 mm

- Five-Fingers; No Wrist Environment
- This run was trained on vast_ai using SB3's SAC algo.
"""

tags = ["five_fingers", "sac", "sb3", "vast_ai"]

run = wandb.init(
    project="DARM",
    name=run_name,
    tags=tags,
    notes=notes,
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    # monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

In [7]:
from gym.wrappers import TransformObservation
# from gym.wrappers import RescaleAction

create_env = lambda: TransformObservation(gym.make(config["env_id"], single_finger_env=config["single_finger_env"]), lambda obs: obs*100)

In [8]:
NUM_CPU = config["num_cpu"]

env = make_vec_env(create_env, n_envs=NUM_CPU, seed=config["seed"])
# env = VecNormalize(env)   #FIXME: Remember to save norm params if using VecNorm env
# env = VecMonitor(env)

policy_kwargs = dict(net_arch=dict(pi=config["pi_net_arch"], qf=config["qf_net_arch"]))

model = SAC("MlpPolicy", env, verbose=1,
            learning_starts=config["learning_starts"],
            gradient_steps=NUM_CPU, # num of envs
            policy_kwargs=policy_kwargs,
            tensorboard_log=config['run_local_dir'])

In [9]:
eval_env = make_vec_env(create_env, n_envs=1, seed=config["seed"])

# Stop training when the model reaches the reward threshold
# reward_thresh_callback = StopTrainingOnRewardThreshold(reward_threshold=config["mean_reward_thresh"], verbose=1)

# Stop training if there is no improvement after more than N evaluations
# stop_train_callback = StopTrainingOnNoModelImprovement(
#     max_no_improvement_evals=config["max_no_improvement_evals"], 
#     min_evals=config["no_improvement_min_evals"], 
#     verbose=1)

eval_callback = EvalCallback(eval_env, 
                             best_model_save_path=f"{config['run_local_dir']}/models/best",
                             log_path=f"{config['run_local_dir']}/models/best/logs", 
                             eval_freq=config["eval_freq"],
                             # callback_on_new_best=reward_thresh_callback,
                             # callback_after_eval=stop_train_callback,
                             deterministic=True, render=False, verbose=1)

wandb_callback=WandbCallback(model_save_path=f"{config['run_local_dir']}/models",
                             model_save_freq=config["wandb_model_save_freq"],
                             verbose=2)

# Create the callback list
callback = CallbackList([wandb_callback, eval_callback])
callback

<stable_baselines3.common.callbacks.CallbackList at 0x7f218449c7f0>

In [10]:
try:
    model.learn(total_timesteps=config["total_timesteps"], 
                log_interval=config["log_interval"], 
                tb_log_name=run_name,
                callback=callback)
except Exception as e:
    print("Exception caught:")
    print(e)
finally:
    # timestamp = f"{datetime.now().date()}__{datetime.now().time()}"
    print("Saving last checkpoint")
    model_name = f"{config['run_local_dir']}/models/last_model"
    model.save(model_name)
    print(f"Last checkpoint saved in: {model_name}")

In [11]:
# Finish the run if it's final
run.finish()
print(f"Finished run {run_name}")