In [1]:
import sys
import os
import math
import json
import uuid
import datetime
import subprocess
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from models.td3 import TD3
# from osim.env import ProstheticsEnv
from environment.prosthetics_env_with_history import ProstheticsEnvWithHistory
from environment.observations import prepare_model_observation, env_obs_history_to_model_obs
from environment.actions import prepare_env_action, reset_frameskip
from environment.rewards import env_obs_to_custom_reward
from distributed.database import persist_timesteps, persist_event, get_total_timesteps, clear_clients_for_thread
from distributed.db_history_sampler import DatabaseHistorySampler
from distributed.s3_checkpoints import load_s3_model_checkpoint, save_s3_model_checkpoint
import torch
import torch.utils.data



In [2]:
with open('config_distributed.json', 'r') as f:
    CONFIG = json.load(f)
print(json.dumps(CONFIG, indent=4))


{
    "env": {
        "integrator_accuracy": 0.002
    },
    "model": {
        "architecture": "TD3"
    },
    "rollout": {
        "#": "Frameskip will be applied for random durations between 0 and `frameskip` timesteps.",
        "max_episode_steps": 600,
        "expl_noise": 0.25,
        "frameskip": 5
    },
    "distributed": {
        "policy_weights_dir_s3": "s3://colllin-nips-2018-prosthetics/checkpoints/",
        "policy_weights_basename": "checkpoint_TD3",
        "#": "How often (episodes) we download model weights during rollout.",
        "rollout_refresh_model_freq": 5
    },
    "training": {
        "#": "Frequency of delayed policy updates",
        "eval_freq": 2000,
        "batch_size": 100,
        "discount": 0.99,
        "tau": 0.005,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
        "policy_freq": 2
    }
}


### Create Policy, Download & load latest weights

In [3]:
# state_dim = env.observation_space.shape[0]
env = ProstheticsEnvWithHistory(visualize=False, integrator_accuracy=CONFIG['env']['integrator_accuracy'])
env.reset()
state_dim = prepare_model_observation(env).shape[0]
action_dim = env.action_space.shape[0]
max_action = int(env.action_space.high[0])
del env
state_dim, action_dim, max_action


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


(1260, 19, 1)

In [4]:
policy = TD3(state_dim, action_dim, max_action)

In [5]:
print(f"Loading policy checkpoints from {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}*")
load_s3_model_checkpoint(
    policy, 
    s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
    basename=CONFIG['distributed']['policy_weights_basename'],
)
persist_event('train_load_latest_checkpoint', f'Loaded policy checkpoint from {CONFIG["distributed"]["policy_weights_dir_s3"]}{CONFIG["distributed"]["policy_weights_basename"]}*')



Loading policy checkpoints from s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*


### Episode Hacking (Custom "done" criteria)


In [6]:
def should_abort_episode(env_obs, custom_rewards=None, verbose=False):
#     print((np.array(env_obs['body_pos_rot']['torso'])*180/math.pi > 60).any())
#     if env_obs['body_pos_rot']['torso'][2] < -0.2:
#         return True
    rewards = custom_rewards if custom_rewards != None else env_obs_to_custom_reward(env_obs)
    # print(f'Custom reward: {sum(rewards.values())}')
    if (env_obs['body_pos']['head'][0] - env_obs['body_pos']['pelvis'][0]) < -.2:
        if verbose: print(f'Aborting episode due to head being > .2m behind the pelvis ({env_obs["body_pos"]["head"][0] - env_obs["body_pos"]["pelvis"][0]}).')
        return True
    if np.fabs(env_obs['body_pos']['head'][2]) > 0.5:
        if verbose: print(f'Aborting episode due to head being > 0.5m away from centerline ({env_obs["body_pos"]["head"][2]}).')
        return True
    if sum(rewards.values()) < -10:
        if verbose:
            print(f'Aborting episode due to custom reward < -10 ({sum(rewards.values())}):')
            for k,v in rewards.items():
                if v < 0:
                    print(f'  reward `{k}` = {v}')
        return True
    return False
    

