In [1]:
import sys
import os
import math
import json
import subprocess
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from models.td3 import TD3
# from osim.env import ProstheticsEnv
from environment.prosthetics_env_with_history import ProstheticsEnvWithHistory
from environment.observations import prepare_model_observation
from environment.actions import prepare_env_action, reset_frameskip
from environment.rewards import env_obs_to_custom_reward
from distributed.database import persist_timesteps, persist_event
from distributed.s3_checkpoints import load_s3_model_checkpoint


In [2]:
with open('config_distributed.json', 'r') as f:
    CONFIG = json.load(f)
print(json.dumps(CONFIG, indent=4))


{
    "env": {
        "integrator_accuracy": 0.002
    },
    "model": {
        "architecture": "TD3"
    },
    "rollout": {
        "#": "Frameskip will be applied for random durations between 0 and `frameskip` timesteps.",
        "max_episode_steps": 600,
        "expl_noise": 0.25,
        "frameskip": 5
    },
    "distributed": {
        "policy_weights_dir_s3": "s3://colllin-nips-2018-prosthetics/checkpoints/",
        "policy_weights_basename": "checkpoint_TD3",
        "#": "How often (episodes) we download model weights during rollout.",
        "rollout_refresh_model_freq": 5
    },
    "training": {
        "#": "Frequency of delayed policy updates",
        "eval_freq": 5000.0,
        "batch_size": 100,
        "discount": 0.99,
        "tau": 0.005,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
        "policy_freq": 2
    }
}


### Create simulation env

In [3]:
env = ProstheticsEnvWithHistory(visualize=False, integrator_accuracy=CONFIG['env']['integrator_accuracy'])
env_step_kwargs = {'project': False}


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


### Create Policy, Download & load latest weights

In [4]:
# state_dim = env.observation_space.shape[0]
env.reset(**env_step_kwargs)
state_dim = prepare_model_observation(env).shape[0]
action_dim = env.action_space.shape[0]
max_action = int(env.action_space.high[0])
state_dim, action_dim, max_action

print(env.action_space.low, env.action_space.high)


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [5]:
policy = TD3(state_dim, action_dim, max_action)

In [6]:
print(f"Loading policy checkpoints from {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}*")
load_s3_model_checkpoint(
    policy, 
    s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
    basename=CONFIG['distributed']['policy_weights_basename'],
)


Loading policy checkpoints from s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*


### Episode Hacking (Custom "done" criteria)


In [7]:
def should_abort_episode(env_obs, custom_rewards=None, verbose=False):
#     print((np.array(env_obs['body_pos_rot']['torso'])*180/math.pi > 60).any())
#     if env_obs['body_pos_rot']['torso'][2] < -0.2:
#         return True
    rewards = custom_rewards if custom_rewards != None else env_obs_to_custom_reward(env_obs)
    # print(f'Custom reward: {sum(rewards.values())}')
    if (env_obs['body_pos']['head'][0] - env_obs['body_pos']['pelvis'][0]) < -.2:
        if verbose: print(f'Aborting episode due to head being > .2m behind the pelvis ({env_obs["body_pos"]["head"][0] - env_obs["body_pos"]["pelvis"][0]}).')
        return True
    if np.fabs(env_obs['body_pos']['head'][2]) > 0.5:
        if verbose: print(f'Aborting episode due to head being > 0.5m away from centerline ({env_obs["body_pos"]["head"][2]}).')
        return True
    if sum(rewards.values()) < -10:
        if verbose:
            print(f'Aborting episode due to custom reward < -10 ({sum(rewards.values())}):')
            for k,v in rewards.items():
                if v < 0:
                    print(f'  reward `{k}` = {v}')
        return True
    return False
    

# Policy rollout (Record & Persist Simulations)

In [8]:
total_timesteps = 0
episode_num = 0
done = True
episode_timesteps = 0
total_timesteps, episode_num, episode_timesteps


(0, 0, 0)

In [9]:
while True:
    if done: 
        if (episode_num % CONFIG['distributed']['rollout_refresh_model_freq']) == 0:
            print(f"\nLoading policy checkpoint from {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}\n")
            load_s3_model_checkpoint(
                policy, 
                s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
                basename=CONFIG['distributed']['policy_weights_basename'],
            )
            timesteps_since_model_update = 0
            persist_event('rollout_model_refreshed', {
                'episode_num': episode_num,
            })

        # Reset environment
        obs = env.reset(**env_step_kwargs)
        reset_frameskip(CONFIG['rollout']['frameskip'])
        done = False
        episode_reward = 0
        episode_timesteps = 0
        episode_num += 1 

    # # Select action randomly or according to policy
    # if total_timesteps < CONFIG['training']['start_timesteps']:
    #     action = env.action_space.sample()
    # else:
    action = policy.select_action(prepare_model_observation(env))
    if CONFIG['rollout']['expl_noise'] != 0: 
        action = (action + np.random.normal(0, CONFIG['rollout']['expl_noise'], size=env.action_space.shape[0])).clip(env.action_space.low[0], env.action_space.high[0])

    # Perform action
    action = prepare_env_action(action)
    obs, reward, done, _ = env.step(action, **env_step_kwargs)

    if not done:
        done = should_abort_episode(env.get_state_desc(), verbose=True)
    done_bool = 0 if episode_timesteps + 1 == CONFIG['rollout']['max_episode_steps'] else float(done)

    # custom_rewards = compute_rewards(new_obs_dict)
    episode_reward += reward #+ sum(custom_rewards.values())

    episode_timesteps += 1
    total_timesteps += 1

    if done:
        # Persist timesteps to central database
        persist_timesteps(env.history())
        env.reset_history()

        # Log episode
        persist_event('rollout_episode_completed', {
            'episode_num': episode_num,
            'episode_timesteps': episode_timesteps,
            'episode_reward': episode_reward,
        })
        print(f"Total T: {total_timesteps} Episode Num: {episode_num} Episode T: {episode_timesteps} Reward: {episode_reward}")
        sys.stdout.flush()




