In [1]:
import os
import copy
import random
import gym
import pickle

import numpy as np
import pandas as pd
import tensorflow as tf
tf.keras.backend.set_floatx('float32')

from itertools import permutations
from sklearn.model_selection import KFold, GridSearchCV

from multiprocessing import set_start_method
import multiprocessing as mp

path = os.path.abspath('..')
if path not in sys.path:
    sys.path.append(path)

from peal.agents.default_config import DEFAULT_CONFIG as config
# from peal.agents.dqn import DQNAgent
# from peal.agents.qr_dqn import QuantileAgent
# from peal.agents.multi_head_dqn import MultiHeadDQNAgent
# from peal.agents.discrete_bcq import DiscreteBCQAgent
from peal.agents.kl_control import KLAgent
from peal.algos.kfold import CVS, KFoldCV
from peal.algos.advantage_learner import AdvantageLearner
from peal.algos.behavior_cloning import BehaviorCloning
from peal.algos.density_ratio import VisitationRatioModel
from peal.algos.fqe import FQE

def one_step(seed):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
    path = './data/mh/kl/trajs_mh.pkl'
    nfolds = 5
    n_splits = 5
    ckpts = (np.arange(10) + 1)*1000
    
    num_actions = 5
    # configures
    config['online'] = False
    config['lr'] = 5e-4
    config['decay_steps'] = 50000
    config['max_training_steps'] = 10000
    config['training_steps_to_checkpoint'] = 1000
    config['training_steps_to_eval'] = 100000
    config['hiddens'] = [64,64]
    config['constraint_hyperparm'] = 0.1

    index = pd.MultiIndex.from_product([np.arange(nfolds), ckpts])
    columns = ['dqn',  'peal']

    rets = pd.DataFrame(index=index, columns=columns)

    print('-'*20, 'start', '-'*20)
    cvs = CVS(path, n_splits=nfolds, random_state=seed)
    cvs.split()
    for fold in range(nfolds):
        train_path = cvs.train_paths[fold] + 'trajs.pkl'
        kf = KFoldCV(train_path, n_trajs=None, n_splits=n_splits, shuffle=False, random_state=seed)
        kf.split()

        print('-'*20, 'training agent', '-'*20)
        # agent
        config['persistent_directory'] = kf.agent_path
        config['checkpoint_path'] = kf.ckpt_path
        agent = KLAgent(num_actions=num_actions, config=config)
        agent.learn()

        print('-'*20, 'training agents', '-'*20)
        agent_1, ..., agent_K
        for idx in range(kf.n_splits):
            config_idx = copy.deepcopy(config)
            config_idx['persistent_directory'] = kf.agent_paths[idx]
            config_idx['checkpoint_path'] = kf.ckpt_paths[idx]
            agent_idx = KLAgent(num_actions=num_actions, config=config_idx)
            agent_idx.learn()

        # fitted q evaluation
        test_path = cvs.test_paths[fold] + 'trajs.pkl'
        with open(test_path, 'rb') as f:
            trajs = pickle.load(f)

        print('-'*20, 'behavior cloning', '-'*20)
        behavior cloning
        bc = BehaviorCloning(num_actions=num_actions)
        states  = np.array([transition[0] for traj in kf.trajs for transition in traj])
        actions = np.array([transition[1] for traj in kf.trajs for transition in traj])
        bc.train(states, actions)
        
        for ckpt in ckpts:
            print('-'*20, 'ckpt: ', ckpt, '-'*20)
            agent = KLAgent(num_actions=num_actions, config=config)
            agent.load(kf.ckpt_path + 'offline_kl_{}.ckpt'.format(ckpt))
        

            agents = []
            for idx in range(kf.n_splits):
                config_idx = copy.deepcopy(config)
                config_idx['persistent_directory'] = kf.agent_paths[idx]
                config_idx['checkpoint_path'] = kf.ckpt_paths[idx]
                agent_idx = KLAgent(num_actions=num_actions, config=config_idx)
                agent_idx.load(kf.ckpt_paths[idx] + 'offline_kl_{}.ckpt'.format(ckpt))
                agents.append(agent_idx)
            states, qvalues, qtildes = kf.update_q(agents, bc)

            print('-'*20, 'adv learner', '-'*20)
