In [1]:
import time, datetime
import copy
import os
import sys
import warnings
#warnings.filterwarnings("ignore", category=UserWarning)
#warnings.filterwarnings("ignore", category=RuntimeWarning)

import numpy as np
from loguru import logger
import yaml
from utils import dumb_reward_plot
import gym

sys.path.append('./envs/cartpole-envs')
sys.path.append('./')
import cartpole_envs
#import highway_env

from utils import plot_reward, plot_index
from mpc.mpc_cp import MPC
from baselines.NN import NN

def prepare_dynamics(gym_config):
    dynamics_name = gym_config['dynamics_name']
    seed = gym_config['seed']
    dynamics_set = []
    for i in range(len(dynamics_name)):
        env = gym.make(dynamics_name[i])
        # env.seed(seed)
        dynamics_set.append(gym.make(dynamics_name[i]))
    
    # use pre-defined env sequence
    task = [dynamics_set[i] for i in gym_config['task_dynamics_list']]
    return task

def load_config(config_path="config.yml"):
    if os.path.isfile(config_path):
        f = open(config_path)
        return yaml.load(f, Loader=yaml.FullLoader)
    else:
        raise Exception("Configuration file is not found in the path: "+config_path)


In [2]:
config = load_config('config/config_swingup.yml')
nn_config = config['NN_config']
mpc_config = config['mpc_config']
gym_config = config['gym_config']
render = gym_config['render']

# initialize the mixture model
# model = DPGPMM(dpgp_config=dpgp_config)
# model = SingleSparseGP(sparse_gp_config=sparse_gp_config)
# model = SingleGP(gp_config=gp_config)
model = NN(NN_config=nn_config)
logger.info('Using model: {}', model.name)

# initial MPC controller
mpc_controller = MPC(mpc_config=mpc_config)

# prepare task
# the task is solved, if each dynamic is solved
task = prepare_dynamics(gym_config)
print(gym_config)

"""start DPGP-MBRL"""
data_buffer = []
label_list = []
subtask_list = []
subtask_reward = []
subtask_succ_count = [0]
comp_trainable = [1]
task_reward = []
trainable = True
task_solved = False
subtask_solved = [False, False, False, False]
total_count = 0
task_epi = 0
log_name = None

total_tasks = 1


2020-08-05 18:07:32.811 | INFO     | __main__:<module>:12 - Using model: NN


{'render': False, 'task_dynamics_list': [0], 'subtask_episode': 3, 'subtask_episode_length': 200, 'task_episode': 100, 'seed': 1000, 'dynamics_name': ['CartPoleSwingUpEnvCm05Pm05Pl05-v0', 'CartPoleSwingUpEnvCm05Pm04Pl07-v0', 'CartPoleSwingUpEnvCm05Pm08Pl05-v0', 'CartPoleSwingUpEnvCm05Pm08Pl07-v0']}


In [9]:
"""NN pretrain"""
pretrain_episodes = 1
for task_idx in range(total_tasks):
    env = task[task_idx]
    # data collection
    for epi in range(pretrain_episodes):
        obs = env.reset()
        done = False
        mpc_controller.reset()
        while not done:
            action = env.action_space.sample()
            obs_next, reward, done, state_next = env.step(action)
            model.data_process([0, obs, action, obs_next - obs])
            obs = obs_next

#print('collected data: ', len(data))
# training the model
model.validation_flag = True
#model.n_epochs = 20
model.fit()