Loading policy checkpoint from s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3

[0.90156519 0.         0.58485466 0.         0.95628807 0.
 0.         0.77024758 0.         1.         0.         1.
 0.         0.         0.90779868 1.         1.         1.
 1.        ]
[1.         0.         1.         0.         0.87208195 0.
 0.         0.82868751 0.         0.63772332 0.         1.
 0.         0.         0.57978286 1.         1.         1.
 1.        ]
[1.         0.         0.55363365 0.         1.         0.
 0.         0.83663117 0.         1.         0.         1.
 0.         0.         0.91170211 1.         1.         1.
 0.40239909]
[0.98010331 0.         1.         0.         0.96340011 0.
 0.         0.90504331 0.         0.85506601 0.         0.62530168
 0.         0.         0.60945816 0.42312787 0.91651506 0.76630642
 0.99581278]
[0.92783191 0.         1.         0.         1.         0.
 0.         1.         0.         1.         0.         0.78961317
 0.

[0.70834113 0.         0.9618427  0.         1.         0.
 0.         1.         0.         1.         0.         1.
 0.         0.         1.         1.         1.         0.89155708
 0.68215728]
[0.98644332 0.         0.95926823 0.         1.         0.
 0.         0.55346757 0.         1.         0.         0.77338501
 0.         0.         0.94803394 1.         1.         0.64160171
 0.97561284]
[0.47699488 0.         1.         0.         1.         0.
 0.         0.65284687 0.         1.         0.         0.2840375
 0.         0.         1.         0.59633167 0.76131245 1.
 1.        ]
[0.98534554 0.         0.95106563 0.         1.         0.
 0.         0.97514647 0.         0.89468667 0.         0.76083781
 0.         0.         1.         0.6368228  0.65733061 1.
 1.        ]
[1.         0.         0.6761892  0.         1.         0.
 0.         1.         0.         1.         0.         1.
 0.         0.         1.         0.69400735 1.         1.
 0.87076561]
[0.90908567

[1.         0.         1.         0.         0.73880899 0.
 0.         1.         0.         0.79793124 0.         0.75622129
 0.         0.         0.82421884 0.83661591 0.90346172 1.
 0.9483231 ]
[1.         0.         0.97713984 0.         1.         0.
 0.         1.         0.         0.82436782 0.         1.
 0.         0.         1.         1.         0.67986996 1.
 1.        ]
[0.72215824 0.         1.         0.         0.63653482 0.
 0.         0.90279319 0.         1.         0.         0.95816892
 0.         0.         0.88065981 0.9702827  0.98857946 0.73761356
 1.        ]
[0.84861909 0.         0.52286349 0.         0.87314688 0.
 0.         1.         0.         0.88521436 0.         0.53936722
 0.         0.         0.84660142 1.         1.         1.
 0.49149226]
[1.         0.         0.84600741 0.         1.         0.
 0.         0.85353667 0.         1.         0.         0.83081461
 0.         0.         0.68245275 1.         1.         1.
 1.        ]
[1.       

[0.94984194 0.         1.         0.         0.97014779 0.
 0.         0.92160921 0.         0.88655061 0.         1.
 0.         0.         0.97906729 1.         0.92233608 0.9285055
 1.        ]
[1.         0.         1.         0.         0.99599689 0.
 0.         0.7721455  0.         1.         0.         1.
 0.         0.         0.88282386 0.99942154 0.86800747 1.
 0.53542261]
[1.         0.         1.         0.         0.79522747 0.
 0.         1.         0.         0.98972008 0.         1.
 0.         0.         0.98838629 0.95155523 1.         1.
 0.75965763]
[0.91668757 0.         0.80738447 0.         0.797171   0.
 0.         1.         0.         0.5179125  0.         0.63611003
 0.         0.         1.         0.61887026 0.85227111 1.
 0.9767528 ]
[0.61203961 0.         0.91457469 0.         1.         0.
 0.         1.         0.         1.         0.         1.
 0.         0.         0.91000149 1.         0.66601998 1.
 0.76348581]
[1.         0.         0.80919739 0

SystemError: <built-in function Manager_integrate> returned a result with an error set