In [1]:
import sys
import os
import math
import json
import uuid
import datetime
import subprocess
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from models.td3 import TD3
# from osim.env import ProstheticsEnv
from environment.prosthetics_env_with_history import ProstheticsEnvWithHistory
from environment.observations import prepare_model_observation, env_obs_history_to_model_obs
# from environment.actions import prepare_env_action, reset_frameskip
from environment.rewards import env_obs_to_custom_reward
from distributed.database import persist_timesteps, persist_event, get_total_timesteps, clear_clients_for_thread
from distributed.db_history_sampler import DatabaseHistorySampler
from distributed.s3_checkpoints import load_s3_model_checkpoint, save_s3_model_checkpoint
import torch
import torch.utils.data



In [2]:
with open('config_distributed.json', 'r') as f:
    CONFIG = json.load(f)
print(json.dumps(CONFIG, indent=4))


{
    "env": {
        "integrator_accuracy": 0.002
    },
    "model": {
        "architecture": "TD3"
    },
    "rollout": {
        "#": "Frameskip will be applied for random durations between 0 and `frameskip` timesteps.",
        "max_episode_steps": 600,
        "expl_noise": 0.25,
        "frameskip": 5
    },
    "distributed": {
        "policy_weights_dir_s3": "s3://colllin-nips-2018-prosthetics/checkpoints/",
        "policy_weights_basename": "checkpoint_TD3",
        "#": "How often (episodes) we download model weights during rollout.",
        "rollout_refresh_model_freq": 5
    },
    "training": {
        "#": "Frequency of delayed policy updates",
        "eval_freq": 2500,
        "batch_size": 100,
        "discount": 0.99,
        "tau": 0.005,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
        "policy_freq": 2
    }
}


### Create Policy, Download & load latest weights

In [3]:
# state_dim = env.observation_space.shape[0]
env = ProstheticsEnvWithHistory(visualize=False, integrator_accuracy=CONFIG['env']['integrator_accuracy'])
env.reset()
state_dim = prepare_model_observation(env).shape[0]
action_dim = env.action_space.shape[0]
max_action = int(env.action_space.high[0])
del env
state_dim, action_dim, max_action


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


(1260, 19, 1)

In [4]:
policy = TD3(state_dim, action_dim, max_action)

In [5]:
print(f"Loading policy checkpoints from {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}*")
load_s3_model_checkpoint(
    policy, 
    s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
    basename=CONFIG['distributed']['policy_weights_basename'],
    map_location='cpu'
)
persist_event('train_load_latest_checkpoint', f'Loaded policy checkpoint from {CONFIG["distributed"]["policy_weights_dir_s3"]}{CONFIG["distributed"]["policy_weights_basename"]}*')



Loading policy checkpoints from s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*


### Episode Hacking (Custom "done" criteria)


In [6]:
# def should_abort_episode(env_obs, custom_rewards=None, verbose=False):
# #     print((np.array(env_obs['body_pos_rot']['torso'])*180/math.pi > 60).any())
# #     if env_obs['body_pos_rot']['torso'][2] < -0.2:
# #         return True
#     rewards = custom_rewards if custom_rewards != None else env_obs_to_custom_reward(env_obs)
#     # print(f'Custom reward: {sum(rewards.values())}')
#     if (env_obs['body_pos']['head'][0] - env_obs['body_pos']['pelvis'][0]) < -.2:
#         if verbose: print(f'Aborting episode due to head being > .2m behind the pelvis ({env_obs["body_pos"]["head"][0] - env_obs["body_pos"]["pelvis"][0]}).')
#         return True
#     if np.fabs(env_obs['body_pos']['head'][2]) > 0.5:
#         if verbose: print(f'Aborting episode due to head being > 0.5m away from centerline ({env_obs["body_pos"]["head"][2]}).')
#         return True
#     if sum(rewards.values()) < -10:
#         if verbose:
#             print(f'Aborting episode due to custom reward < -10 ({sum(rewards.values())}):')
#             for k,v in rewards.items():
#                 if v < 0:
#                     print(f'  reward `{k}` = {v}')
#         return True
#     return False
    

### Init dataloader

In [5]:
history_sampler = DatabaseHistorySampler(
    env_obs_history_to_model_obs_fn=env_obs_history_to_model_obs, 
    n_obs_history=3,
#     env_obs_custom_reward_fn=lambda obs: sum(env_obs_to_custom_reward(obs).values()),
#     env_obs_custom_done_fn=should_abort_episode,
)

def load_batch(fake_batch):
    return history_sampler.sample(len(fake_batch))

