In [1]:
import sys,os
from pathlib import Path
curr_path = str(Path().absolute())
parent_path = str(Path().absolute().parent)
sys.path.append(parent_path) # add current terminal path to sys.path
from MORLS.env.SimEnv import SimEnv

# benchmark (continuous cartpole) for fun
import mbrl.env.cartpole_continuous as cartpole_env

In [2]:
import argparse
import time
import gym
import torch
import numpy as np
from itertools import count

import logging

import os
import os.path as osp
import json

from sac.replay_memory import ReplayMemory
from sac.sac import SAC
from model import EnsembleDynamicsModel
from predict_env import PredictEnv
from sample_env import EnvSampler
from tf_models.constructor import construct_model, format_samples_for_training


In [3]:
def train(args, env_sampler, predict_env, agent, env_pool, model_pool):
    total_step = 0
    reward_sum = 0
    rollout_length = 1
    exploration_before_start(args, env_sampler, env_pool, agent)

    for epoch_step in range(args.num_epoch):
        print('epoch_step: {}'.format(epoch_step))
        start_step = total_step
        train_policy_steps = 0
        for i in count():
            if i%50==0: print('i:{}'.format(i))
            
            cur_step = total_step - start_step
            if cur_step%50==0:print('cur step:{}'.format(cur_step))
            #if cur_step >= start_step + args.epoch_length and len(env_pool) > args.min_pool_size:
            if cur_step >=args.epoch_length and len(env_pool) > args.min_pool_size:
                break

            if cur_step > 0 and cur_step % args.model_train_freq == 0 and args.real_ratio < 1.0:
                train_predict_model(args, env_pool, predict_env)

                new_rollout_length = set_rollout_length(args, epoch_step)
                if rollout_length != new_rollout_length:
                    rollout_length = new_rollout_length
                    model_pool = resize_model_pool(args, rollout_length, model_pool)

                rollout_model(args, predict_env, agent, model_pool, env_pool, rollout_length)

            cur_state, action, next_state, reward, done, info = env_sampler.sample(agent)
            env_pool.push(cur_state, action, reward, next_state, done)

            if len(env_pool) > args.min_pool_size:
                train_policy_steps += train_policy_repeats(args, total_step, train_policy_steps, cur_step, env_pool, model_pool, agent)

            total_step += 1

            if total_step % 100 == 0:
                print('total step:{}'.format(total_step))
                '''
                avg_reward_len = min(len(env_sampler.path_rewards), 5)
                avg_reward = sum(env_sampler.path_rewards[-avg_reward_len:]) / avg_reward_len
                logging.info("Step Reward: " + str(total_step) + " " + str(env_sampler.path_rewards[-1]) + " " + str(avg_reward))
                print(total_step, env_sampler.path_rewards[-1], avg_reward)
                '''
                env_sampler.current_state = None
                sum_reward = 0
                done = False
                while not done:
                    cur_state, action, next_state, reward, done, info = env_sampler.sample(agent, eval_t=True)
                    sum_reward += reward
                # logger.record_tabular("total_step", total_step)
                # logger.record_tabular("sum_reward", sum_reward)
                # logger.dump_tabular()
                logging.info("Step Reward: " + str(total_step) + " " + str(sum_reward))
                print(total_step, sum_reward)
                

def exploration_before_start(args, env_sampler, env_pool, agent):
    for i in range(args.init_exploration_steps):
        #print("exploration i:{}".format(i))
        cur_state, action, next_state, reward, done, info = env_sampler.sample(agent)
        #print("cur_state:{}".format(cur_state))
        #print("action:{}".format(action))
        #print("next_state:{}".format(next_state))
        #print("reward:{}".format(reward))
        env_pool.push(cur_state, action, reward, next_state, done)
    print("done exploration before starting")

def set_rollout_length(args, epoch_step):
    rollout_length = (min(max(args.rollout_min_length + (epoch_step - args.rollout_min_epoch)
                              / (args.rollout_max_epoch - args.rollout_min_epoch) * (args.rollout_max_length - args.rollout_min_length),
                              args.rollout_min_length), args.rollout_max_length))
    return int(rollout_length)


