In [1]:
import pandas as pd
import numpy as np
from d3rlpy.algos import DiscreteCQL
from d3rlpy.algos import DiscreteBCQ
from sklearn.preprocessing import MinMaxScaler
import os, glob
import re

In [2]:
mdp_name = "MDP_aug4687"

target_policy = pd.read_csv('policies_penalised/policy_{}.csv'.format(mdp_name))
target_policy['state'] = target_policy['state'].astype('str')

num_features = len(target_policy.loc[0,'state'])
obs = np.empty(shape = (len(target_policy), num_features))
for i in range(num_features):
    obs[:,i] = target_policy['state'].str[i]


quan_path = 'd3rlpy_results/limited/'
quan_folders = os.listdir(quan_path)

for folder in quan_folders:
    print('******************************************************')
    print('Current run: {}'.format(re.sub(r'_2021.*$',"", folder)))
    folder_path = os.path.join(quan_path, folder)
    models = glob.glob('{}/*.pt'.format(folder_path))
    if os.path.join(folder_path, 'model_10.pt') not in models:
        model_path = models[-1]
    else:
        model_path = os.path.join(folder_path, 'model_10.pt')
        
    if 'bcq' in folder:
        algo = DiscreteBCQ.from_json('{}/params.json'.format(folder_path))
    else:
        algo = DiscreteCQL.from_json('{}/params.json'.format(folder_path))
    algo.load_model(model_path)

    cql_policy = algo.predict(obs)
    cql_state_vals = algo.predict_value(obs, cql_policy)
    
    print("Unique actions in policy: {}".format(np.unique(cql_policy)))

    cql_csv = pd.DataFrame(index = target_policy['state'])
    cql_csv['policy'] = cql_policy
    cql_csv['values'] = cql_state_vals
    
    cql_csv.to_csv('policies_CQL/policy_{}.csv'.format(re.sub(r'_2021.*$',"", folder)))

    q_vals = []
    for act in range(36):
        q_vals.append(algo.predict_value(obs, np.array([act]*len(obs))))

    q_vals = np.array(q_vals)
    q_vals = q_vals.transpose()

    q_vals = pd.DataFrame(q_vals)
    q_vals['state'] = target_policy['state'].values
    q_vals.set_index('state', inplace = True)
    q_vals.to_csv('policies_CQL/Q_{}.csv'.format(re.sub(r'_2021.*$',"", folder)))

******************************************************
Current run: bcq_q_qr_bndo
Unique actions in policy: [ 0  1  2  3  4  6  7  8  9 11 14 15 17 18 20 21 22 23 24 26 27 28 31 32
 33 34 36]
******************************************************
Current run: lim_q_iqn
Unique actions in policy: [ 2  3  4  5  6  8  9 11 13 14 16 20 21 22 23 24 27 28 29 33 36]
******************************************************
Current run: lim_q_mean
Unique actions in policy: [ 0  2  4  5  6  7  8  9 10 12 13 14 15 16 20 21 22 23 24 25 28 29 30 31
 33 34 36]
******************************************************
Current run: lim_q_qr
Unique actions in policy: [ 0  2  3  4  8  9 10 11 13 14 15 16 17 19 20 22 23 24 27 28 31 32 33 34
 36]
******************************************************
Current run: lim_q_qr_bndo
Unique actions in policy: [23 24]
******************************************************
Current run: lim_q_qr_bndo_alph0.5
Unique actions in policy: [21 22 23 24]
***********************