data size:  41
torch.Size([41, 6])
torch.Size([41, 6]) torch.Size([41, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch.Size([33, 5]) torch.Size([33, 5])
torch.Size([33, 6]) torch

0.0018703675596043468

In [10]:
"""testing the model with MPC while training """
test_episode = 2
test_epoch = 20
log = []
for ep in range(test_epoch):
    print('epoch: ', ep)
    for task_idx in range(total_tasks):
        env = task[task_idx]
        print('task: ', task_idx)
        for epi in range(test_episode):
            #print('episode: ', epi)
            acc_reward = 0
            obs = env.reset()

            O, A, R, acc_reward, done = [], [], [], 0, False
            mpc_controller.reset()
            i = 0
            while not done:
                i+= 1

                # env.render()
                env_copy = prepare_dynamics(gym_config)[task_idx]
                env_copy.reset()
                action = np.array([mpc_controller.act(task=env_copy, model=model, state=obs)])
                obs_next, reward, done, state_next = env.step(action)
                A.append(action)
                O.append(state_next)
                R.append(reward)

                # append data but not training
                model.data_process([0, obs, action, obs_next - obs])
                obs = copy.deepcopy(obs_next)
                acc_reward += reward
                # logger.info('reward: {}', reward)
                #time.sleep(0.1)
            print('task: ', task_idx,'step: ', i, 'acc_reward: ', acc_reward)
            env.close()

            if done:
                samples = {
                    "obs": np.array(O),
                    "actions": np.array(A),
                    "rewards": np.array(R), 
                    "reward_sum": acc_reward,
                }
                print('******************')
                print('acc_reward', acc_reward)
                print('******************')
                log.append(samples)
                if log_name is None:
                    log_name = datetime.datetime.now()
                path = './misc/log/' + log_name.strftime("%d-%H-%M") + '.npy'
                np.save(path, log, allow_pickle=True)
                dumb_reward_plot(path)

        # use the collected date to train model
        print('fitting the model...')
        #model.n_epochs = 20
        model.fit()

epoch:  0
task:  0
[0.21687567]
[-0.5307585]
[0.4572252]
[-0.0628475]
[0.6124457]
[-0.0209409]
[-0.97256005]
[0.1026597]
[0.01417851]
[-0.80613303]
[-0.130265]
[-0.601365]
[-0.98650706]
[-0.53677845]
[-0.7855493]
[-0.61960626]
[-0.80024505]
[0.80684364]
[0.7455486]
[0.41950154]
[-0.4499613]
[0.13148665]
[-0.07458448]
[0.16239119]
[-0.30847967]
[-0.12946093]
[0.75982046]
[0.38173735]
[0.2123164]
[0.00515008]
[0.09787834]
[0.7925559]
[0.85437846]
[0.82264864]
task:  0 step:  34 acc_reward:  2.1450148421005917
******************
acc_reward 2.1450148421005917
******************
[0.16149592]
[-0.4289242]
[0.5794027]
[-0.2208705]
[0.09817255]
[0.763765]
[-0.97912145]
[0.01798832]
[-0.6528646]
[-0.89498353]
[0.33381593]
[-0.87796676]
[-0.96540105]
[-0.77876294]
[-0.99296737]
[-0.7910297]
[0.739969]
[0.96966445]
[0.80252767]
[0.3212384]
[0.38287532]
[0.13877952]
[0.62644076]
[0.87078714]
[0.9301553]
[0.66356194]
[0.41257417]
[-0.9628624]
[-0.52788806]
[0.13115776]
[-0.15704656]
[0.4050603]
[0.

[0.12883317]
[0.88324416]
[-0.92072403]
[0.3813753]
[0.33716738]
[0.4535352]
[-0.53141975]
[-0.746096]
[-0.77130806]
[-0.31553447]
[0.35350895]
[0.5660131]
[-0.36869013]
[-0.9761063]
[0.7200879]
[0.592896]
[-0.5924219]
[-0.5154165]
[0.6529585]
[-0.39287746]
[-0.03361583]


KeyboardInterrupt: 

In [6]:
    indices = list(range(1000))
    np.random.shuffle(indices)
    num_context = np.random.randint(10,100)
    num_target = num_context + np.random.randint(0,100-num_context)
    rand_ind_ctt, rand_ind_tgt = indices[:num_context], indices[:num_target]

In [9]:
import torch
torch.randperm(10)

tensor([4, 3, 7, 8, 2, 9, 6, 5, 0, 1])

In [10]:
torch.tensor(indices[:num_context])

tensor([199, 138, 643, 754, 702, 874, 782, 494, 355,  24, 967, 670, 170, 471,
        725, 315, 942, 720, 223, 966, 304, 129, 137, 646, 206, 865, 210, 738,
        525,  66, 979, 787, 801, 831, 323, 929, 662, 827, 744, 361, 795, 622,
        568,  14, 976, 469, 299, 669,  78, 908, 441, 179, 251, 811, 500, 724,
        896, 444, 561, 642, 651,  17, 716, 969, 855, 205, 872, 794, 989, 166,
         20, 885, 728, 727, 346, 685, 680, 417, 912, 110, 488, 577, 834, 848,
        862, 352, 495, 609, 892, 455, 558,  83, 406])