def train_predict_model(args, env_pool, predict_env):
    print("train predict model")
    # Get all samples from environment
    state, action, reward, next_state, done = env_pool.sample(len(env_pool))
    delta_state = next_state - state
    inputs = np.concatenate((state, action), axis=-1)
    labels = np.concatenate((np.reshape(reward, (reward.shape[0], -1)), delta_state), axis=-1)

    predict_env.model.train(inputs, labels, batch_size=256, holdout_ratio=0.2)


def resize_model_pool(args, rollout_length, model_pool):
    rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
    model_steps_per_epoch = int(rollout_length * rollouts_per_epoch)
    new_pool_size = args.model_retain_epochs * model_steps_per_epoch

    sample_all = model_pool.return_all()
    new_model_pool = ReplayMemory(new_pool_size)
    new_model_pool.push_batch(sample_all)

    return new_model_pool


def rollout_model(args, predict_env, agent, model_pool, env_pool, rollout_length):
    state, action, reward, next_state, done = env_pool.sample_all_batch(args.rollout_batch_size)
    for i in range(rollout_length):
        # TODO: Get a batch of actions
        action = agent.select_action(state)
        next_states, rewards, terminals, info = predict_env.step(state, action)
        # TODO: Push a batch of samples
        model_pool.push_batch([(state[j], action[j], rewards[j], next_states[j], terminals[j]) for j in range(state.shape[0])])
        nonterm_mask = ~terminals.squeeze(-1)
        if nonterm_mask.sum() == 0:
            break
        state = next_states[nonterm_mask]


def train_policy_repeats(args, total_step, train_step, cur_step, env_pool, model_pool, agent):
    #print("train policy repeats")
    if total_step % args.train_every_n_steps > 0:
        return 0

    if train_step > args.max_train_repeat_per_step * total_step:
        return 0

    for i in range(args.num_train_repeat):
        env_batch_size = int(args.policy_train_batch_size * args.real_ratio)
        model_batch_size = args.policy_train_batch_size - env_batch_size

        env_state, env_action, env_reward, env_next_state, env_done = env_pool.sample(int(env_batch_size))

        if model_batch_size > 0 and len(model_pool) > 0:
            model_state, model_action, model_reward, model_next_state, model_done = model_pool.sample_all_batch(int(model_batch_size))
            batch_state, batch_action, batch_reward, batch_next_state, batch_done = np.concatenate((env_state, model_state), axis=0), \
                                                                                    np.concatenate((env_action, model_action),
                                                                                                   axis=0), np.concatenate(
                (np.reshape(env_reward, (env_reward.shape[0], -1)), model_reward), axis=0), \
                                                                                    np.concatenate((env_next_state, model_next_state),
                                                                                                   axis=0), np.concatenate(
                (np.reshape(env_done, (env_done.shape[0], -1)), model_done), axis=0)
        else:
            batch_state, batch_action, batch_reward, batch_next_state, batch_done = env_state, env_action, env_reward, env_next_state, env_done

        batch_reward, batch_done = np.squeeze(batch_reward), np.squeeze(batch_done)
        batch_done = (~batch_done).astype(int)
        agent.update_parameters((batch_state, batch_action, batch_reward, batch_next_state, batch_done), args.policy_train_batch_size, i)

    return args.num_train_repeat


