In [1]:
import time, datetime
import copy
import os
import sys
import warnings
#warnings.filterwarnings("ignore", category=UserWarning)
#warnings.filterwarnings("ignore", category=RuntimeWarning)

import numpy as np
from loguru import logger
import yaml
from utils import dumb_reward_plot
import gym

sys.path.append('./envs/cartpole-envs')
sys.path.append('./')
import cartpole_envs
#import highway_env

from utils import plot_reward, plot_index
from mpc.mpc_cp import MPC
from baselines.NP_epi import NP

def prepare_dynamics(gym_config):
    dynamics_name = gym_config['dynamics_name']
    seed = gym_config['seed']
    dynamics_set = []
    for i in range(len(dynamics_name)):
        env = gym.make(dynamics_name[i])
        # env.seed(seed)
        dynamics_set.append(gym.make(dynamics_name[i]))
    
    # use pre-defined env sequence
    task = [dynamics_set[i] for i in gym_config['task_dynamics_list']]
    return task

def load_config(config_path="config.yml"):
    if os.path.isfile(config_path):
        f = open(config_path)
        return yaml.load(f, Loader=yaml.FullLoader)
    else:
        raise Exception("Configuration file is not found in the path: "+config_path)


In [2]:
# config = load_config('config/config_cpstable_np.yml')
# config = load_config('config/config_cp_np.yml')
config = load_config('config/config_cp_np_delay.yml')
mpc_config = config['mpc_config']
gym_config = config['gym_config']
render = gym_config['render']
np_config = config['NP_config']

# initialize the mixture model
# NOTICE: Model
#  ###########
model = NP(NP_config=np_config)
logger.info('Using model: {}', model.name)

# initial MPC controller
mpc_controller = MPC(mpc_config=mpc_config)

# prepare task
# the task is solved, if each dynamic is solved
task = prepare_dynamics(gym_config)
print(gym_config)

"""start DPGP-MBRL"""
data_buffer = []
label_list = []
subtask_list = []
subtask_reward = []
subtask_succ_count = [0]
comp_trainable = [1]
task_reward = []
trainable = True
task_solved = False
subtask_solved = [False, False, False, False]
total_count = 0
task_epi = 0
log_name = None

max_delay_step = gym_config["max_delay_step"]
total_tasks = 1  #4

2020-05-24 20:47:11.328 | INFO     | __main__:<module>:13 - Using model: NP


{'render': False, 'task_dynamics_list': [0, 1, 2, 3], 'subtask_episode': 3, 'subtask_episode_length': 200, 'task_episode': 100, 'seed': 1000, 'dynamics_name': ['CartPoleSwingUpEnvCm05Pm04Pl05-v0', 'CartPoleSwingUpEnvCm05Pm04Pl07-v0', 'CartPoleSwingUpEnvCm05Pm08Pl05-v0', 'CartPoleSwingUpEnvCm05Pm08Pl07-v0'], 'max_delay_step': 4}


In [3]:
"""NP pretrain"""
pretrain_episodes = 100
for task_idx in range(total_tasks):
    env = task[task_idx]
    # data collection
    for epi in range(pretrain_episodes):
        act_buf = np.zeros(max_delay_step+1)
        delay_step = np.random.randint(0,max_delay_step-1)
#         delay_step = 4
        obs = env.reset()
        done = False
        mpc_controller.reset()
        while not done:
            action = env.action_space.sample()
            act_buf = np.concatenate((act_buf, action))[1:]
            obs_next, reward, done, state_next = env.step([act_buf[-1-delay_step]])
#             print(np.concatenate((obs, act_buf[:-1])),action, np.concatenate((obs_next - obs, act_buf[1:])) )
            model.data_process([0, np.concatenate((obs, act_buf[:-1])), action, np.concatenate((obs_next - obs, act_buf[1:]))])
            obs = copy.deepcopy(obs_next)
        model.fit()
        model.reset()
# NOTICE: Model
#  ###########
# model.validation_flag = True
# model.fit()

In [4]:
"""testing the model with MPC while training """
test_episode = 1
test_epoch = 1000
log = []
for ep in range(test_epoch):
    print('epoch: ', ep)
    for task_idx in range(total_tasks):
        env = task[task_idx]
        print('task: ', task_idx)
        for epi in range(test_episode):
            #print('episode: ', epi)
            acc_reward = 0
            obs = env.reset()
            act_buf = np.zeros(max_delay_step+1)
            delay_step = np.random.randint(0,max_delay_step)
#             delay_step = 4
            O, A, R, acc_reward, done = [], [], [], 0, False
            mpc_controller.reset()
            i = 0
            while not done:
                i+= 1
                # env.render()
                env_copy = prepare_dynamics(gym_config)[task_idx]
                env_copy.reset()
                # NOTICE: Model
                #  ###########
                act_buf = act_buf[1:]
                action = np.array([0.0])
                if i > 1:
                    action = np.array([mpc_controller.act(task=env_copy, model=model, state=np.concatenate((obs, act_buf)))])
                act_buf = np.concatenate((act_buf, action))
                obs_next, reward, done, state_next = env.step([act_buf[-1-delay_step]])
                model.data_process([0, np.concatenate((obs, act_buf[:-1])), action, np.concatenate((obs_next - obs, act_buf[1:]))])