fake_dataset_len = CONFIG['training']['eval_freq'] * CONFIG['training']['batch_size']
fake_dataset = list(range(int(fake_dataset_len)))
dataloader = torch.utils.data.DataLoader(
    fake_dataset,
    batch_size=CONFIG['training']['batch_size'], 
#     shuffle=False, 
#     sampler=None, 
#     batch_sampler=None, 
    num_workers=6, 
    collate_fn=load_batch, 
    pin_memory=True, 
    drop_last=True, 
#     timeout=0, 
    worker_init_fn=lambda instance_id: clear_clients_for_thread()
)




# Model Training

In [None]:
while True:
    # Train for `eval_freq` batches:
    if CONFIG['model']['architecture'] == "TD3":
        policy.train(
            dataloader,
            CONFIG['training']['discount'], 
            CONFIG['training']['tau'], 
            CONFIG['training']['policy_noise'], 
            CONFIG['training']['noise_clip'], 
            CONFIG['training']['policy_freq'],
        )
    else: 
        policy.train(
            history_sampler,#replay_buffer, 
            int(CONFIG['training']['eval_freq']),
            CONFIG['training']['batch_size'], 
            CONFIG['training']['discount'], 
            CONFIG['training']['tau']
        )
    persist_event('train_epoch_completed', f'Trained policy for {len(dataloader)} batches of {dataloader.batch_size}')
      
    # Upload policy weights to S3, to be picked up by instances running the Rollout Distributed process.
    print(f"SAving policy checkpoints to {CONFIG['distributed']['policy_weights_dir_s3']}{CONFIG['distributed']['policy_weights_basename']}*")
    save_s3_model_checkpoint(
        policy, 
        s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
        basename=CONFIG['distributed']['policy_weights_basename'],
    )
    persist_event('train_update_s3_checkpoint', f'Uploaded policy checkpoint to {CONFIG["distributed"]["policy_weights_dir_s3"]}{CONFIG["distributed"]["policy_weights_basename"]}*')
    
    # Also upload policy weights under unique name as a historical checkpoint.
    total_timesteps = get_total_timesteps()
    evalname = f"{CONFIG['distributed']['policy_weights_basename']}_T{total_timesteps}_{datetime.datetime.now().isoformat()}"
    print(f"SAving policy checkpoints to {CONFIG['distributed']['policy_weights_dir_s3']}{evalname}*")
    save_s3_model_checkpoint(
        policy, 
        s3_dir=CONFIG['distributed']['policy_weights_dir_s3'],
        basename=evalname,
    )
    persist_event('train_save_historical_checkpoint', f'Uploaded policy checkpoint to {CONFIG["distributed"]["policy_weights_dir_s3"]}{evalname}*')

    # Run Evaluation script
    # evaldir = str(uuid.uuid4())
    # print(f"SAving policy checkpoints to {evaldir}/{evalname}*")
    # os.makedirs(evaldir, exist_ok=True)
    # policy.save(evaldir, evalname)
    # evalcmd = f"CHECKPOINT_DIR={evaldir} CHECKPOINT_NAME={evalname} pipenv run python evaluate_policy.py"
    # print(f"Launching evaluation script with cmd: `{evalcmd}`")
    # subprocess.Popen(evalcmd, shell=True)
    
    

Train model: 100%|██████████| 2500/2500 [06:57<00:00,  5.99batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:27:53.108049*


Train model: 100%|██████████| 2500/2500 [05:49<00:00,  9.10batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:33:44.266989*


Train model: 100%|██████████| 2500/2500 [05:53<00:00,  8.38batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:39:39.434553*


Train model: 100%|██████████| 2500/2500 [05:55<00:00,  7.04batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:45:36.113071*


Train model: 100%|██████████| 2500/2500 [05:50<00:00,  7.14batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:51:28.709823*


Train model: 100%|██████████| 2500/2500 [05:50<00:00,  7.13batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T19:57:20.749984*


Train model: 100%|██████████| 2500/2500 [05:53<00:00,  7.16batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T21638_2018-10-19T20:03:16.300675*


Train model: 100%|██████████| 2500/2500 [05:52<00:00,  7.10batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T33334_2018-10-19T20:09:10.137125*


Train model: 100%|██████████| 2500/2500 [05:52<00:00,  7.09batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T47909_2018-10-19T20:15:04.372928*


Train model: 100%|██████████| 2500/2500 [05:55<00:00,  7.04batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T51472_2018-10-19T20:21:01.083062*


Train model: 100%|██████████| 2500/2500 [05:52<00:00,  7.10batch/s]


SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3*
SAving policy checkpoints to s3://colllin-nips-2018-prosthetics/checkpoints/checkpoint_TD3_T61332_2018-10-19T20:26:54.947192*


Train model:  42%|████▏     | 1048/2500 [02:27<03:56,  6.15batch/s]