In [4]:
def main(args=None):
    if args is None:
        args = readParser()


    
    env=SimEnv()
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    env.seed(args.seed)

    # Intial agent
    agent = SAC(env.observation_space.shape[0], env.action_space, args)

    # Initial ensemble model
    state_size = np.prod(env.observation_space.shape)
    action_size = np.prod(env.action_space.shape)
    print(env.action_space)
    if args.model_type == 'pytorch':
        env_model = EnsembleDynamicsModel(args.num_networks, args.num_elites, state_size, action_size, args.reward_size, args.pred_hidden_size,
                                          use_decay=args.use_decay)
    else:
        env_model = construct_model(obs_dim=state_size, act_dim=action_size, hidden_dim=args.pred_hidden_size, num_networks=args.num_networks,
                                    num_elites=args.num_elites)

    
    # Predict environments
    predict_env = PredictEnv(env_model, args.model_type)

   
    # Initial pool for env
    env_pool = ReplayMemory(args.replay_size)
    print(env_pool)
    
    # Initial pool for model
    rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
    model_steps_per_epoch = int(1 * rollouts_per_epoch)
    new_pool_size = args.model_retain_epochs * model_steps_per_epoch
    model_pool = ReplayMemory(new_pool_size)
    
    
    # Sampler of environment
    env_sampler = EnvSampler(env)
    print('-------------------final training!----------------------')
    train(args, env_sampler, predict_env, agent, env_pool, model_pool)


In [5]:
parser = argparse.ArgumentParser(description='MBPO')
parser.add_argument('--seed', type=int, default=123456, metavar='N',
                    help='random seed (default: 123456)')

parser.add_argument('--use_decay', type=bool, default=True, metavar='G',
                    help='discount factor for reward (default: 0.99)')

parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                        term against the reward (default: 0.2)')
# parser.add_argument('--policy', default="Gaussian",
#                     help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')

parser.add_argument('--num_networks', type=int, default=3, metavar='E',
                    help='ensemble size (default: 7)')
parser.add_argument('--num_elites', type=int, default=5, metavar='E',
                    help='elite size (default: 5)')
parser.add_argument('--pred_hidden_size', type=int, default=200, metavar='E',
                    help='hidden size for predictive model')
parser.add_argument('--reward_size', type=int, default=1, metavar='E',
                    help='environment reward size')

parser.add_argument('--replay_size', type=int, default=10000, metavar='N',
                    help='size of replay buffer (default: 10000000)')

parser.add_argument('--model_retain_epochs', type=int, default=1, metavar='A',
                    help='retain epochs')
parser.add_argument('--model_train_freq', type=int, default=90, metavar='A',
                    help='frequency of training')
parser.add_argument('--rollout_batch_size', type=int, default=1000, metavar='A',
                    help='rollout number M')
parser.add_argument('--epoch_length', type=int, default=100, metavar='A',
                    help='steps per epoch')
parser.add_argument('--rollout_min_epoch', type=int, default=20, metavar='A',
                    help='rollout min epoch')
parser.add_argument('--rollout_max_epoch', type=int, default=150, metavar='A',
                    help='rollout max epoch')
parser.add_argument('--rollout_min_length', type=int, default=1, metavar='A',
                    help='rollout min length')
parser.add_argument('--rollout_max_length', type=int, default=15, metavar='A',
                    help='rollout max length')
parser.add_argument('--num_epoch', type=int, default=50, metavar='A',
                    help='total number of epochs')
parser.add_argument('--min_pool_size', type=int, default=100, metavar='A',
                    help='minimum pool size')
parser.add_argument('--real_ratio', type=float, default=0.05, metavar='A',
                    help='ratio of env samples / model samples')
parser.add_argument('--train_every_n_steps', type=int, default=1, metavar='A',
                    help='frequency of training policy')
parser.add_argument('--num_train_repeat', type=int, default=20, metavar='A',
                    help='times to training policy per step')
parser.add_argument('--max_train_repeat_per_step', type=int, default=5, metavar='A',
                    help='max training times per step')
parser.add_argument('--policy_train_batch_size', type=int, default=256, metavar='A',
                    help='batch size for training policy')
parser.add_argument('--init_exploration_steps', type=int, default=1000, metavar='A',
                    help='exploration steps initially')

parser.add_argument('--model_type', default='pytorch', metavar='A',
                    help='predict model -- pytorch or tensorflow')

parser.add_argument('--cuda', default=True, action="store_true",
                    help='run on CUDA (default: True)')

_StoreTrueAction(option_strings=['--cuda'], dest='cuda', nargs=0, const=True, default=True, type=None, choices=None, help='run on CUDA (default: True)', metavar=None)

In [6]:
args = parser.parse_args(args=[])