#             advs1 = qvalues - qvalues.mean(axis=1, keepdims=True)
#             agent1 = AdvantageLearner(num_actions=num_actions)
#             agent1._train(states, advs1)
            
            advs2 = qtildes - qtildes.mean(axis=1, keepdims=True)
            agent2 = AdvantageLearner(num_actions=num_actions)
            agent2._train(states, advs2)

            print('-'*20, 'fqe on dqn & peal', '-'*20)
            fqe_dqn = FQE(agent.greedy_actions, num_actions=num_actions, hiddens=[64, 64], activation='tanh', max_iter=100, eps=0.0015)
#             if fqe_dqn_weights_init != []:
#                 try:
#                     fqe_dqn.model.set_weights(fqe_dqn_weights_init)
#                 except:
#                     print('You called `set_weights(weights)` on layer "mlp_network" with a weight list of length 6, but the layer was expecting 0 weights.')
            fqe_dqn.train(trajs)
#             fqe_dqn_weights_init = fqe_dqn.model.get_weights()
#             fqe_dml = FQE(agent1.greedy_actions, num_actions=num_actions)
#             fqe_dml.train(trajs)
            fqe_peal = FQE(agent2.greedy_actions, num_actions=num_actions, hiddens=[64, 64], activation='tanh', max_iter=100, eps=0.0015)
#             if fqe_peal_weights_init != []:
#                 try:
#                     fqe_peal.model.set_weights(fqe_peal_weights_init)
#                 except:
#                     print('You called `set_weights(weights)` on layer "mlp_network" with a weight list of length 6, but the layer was expecting 0 weights.')
            fqe_peal.train(trajs)
#             fqe_peal_weights_init = fqe_peal.model.get_weights()

            rets.loc[(fold, ckpt), 'dqn'] = fqe_dqn.values
#             rets.loc[(fold, ckpt), 'dml'] = fqe_dml.values
            rets.loc[(fold, ckpt), 'peal'] = fqe_peal.values
            
    return rets

In [None]:
save_path = './data/mh/discrete_bear/'

pool = mp.Pool(5)
rets = pool.map(one_step, range(5))
pool.close()
with open(save_path + 'rets_discrete_bear_mh.pkl', 'wb') as f:
    pickle.dump(rets, f)

---------------------------------------- start start  --------------------
--------------------
-------------------- training agent --------------------
-------------------- training agent --------------------
Loaded trajectories from load path: /home/jupyt/leyuan/SUPRL/data/mh/kl/tmp/774252441/fold0/train/agent/trajs.pkl!
Refresh buffer every 1000000 sampling!


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Loaded trajectories from load path: /home/jupyt/leyuan/SUPRL/data/mh/kl/tmp/91571465/fold0/train/agent/trajs.pkl!


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you ca

-----iteration:  57 target diff:  0.0026299078791195093 values:  -47.77637 ----- 

-----iteration:  48 target diff:  0.007521828453694091 values:  -43.071636 ----- 

-----iteration:  58 target diff:  0.0028622734580266793 values:  -47.676807 ----- 

-----iteration:  49 target diff:  0.006106318320707184 values:  -42.812054 ----- 

-----iteration:  59 target diff:  0.002697293718952061 values:  -47.58132 ----- 

-----iteration:  50 target diff:  0.005841113699873013 values:  -42.553936 ----- 

-----iteration:  60 target diff:  0.002327099654219986 values:  -47.47473 ----- 

-----iteration:  51 target diff:  0.005452299215780968 values:  -42.29826 ----- 

-----iteration:  61 target diff:  0.002529393635191214 values:  -47.36484 ----- 

-----iteration:  52 target diff:  0.005232985840932566 values:  -42.08378 ----- 

-----iteration:  62 target diff:  0.0026543744545211375 values:  -47.254005 ----- 

-----iteration:  53 target diff:  0.005049820982844556 values:  -41.867466 ----- 

-----it


-----iteration:  7 target diff:  0.0019707782325431166 values:  -48.768738 ----- 

-----iteration:  8 target diff:  0.00198031869710119 values:  -48.68766 ----- 

-----iteration:  9 target diff:  0.0021690658687595547 values:  -48.671364 ----- 

-----iteration:  0 target diff:  0.9111728371249166 values:  -51.2384 ----- 

-----iteration:  10 target diff:  0.0016389300083064998 values:  -48.752323 ----- 

-----iteration:  1 target diff:  0.0029777286146075808 values:  -51.177586 ----- 

-----iteration:  11 target diff:  0.0024262028962459456 values:  -48.724262 ----- 