#                 print(np.concatenate((obs, act_buf[:-1])), action, np.concatenate((obs_next - obs, act_buf[1:])))
                A.append(action)
                O.append(state_next)
                R.append(reward)
                obs = copy.deepcopy(obs_next)
                acc_reward += reward
                # logger.info('reward: {}', reward)
                #time.sleep(0.1)
            print('task: ', task_idx,'delay: ', delay_step,'step: ', i, 'acc_reward: ', acc_reward)
            env.close()

            if done:
                samples = {
                    "obs": np.array(O),
                    "actions": np.array(A),
                    "rewards": np.array(R),
                    "reward_sum": acc_reward,
                }
                print('******************')
                print('acc_reward', acc_reward)
                print('******************')
                log.append(samples)
                if log_name is None:
                    log_name = datetime.datetime.now()
                path = './misc/log/' + log_name.strftime("%d-%H-%M") + '.npy'
                np.save(path, log, allow_pickle=True)
                dumb_reward_plot(path)
            model.fit()
            model.reset()

        # use the collected date to train model
#         print('fitting the model...')
        #model.n_epochs = 20

        # NOTICE: Model
        #  ###########


epoch:  0
task:  0
task:  0 delay:  1 step:  200 acc_reward:  31.724881510974296
******************
acc_reward 31.724881510974296
******************
epoch:  1
task:  0
task:  0 delay:  0 step:  200 acc_reward:  38.31195246944633
******************
acc_reward 38.31195246944633
******************
epoch:  2
task:  0
task:  0 delay:  2 step:  200 acc_reward:  37.2869399043944
******************
acc_reward 37.2869399043944
******************
epoch:  3
task:  0
task:  0 delay:  2 step:  200 acc_reward:  51.08440320972407
******************
acc_reward 51.08440320972407
******************
epoch:  4
task:  0
task:  0 delay:  1 step:  200 acc_reward:  68.34341225289945
******************
acc_reward 68.34341225289945
******************
epoch:  5
task:  0
task:  0 delay:  0 step:  200 acc_reward:  75.73660322962249
******************
acc_reward 75.73660322962249
******************
epoch:  6
task:  0
task:  0 delay:  2 step:  200 acc_reward:  68.02388064003732
******************
acc_reward 68.02388

epoch:  56
task:  0
task:  0 delay:  3 step:  200 acc_reward:  34.06967769744963
******************
acc_reward 34.06967769744963
******************
epoch:  57
task:  0
task:  0 delay:  2 step:  200 acc_reward:  70.32575431529013
******************
acc_reward 70.32575431529013
******************
epoch:  58
task:  0
task:  0 delay:  2 step:  175 acc_reward:  66.28167724991421
******************
acc_reward 66.28167724991421
******************
epoch:  59
task:  0
task:  0 delay:  1 step:  200 acc_reward:  94.38064848389209
******************
acc_reward 94.38064848389209
******************
epoch:  60
task:  0
task:  0 delay:  1 step:  200 acc_reward:  104.84991983779176
******************
acc_reward 104.84991983779176
******************
epoch:  61
task:  0
task:  0 delay:  1 step:  200 acc_reward:  72.67470648919154
******************
acc_reward 72.67470648919154
******************
epoch:  62
task:  0
task:  0 delay:  1 step:  200 acc_reward:  65.46690155965322
******************
acc_reward

epoch:  112
task:  0
task:  0 delay:  2 step:  200 acc_reward:  60.00857851437036
******************
acc_reward 60.00857851437036
******************
epoch:  113
task:  0
task:  0 delay:  2 step:  200 acc_reward:  56.75704093155251
******************
acc_reward 56.75704093155251
******************
epoch:  114
task:  0
task:  0 delay:  2 step:  200 acc_reward:  60.33712874060047
******************
acc_reward 60.33712874060047
******************
epoch:  115
task:  0
task:  0 delay:  2 step:  200 acc_reward:  63.230817584157414
******************
acc_reward 63.230817584157414
******************
epoch:  116
task:  0
task:  0 delay:  2 step:  200 acc_reward:  66.91950164518893
******************
acc_reward 66.91950164518893
******************
epoch:  117
task:  0
task:  0 delay:  1 step:  200 acc_reward:  79.95204962923161
******************
acc_reward 79.95204962923161
******************
epoch:  118
task:  0
task:  0 delay:  1 step:  159 acc_reward:  70.40578012434013
******************
acc

epoch:  167
task:  0
task:  0 delay:  0 step:  200 acc_reward:  164.3633565202954
******************
acc_reward 164.3633565202954
******************
epoch:  168
task:  0
task:  0 delay:  0 step:  200 acc_reward:  172.59517489552263
******************
acc_reward 172.59517489552263
******************
epoch:  169
task:  0
task:  0 delay:  1 step:  200 acc_reward:  73.31690189053238
******************
acc_reward 73.31690189053238
******************
epoch:  170
task:  0
task:  0 delay:  0 step:  200 acc_reward:  71.23909180869961
******************
acc_reward 71.23909180869961
******************
epoch:  171
task:  0
task:  0 delay:  3 step:  114 acc_reward:  22.866192887374538
******************
acc_reward 22.866192887374538
******************
epoch:  172
task:  0
task:  0 delay:  1 step:  148 acc_reward:  23.557768069101858
******************
acc_reward 23.557768069101858
******************
epoch:  173
task:  0
task:  0 delay:  3 step:  200 acc_reward:  41.06507013485047
******************

KeyboardInterrupt: 