In [7]:
args

Namespace(alpha=0.2, automatic_entropy_tuning=False, cuda=True, epoch_length=100, gamma=0.99, hidden_size=256, init_exploration_steps=1000, lr=0.0003, max_train_repeat_per_step=5, min_pool_size=100, model_retain_epochs=1, model_train_freq=90, model_type='pytorch', num_elites=5, num_epoch=50, num_networks=3, num_train_repeat=20, policy='Gaussian', policy_train_batch_size=256, pred_hidden_size=200, real_ratio=0.05, replay_size=10000, reward_size=1, rollout_batch_size=1000, rollout_max_epoch=150, rollout_max_length=15, rollout_min_epoch=20, rollout_min_length=1, seed=123456, target_update_interval=1, tau=0.005, train_every_n_steps=1, use_decay=True)

In [8]:
if __name__ == '__main__':
    main(args)

Box(2,)
<sac.replay_memory.ReplayMemory object at 0x00000140A42AB108>
-------------------final training!----------------------
done exploration before starting
epoch_step: 0
i:0
cur step:0
i:50
cur step:50
train predict model




epoch: 0, holdout mse losses: [46.29018 46.4403  46.31322]
epoch: 1, holdout mse losses: [44.89128 45.46177 44.88377]
epoch: 2, holdout mse losses: [40.81736 42.25771 40.60838]
epoch: 3, holdout mse losses: [33.08612 35.66463 32.69725]
epoch: 4, holdout mse losses: [24.03356 27.55015 23.49522]
epoch: 5, holdout mse losses: [16.5349  24.08055 18.39691]
epoch: 6, holdout mse losses: [12.99663 22.17967 17.59348]
epoch: 7, holdout mse losses: [13.91541 18.83678 14.42812]
epoch: 8, holdout mse losses: [ 7.20264 14.8828   8.55309]
epoch: 9, holdout mse losses: [ 2.64997 11.49969  4.29515]
epoch: 10, holdout mse losses: [1.56292 8.27298 1.86567]
epoch: 11, holdout mse losses: [1.74043 4.41443 1.46596]
epoch: 12, holdout mse losses: [1.37803 1.5441  1.56998]
epoch: 13, holdout mse losses: [0.76081 1.21147 0.92863]
epoch: 14, holdout mse losses: [0.38997 1.57873 0.44257]
epoch: 15, holdout mse losses: [0.3029  0.88416 0.35642]
epoch: 16, holdout mse losses: [0.31688 0.44904 0.35005]
epoch: 17, 

epoch: 143, holdout mse losses: [0.01414 0.01883 0.01518]
epoch: 144, holdout mse losses: [0.01461 0.0156  0.01404]
epoch: 145, holdout mse losses: [0.01689 0.01402 0.01503]
epoch: 146, holdout mse losses: [0.01473 0.01505 0.01493]
epoch: 147, holdout mse losses: [0.01464 0.01213 0.03312]
epoch: 148, holdout mse losses: [0.0179  0.01467 0.01662]
epoch: 149, holdout mse losses: [0.02837 0.01282 0.01974]
epoch: 150, holdout mse losses: [0.0132  0.01166 0.02585]
epoch: 151, holdout mse losses: [0.01438 0.01082 0.02037]
epoch: 152, holdout mse losses: [0.0157  0.01105 0.02425]
epoch: 153, holdout mse losses: [0.01133 0.01175 0.02268]
epoch: 154, holdout mse losses: [0.01664 0.01109 0.01747]
epoch: 155, holdout mse losses: [0.01183 0.01163 0.6836 ]
epoch: 156, holdout mse losses: [0.01027 0.01113 0.25629]
epoch: 157, holdout mse losses: [0.01003 0.01123 0.60019]
epoch: 158, holdout mse losses: [0.00983 0.01297 0.10296]
epoch: 159, holdout mse losses: [0.0107  0.02174 0.13195]
epoch: 160, ho