-----iteration:  2 target diff:  0.0023354709511557276 values:  -51.21978 ----- 

-----iteration:  12 target diff:  0.0013476507008457362 values:  -48.625618 ----- 

-------------------- ckpt:  3000 --------------------
Loaded trajectories from load path: /home/jupyt/leyuan/SUPRL/data/mh/kl/tmp/91571465/fold0/train/agent/trajs.pkl!
Refresh buffer every 1000000 sampling!
-------------------- fqe on dqn & sale ------------

-----iteration:  4 target diff:  0.001363880593451003 values:  -48.9022 ----- 

-------------------- ckpt:  6000 --------------------
Loaded trajectories from load path: /home/jupyt/leyuan/SUPRL/data/mh/kl/tmp/91571465/fold0/train/agent/trajs.pkl!
Refresh buffer every 1000000 sampling!
-------------------- fqe on dqn & sale --------------------


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

-----iteration:  0 target diff:  0.9184778092722697


-----iteration:  16 target diff:  0.0021306011510334832 values:  -47.82837 ----- 

-----iteration:  0 target diff:  0.9127219686963597 values:  -50.651497 ----- 

-----iteration:  17 target diff:  0.002513289912457839 values:  -47.81375 ----- 

-----iteration:  1 target diff:  0.004917974042667133 values:  -50.63364 ----- 

-----iteration:  18 target diff:  0.0020638418253410497 values:  -47.788334 ----- 

-----iteration:  2 target diff:  0.0022428970737552387 values:  -50.676826 ----- 

-----iteration:  19 target diff:  0.0024745269382481813 values:  -47.867153 ----- 

-----iteration:  3 target diff:  0.003284652860792506 values:  -50.66984 ----- 

-----iteration:  4 target diff:  0.0016994275644209042 values:  -50.745533 ----- 

-----iteration:  20 target diff:  0.0023651422121901172 values:  -47.837605 ----- 

-----iteration:  5 target diff:  0.0037157993621291453 values:  -50.76586 ----- 

-----iteration:  6 target diff:  0.002014249437973438 values:  -50.826134 ----- 

-----itera

-----iteration:  10 target diff:  0.001985107032057846 values:  -51.461388 ----- 

-----iteration:  11 target diff:  0.0021529424460685473 values:  -51.48136 ----- 
-----iteration: 
 2 target diff:  0.0016411577923965334 values:  -47.20119 ----- 

-----iteration:  12 target diff:  0.0016326608361822996 values:  -51.563374 ----- 
-----iteration: 
 3 target diff:  0.0019344437983614443 values:  -47.12812 ----- 

-----iteration:  13 target diff:  0.001774881566834684 values:  -51.56757 ----- 

-----iteration:  4 target diff:  0.0015214231950550052 values:  -47.032463 ----- 

-----iteration:  14 target diff:  0.0020466008913533522 values:  -51.502678 ----- 

-----iteration:  5 target diff:  0.0016822501732732585 values:  -46.994865 ----- 

-----iteration:  15 target diff:  0.0016585191146725312 values:  -51.406128 ----- 

-----iteration:  6 target diff:  0.0019627129870558777 values:  -46.97508 ----- 

-----iteration:  7 target diff:  0.0012790460109426865 values:  -46.943043 ----- 

-----


-----iteration:  0 target diff:  0.9120704562465765 values:  -51.504753 ----- 

-----iteration:  1 target diff:  0.006324454502637083 values:  -51.541878 ----- 

-----iteration:  2 target diff:  0.0027167334359829254 values:  -51.48101 ----- 

-----iteration:  3 target diff:  0.0027901241557260234 values:  -51.45444 ----- 

-----iteration:  4 target diff:  0.0033104565721850106 values:  -51.443733 ----- 

-----iteration:  5 target diff:  0.0016999744508440647 values:  -51.37751 ----- 

-----iteration:  6 target diff:  0.0018134561424615861 values:  -51.43711 ----- 

-----iteration:  7 target diff:  0.0022618703098555046 values:  -51.359524 ----- 

-----iteration:  8 target diff:  0.0022514333908860516 values:  -51.376926 ----- 

-----iteration:  9 target diff:  0.0015988545847271402 values:  -51.402107 ----- 

-----iteration:  10 target diff:  0.001584659969404036 values:  -51.543636 ----- 

-----iteration:  11 target diff:  0.002389652355586579 values:  -51.586662 ----- 

-----iterat