In [1]:
import gym
from darm_gym_env import DARMSFEnv
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env

from datetime import datetime

In [None]:
NUM_CPU = 4

env = make_vec_env("darm/DarmSFHand-v0", n_envs=NUM_CPU, seed=0)
# env = VecNormalize(env)   #FIXME: Remember to save norm params if using VecNorm env
env = VecMonitor(env)

policy_kwargs = dict(net_arch=dict(pi=[64, 256, 256, 64], qf=[64, 256, 256, 64]))
model = SAC("MlpPolicy", env, verbose=1,
            learning_starts=40_000,
            gradient_steps=NUM_CPU, # num of envs
            policy_kwargs=policy_kwargs,
            tensorboard_log="./results/darm_sf_hand")

timestamp = ""
try:
    model.learn(total_timesteps=10_000_000, log_interval=8, tb_log_name="SF_RelPos_NoNorm_GS4_[64, 256, 256, 64]")
    # Add calbacks
except Exception as e:
    print("Exception caught:")
    print(e)
finally:
    timestamp = f"{datetime.now().date()}__{datetime.now().time()}"
    print(f"Saving checkpoint {timestamp}")
    model_name = f"./checkpoints/darm_sf_hand_{timestamp}"
    env_norm_name = f"./checkpoints/darm_sf_hand_env_norm_{timestamp}"
    model.save(model_name)
    # env.save(env_norm_name) # FIXME: Remember to save norm params if using VecNorm env

Loaded XML file successfully


  logger.warn(


Loaded XML file successfully
Loaded XML file successfully
Loaded XML file successfully
Using cpu device
Logging to ./results/darm_sf_hand/SF_RelPos_NoNorm_GS4_[64, 256, 256, 64]_1




-----------------------------------
| rollout/           |            |
|    ep_len_mean     | 91.1       |
|    ep_rew_mean     | -127.10165 |
| time/              |            |
|    episodes        | 8          |
|    fps             | 495        |
|    time_elapsed    | 1          |
|    total_timesteps | 800        |
-----------------------------------
------------------------------------
| rollout/           |             |
|    ep_len_mean     | 89.6        |
|    ep_rew_mean     | -124.125084 |
| time/              |             |
|    episodes        | 16          |
|    fps             | 498         |
|    time_elapsed    | 3           |
|    total_timesteps | 1600        |
------------------------------------
-----------------------------------
| rollout/           |            |
|    ep_len_mean     | 93         |
|    ep_rew_mean     | -141.32576 |
| time/              |            |
|    episodes        | 24         |
|    fps             | 503        |
|    time_elapsed 

In [None]:
# MORE TRAINING

# LOAD TRAINED MODEL

try:
    model.learn(total_timesteps=10_000_000, log_interval=8, tb_log_name="PlainDarmEnv")
    # Add calbacks
except Exception as e:
    print("Exception caught:")
    print(e)
finally:
    timestamp = f"{datetime.now().date()}__{datetime.now().time()}"
    print(f"Saving checkpoint {timestamp}")
    model_name = f"./checkpoints/darm_sf_hand_{timestamp}"
    env_norm_name = f"./checkpoints/darm_sf_hand_env_norm_{timestamp}"
    model.save(model_name)
    # env.save(env_norm_name) # FIXME: Remember to save norm params if using VecNorm env

In [None]:
env.close()

### DONE TRAINING

In [2]:
model_name = "./checkpoints/darm_sf_hand_2022-12-28__10:10:05.637581"
env_norm_name = "./checkpoints/darm_sf_hand_env_norm_2022-12-28__10:10:05.637581"

In [7]:
env = DummyVecEnv([lambda: gym.make("darm/DarmSFHand-v0", render_mode="human", hand_name="hand1")])
env = VecNormalize.load(env_norm_name, env)
env.training = False
print("Zero Norm: ", env.unnormalize_reward(-0.47959065))

Loaded XML file successfully
Zero Norm:  -16.065026528454638


In [8]:
model = SAC.load(model_name)

In [9]:
import numpy as np
def norm_to_target(obs):
    """
    Returns the norm of each fingertip to the target position
    obs: an observation from the observation space [...fingertip_pos, ...target_pos]
    """
    obs = obs.reshape((-1, 3))
    n_fingertips = len(obs)//2

    fingertip_poses = obs[0:n_fingertips]
    target_poses = obs[n_fingertips:]

    return np.linalg.norm(fingertip_poses-target_poses, ord=2, axis=-1)

In [10]:
obs = env.reset()
episode_return = 0
N_EPISODES = 10

for i in range(N_EPISODES):
  obs = env.reset()
  done = False
  episode_steps = 0
  episode_return = 0
  episode_return_norm = 0

  
  while not done:
    # print("Observation: ", env.unnormalize_obs(obs))
    old_norm = norm_to_target(env.unnormalize_obs(obs))

    action, _states = model.predict(obs, deterministic=True)
    # print("Action: ", action)

    obs, reward, done, info = env.step(action)
    episode_steps += 1
    new_norm = norm_to_target(env.unnormalize_obs(obs))

    # Get actual reward
    unnormalized_reward = env.unnormalize_reward(reward)
    episode_return += unnormalized_reward
    episode_return_norm += reward
    # print(f"Reward: {unnormalized_reward}; Normalized: {reward}")

    # print(f"Next Observation: {env.unnormalize_obs(obs)}")
    # print(f"Change in Norm: {new_norm - old_norm}")
    # print("-----------------------------------------------------")

    # render
    env.render()
  
  print(f"Num Steps: {episode_steps}")
  print(f"Episode Return: {episode_return}")
  print(f"Episode Return Norm: {episode_return_norm}")
  if episode_return > -70: 
    print("Goal Reached!")
  print("\n")

env.close()

Num Steps: 100
Episode Return: [-206.27367]
Episode Return Norm: [-6.1579084]


Num Steps: 7
Episode Return: [238.47386]
Episode Return Norm: [7.1191816]
Goal Reached!


Num Steps: 100
Episode Return: [-112.209465]
Episode Return Norm: [-3.3497994]


Num Steps: 29
Episode Return: [194.17151]
Episode Return Norm: [5.7966194]
Goal Reached!


Num Steps: 100
Episode Return: [-154.97745]
Episode Return Norm: [-4.6265574]


Num Steps: 4
Episode Return: [243.22809]
Episode Return Norm: [7.2611094]
Goal Reached!


Num Steps: 100
Episode Return: [-204.50894]
Episode Return Norm: [-6.1052227]


Num Steps: 100
Episode Return: [-108.443665]
Episode Return Norm: [-3.2373753]


Num Steps: 100
Episode Return: [-206.46591]
Episode Return Norm: [-6.163656]


Num Steps: 100
Episode Return: [-206.59006]
Episode Return Norm: [-6.167353]