### Init dataloader

In [7]:
history_sampler = DatabaseHistorySampler(
    env_obs_history_to_model_obs_fn=env_obs_history_to_model_obs, 
    n_obs_history=3,
    env_obs_custom_reward_fn=lambda obs: sum(env_obs_to_custom_reward(obs).values()),
    env_obs_custom_done_fn=should_abort_episode,
)

def load_batch(fake_batch):
    return history_sampler.sample(len(fake_batch))

fake_dataset_len = CONFIG['training']['eval_freq'] * CONFIG['training']['batch_size']
fake_dataset = list(range(int(fake_dataset_len)))
dataloader = torch.utils.data.DataLoader(
    fake_dataset,
    batch_size=CONFIG['training']['batch_size'], 
#     shuffle=False, 
#     sampler=None, 
#     batch_sampler=None, 
    num_workers=6, 
    collate_fn=load_batch, 
    pin_memory=True, 
    drop_last=True, 
#     timeout=0, 
    worker_init_fn=lambda instance_id: clear_clients_for_thread()
)




# Model Training

In [8]:
while True:
    # Train for `eval_freq` batches:
    if CONFIG['model']['architecture'] == "TD3":
        policy.train(
            dataloader,
            CONFIG['training']['discount'], 
            CONFIG['training']['tau'], 
            CONFIG['training']['policy_noise'], 
            CONFIG['training']['noise_clip'], 
            CONFIG['training']['policy_freq'],
        )
    else: 
        policy.train(
            history_sampler,#replay_buffer, 
            int(CONFIG['training']['eval_freq']),
            CONFIG['training']['batch_size'], 
            CONFIG['training']['discount'], 
            CONFIG['training']['tau']
        )
    persist_event('train_epoch_completed', f'Trained policy for {len(dataloader)} batches of {dataloader.batch_size}')
      
    # Upload policy weights to S3, to be picked up by instances running the Rollout Distributed process.
    print(f"SAving policy checkpoints to {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}*")
    save_s3_model_checkpoint(
        policy, 
        s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
        basename=CONFIG['distributed']['policy_weights_basename'],
    )
    persist_event('train_update_s3_checkpoint', f'Uploaded policy checkpoint to {CONFIG["distributed"]["policy_weights_dir_s3"]}{CONFIG["distributed"]["policy_weights_basename"]}*')
    
    # Also upload policy weights under unique name as a historical checkpoint.
    total_timesteps = get_total_timesteps()
    evalname = f"{CONFIG['distributed']['policy_weights_basename']}_T{total_timesteps}_{datetime.datetime.now().isoformat()}"
    print(f"SAving policy checkpoints to {CONFIG['distributed']['policy_weights_dir_s3']}{evalname}*")
    save_s3_model_checkpoint(
        policy, 
        s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
        basename=evalname,
    )
    persist_event('train_save_historical_checkpoint', f'Uploaded policy checkpoint to {CONFIG["distributed"]["policy_weights_dir_s3"]}{evalname}*')

    # Run Evaluation script
    evaldir = str(uuid.uuid4())
    print(f"SAving policy checkpoints to {evaldir}/{evalname}*")
    os.makedirs(evaldir, exist_ok=True)
    policy.save(evaldir, evalname)
    evalcmd = f"CHECKPOINT_DIR={evaldir} CHECKPOINT_NAME={evalname} pipenv run python evaluate_policy.py"
    print(f"Launching evaluation script with cmd: `{evalcmd}`")
    subprocess.Popen(evalcmd, shell=True)
    
    