epoch: 21, holdout mse losses: [0.00822 0.11394 0.05704]
epoch: 22, holdout mse losses: [0.00766 0.10808 0.03291]
epoch: 23, holdout mse losses: [0.0073  0.10592 0.05442]
epoch: 24, holdout mse losses: [0.0072  0.09496 0.04263]
epoch: 25, holdout mse losses: [0.00713 0.08856 0.0354 ]
epoch: 26, holdout mse losses: [0.00707 0.08566 0.02117]
epoch: 27, holdout mse losses: [0.007   0.08173 0.01788]
epoch: 28, holdout mse losses: [0.00724 0.07827 0.02092]
epoch: 29, holdout mse losses: [0.00715 0.07736 0.01489]
epoch: 30, holdout mse losses: [0.00724 0.075   0.01576]
epoch: 31, holdout mse losses: [0.00739 0.07151 0.01431]
total step:600
600 89.53636461496353
i:100
cur step:100
epoch_step: 6
i:0
cur step:0
i:50
cur step:50
train predict model
epoch: 0, holdout mse losses: [0.11722 0.06916 0.01055]
epoch: 1, holdout mse losses: [0.09446 0.06652 0.0117 ]
epoch: 2, holdout mse losses: [0.02886 0.05933 0.01249]
epoch: 3, holdout mse losses: [0.01616 0.05544 0.00983]
epoch: 4, holdout mse losse

epoch: 62, holdout mse losses: [0.07036 0.00756 0.05929]
epoch: 63, holdout mse losses: [0.02204 0.00746 0.00447]
epoch: 64, holdout mse losses: [0.01639 0.00958 0.01954]
epoch: 65, holdout mse losses: [0.01478 0.00842 0.00782]
epoch: 66, holdout mse losses: [0.00548 0.00798 0.00699]
epoch: 67, holdout mse losses: [0.00747 0.0077  0.00496]
epoch: 68, holdout mse losses: [0.00639 0.00671 0.00655]
epoch: 69, holdout mse losses: [0.00532 0.00668 0.00605]
epoch: 70, holdout mse losses: [0.00548 0.00761 0.00401]
epoch: 71, holdout mse losses: [0.00593 0.007   0.00433]
epoch: 72, holdout mse losses: [0.00648 0.00717 0.00511]
epoch: 73, holdout mse losses: [0.00427 0.00713 0.00424]
epoch: 74, holdout mse losses: [0.00392 0.0065  0.0049 ]
epoch: 75, holdout mse losses: [0.00451 0.00819 0.00517]
epoch: 76, holdout mse losses: [0.005   0.00635 0.00465]
epoch: 77, holdout mse losses: [0.00434 0.00638 0.00572]
epoch: 78, holdout mse losses: [0.0046  0.0078  0.00541]
epoch: 79, holdout mse losses: 

total step:1500
1500 93.90983426570892
i:100
cur step:100
epoch_step: 15
i:0
cur step:0
i:50
cur step:50
train predict model
epoch: 0, holdout mse losses: [0.01083 0.00369 0.00587]
epoch: 1, holdout mse losses: [0.00772 0.00355 0.00375]
epoch: 2, holdout mse losses: [0.00738 0.007   0.00382]
epoch: 3, holdout mse losses: [0.00717 0.00477 0.00346]
epoch: 4, holdout mse losses: [0.00502 0.00355 0.00261]
epoch: 5, holdout mse losses: [0.00586 0.00282 0.00253]
epoch: 6, holdout mse losses: [0.00524 0.00296 0.00247]
epoch: 7, holdout mse losses: [0.00442 0.00243 0.00259]
epoch: 8, holdout mse losses: [0.00505 0.00691 0.00255]
epoch: 9, holdout mse losses: [0.00449 0.03403 0.00424]
epoch: 10, holdout mse losses: [0.00502 0.02665 0.01564]
epoch: 11, holdout mse losses: [0.0049  0.01159 0.008  ]
epoch: 12, holdout mse losses: [0.00498 0.00848 0.00456]
total step:1600
1600 84.99829983711243
i:100
cur step:100
epoch_step: 16
i:0
cur step:0
i:50
cur step:50


KeyboardInterrupt: 