Train model: 100%|██████████| 2000/2000 [04:41<00:00,  7.11batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T27221_2018-09-25T18:40:23.718494*
SAving policy checkpoints to 926d23b3-0e17-47c8-a33a-d8bdc4c51005/checkpoint_TD3_T27221_2018-09-25T18:40:23.718494*
Launching evaluation script with cmd: `CHECKPOINT_DIR=926d23b3-0e17-47c8-a33a-d8bdc4c51005 CHECKPOINT_NAME=checkpoint_TD3_T27221_2018-09-25T18:40:23.718494 pipenv run python evaluate_policy.py`


Train model: 100%|██████████| 2000/2000 [05:42<00:00,  5.83batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T27221_2018-09-25T18:46:08.955179*
SAving policy checkpoints to 52fd7daa-c297-418b-b9e6-bf828b6c8462/checkpoint_TD3_T27221_2018-09-25T18:46:08.955179*
Launching evaluation script with cmd: `CHECKPOINT_DIR=52fd7daa-c297-418b-b9e6-bf828b6c8462 CHECKPOINT_NAME=checkpoint_TD3_T27221_2018-09-25T18:46:08.955179 pipenv run python evaluate_policy.py`


Train model: 100%|██████████| 2000/2000 [06:40<00:00,  4.99batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T27221_2018-09-25T18:52:51.850118*
SAving policy checkpoints to a70b2283-8055-4387-a99d-1c146edaa8ae/checkpoint_TD3_T27221_2018-09-25T18:52:51.850118*
Launching evaluation script with cmd: `CHECKPOINT_DIR=a70b2283-8055-4387-a99d-1c146edaa8ae CHECKPOINT_NAME=checkpoint_TD3_T27221_2018-09-25T18:52:51.850118 pipenv run python evaluate_policy.py`


Train model: 100%|██████████| 2000/2000 [06:59<00:00,  7.89batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T28572_2018-09-25T18:59:53.464152*
SAving policy checkpoints to db33dd0b-1c31-48ea-9cb7-a03f84499f1d/checkpoint_TD3_T28572_2018-09-25T18:59:53.464152*
Launching evaluation script with cmd: `CHECKPOINT_DIR=db33dd0b-1c31-48ea-9cb7-a03f84499f1d CHECKPOINT_NAME=checkpoint_TD3_T28572_2018-09-25T18:59:53.464152 pipenv run python evaluate_policy.py`


Train model:  70%|██████▉   | 1390/2000 [05:02<02:26,  4.18batch/s]Process Process-26:
Process Process-30:
Process Process-25:
Process Process-28:
Process Process-29:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, 

  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 1126, in _refresh
    self.__send_message(g)
KeyboardInterrupt
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/pool.py", line 745, in _raise_connection_failure
    raise error
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/pool.py", line 612, in receive_message
    self._raise_connection_failure(error)
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 931, in __send_message
    operation, exhaust=self.__exhaust, address=self.__address)
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/mongo_client.py", line 1145, in _send_message_with_response
    exhaust

Traceback (most recent call last):
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-28c7e80da0cf>", line 10, in <module>
    CONFIG['training']['policy_freq'],
  File "/home/ubuntu/nips-2018-ai-for-prosthetics/models/td3.py", line 111, in train
    for ibatch, (x, y, u, r, d) in enumerate(tqdm(iter(replay_dataloader), desc='Train model', unit='batch')):
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/tqdm/_tqdm.py", line 937, in __iter__
    for obj in iterable:
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 330, in __next__
    idx, batch = self._get_batch()
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-pros

  File "/home/ubuntu/nips-2018-ai-for-prosthetics/distributed/db_history_sampler.py", line 15, in sample
    docs_with_history = sample_timesteps(n=batch_size, n_obs_history=self.aggregate_obs_history)


KeyboardInterrupt: 

  File "/home/ubuntu/nips-2018-ai-for-prosthetics/distributed/database.py", line 93, in sample_timesteps
    for pt in past_timesteps:
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 1189, in next
    if len(self.__data) or self._refresh():
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 1126, in _refresh
    self.__send_message(g)
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 978, in __send_message
    codec_options=self.__codec_options)
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-prosthetics-fCqIkKV7/lib/python3.6/site-packages/pymongo/cursor.py", line 1067, in _unpack_response
    return response.unpack_response(cursor_id, codec_options)
  File "/home/ubuntu/.local/share/virtualenvs/nips-2018-ai-for-