In [None]:
import csv
import os
import numpy as np
from envs.utils import printoptions
from matplotlib.backends.backend_pdf import PdfPages


%pylab inline
import matplotlib.pyplot as plt


def get_stats(algorithm="mceirl", env="room", spec="default", comb="use_prior", param_tuned='k', path="./results/tuning-mceirl", temp_index=0):
    results_list=[]
    for file in os.listdir(path):
        if algorithm in file and env in file and spec in file and comb in file and "-"+param_tuned in file:
            #print(os.path.join("./results", file))

            with open(os.path.join(path, file), 'rt') as f:
                reader = csv.reader(f)
                # the first line is names of returned items, e g [seed, true_r, final_r]
                list_results = list(reader)[1::]
                list_rewards = []
                for res in list_results:
                    s = res[1]
                    s = s.replace(']', '').replace('[', '').replace(' ', '').split(',')
                    list_rewards.append(float(s[temp_index]))
                list_rewards = np.asarray(list_rewards)

                param_val = file.split('-'+param_tuned+'=', 1)[-1]
                param_val = param_val.split('-')[0]

                results_list.append([float(param_val), np.mean(list_rewards), np.std(list_rewards)])
    results_list = np.asarray(results_list)
    # return a list sorted by the value of k
    return results_list[results_list[:,0].argsort()]

In [None]:
def plot_params_multiple_one_subplot(stats_list_per_env, ax,
                         c='b', color_list=['blue', 'orange', '#5177d6', '#ffe500', 'deepskyblue', 'coral'], 
                         env_names=[' room', ' room', ' train', ' train', ' batteries', ' batteries'], 
                         comb=['Bayesian,', 'Additive,','Bayesian,', 'Additive,', 'Bayesian,', 'Additive,'],
                         title='',
                         current_subplot=None,
                         total_subplots=None,
                         y_min=None,
                         y_max=None):
    ticks_string=[]
    for i in stats_list_per_env[0][0][:,0]: 
        ticks_string.append(str(i))
        
    #plt.figure(figsize=(6*len(stats_list_per_env), 4))
    for j, stats_list in enumerate(stats_list_per_env):
        #ax = plt.subplot(1, len(stats_list_per_env), j+1)
        stats_stack = np.vstack(stats_list)
        #y_min = np.amin(stats_stack[:,1] - stats_stack[:,2]) - .1
        #y_max = np.amax(stats_stack[:,1] + stats_stack[:,2]) + .1
        
        if y_min is None: y_min = 0.45
        if y_max is None: y_max = 1.05
        
        
        for i in range(len(stats_list)):
            if color_list is not None:
                c = color_list[i]

            stats = stats_list[i]
            
            ax.set_ylim(y_min, y_max)
            ax.scatter(np.log2(stats[:,0]), stats[:,1], color=c, edgecolor=c, s=40, label=comb[i]+env_names[i])

            #ax.errorbar(np.log2(stats[:,0]), stats[:,1], yerr=[stats[:,2], stats[:,2]], 
            #                alpha=0.5, color=c, fmt='-o', capthick=3)

            ax.plot(np.log2(stats[:,0]), stats[:,1], color=c)
            #ax.semilogx(stats[:,0], stats[:,1], color=c, label = comb[i], base=2)
            
            plt.tick_params(axis='both', labelsize=12)

            #plt.xticks(np.log2(stats[:,0]), ticks_string)
            ax.tick_params(axis='both', labelsize='large')

            plt.xticks(np.log2(stats[::3,0]), ticks_string[0::3])
            if current_subplot==1:
                plt.xlabel("Standard deviation", fontsize=21)
            if current_subplot==0:
                plt.ylabel("Fraction of max R", fontsize=19)
            #if current_subplot+1==total_subplots:
            if current_subplot==0:
                handles, labels = ax.get_legend_handles_labels()
                # sort both labels and handles by labels
                labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
                #ax.legend()
                plt.legend(handles, labels, loc="best", fontsize=12, handletextpad=-0.4)
        if title!='': plt.title(title, fontsize=24)


def plot_params_multiple(env_lists_per_t, titles_list=None, figname='', y_min=None, y_max=None):
    if titles_list is None: titles_list = ['']*len(env_lists_per_t)
        
    fig = plt.figure(figsize=(5*len(env_lists_per_t), 3.4))
    for j, stats_list in enumerate(env_lists_per_t):
        ax = plt.subplot(1, len(env_lists_per_t), j+1)
        plot_params_multiple_one_subplot(stats_list, ax, title=titles_list[j], current_subplot=j, 
                                             total_subplots=len(env_lists_per_t), y_min=y_min, y_max=y_max)
    #fig.suptitle(r"$\bf{Comparison\ of\ the\ methods\ for\ combining\ } \theta_{spec} \ and\  \theta_{H} $", fontsize=26, y=1.1, fontweight='bold')
    
    fig.subplots_adjust(top=1.1)
    plt.tight_layout()
    
    pp = PdfPages('results/prior_vs_addition_all'+figname+'.pdf')
    pp.savefig()
    pp.close()
    
    plt.show()

# Adding rewards vs prior in the room env

In [None]:
# temp=0 (rational agent)
stats_list_per_env_t0 = [[get_stats("mceirl", "room", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=0), 
                      get_stats("mceirl", "room", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=0),
                      get_stats("mceirl", "train", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=0), 
                      get_stats("mceirl", "train", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=0),
                      get_stats("mceirl", "batteries", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=0), 
                      get_stats("mceirl", "batteries", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=0)]]

# temp=.1
stats_list_per_env_t01 = [[get_stats("mceirl", "room", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=1), 
                      get_stats("mceirl", "room", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=1),
                      get_stats("mceirl", "train", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=1), 
                      get_stats("mceirl", "train", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=1),
                      get_stats("mceirl", "batteries", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=1), 
                      get_stats("mceirl", "batteries", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=1)]]

# temp=1
stats_list_per_env_t1 = [[get_stats("mceirl", "room", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=2), 
                      get_stats("mceirl", "room", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=2),
                      get_stats("mceirl", "train", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=2), 
                      get_stats("mceirl", "train", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=2),
                      get_stats("mceirl", "batteries", "default", "use_prior", "k",  "./results/use_prior_vs_add_r", temp_index=2), 
                      get_stats("mceirl", "batteries", "default", "add_rewards", "k",  "./results/use_prior_vs_add_r", temp_index=2)]]

#plot_params_multiple_one_plot(stats_list_per_env_t1, title="temperature=1")

In [None]:
env_lists_per_t = [stats_list_per_env_t0, stats_list_per_env_t01, stats_list_per_env_t1]
titles_list = ['temperature = 0','temperature = 0.1','temperature = 1']

plt.rcParams["font.family"] = "Times New Roman"
plot_params_multiple(env_lists_per_t, titles_list=titles_list, figname='_x3')

# Horizon

In [None]:
stats_list_per_env_t0 = [[get_stats("mceirl", "train", "default", "-f=True", "H",  "./results/horizon-fixed"),
                         get_stats("mceirl", "room", "default", "-f=True", "H",  "./results/horizon-fixed"),
                         get_stats("mceirl", "batteries", "default", "-f=True", "H",  "./results/horizon-fixed"),
                         get_stats("mceirl", "apples", "default", "-f=True", "H",  "./results/horizon-fixed")]]
stats_list_per_env_t01 = [[get_stats("mceirl", "room", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=1),
                         get_stats("mceirl", "train", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=1),
                           get_stats("mceirl", "batteries", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=1),
                         get_stats("mceirl", "apples", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=1)]]
stats_list_per_env_t1 = [[get_stats("mceirl", "room", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=2),
                         get_stats("mceirl", "train", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=2),
                          get_stats("mceirl", "batteries", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=2),
                         get_stats("mceirl", "apples", "default", "-f=True", "H",  "./results/horizon-fixed", temp_index=2)]]

In [None]:
env_lists_per_t = [stats_list_per_env_t0, stats_list_per_env_t01, stats_list_per_env_t1]
titles_list = ['temperature = 0','temperature = 0.1','temperature = 1']

plt.rcParams["font.family"] = "Times New Roman"
plot_params_multiple(env_lists_per_t, titles_list=titles_list, figname='_horizon_x3', y_min=0.3)

In [None]:
fig = plt.figure(figsize=(5, 2.6))
ax = plt.subplot(1, 1, 1)
plot_params_multiple_one_subplot(stats_list_per_env_t0, ax, title='', current_subplot=0, 
                                total_subplots=len(env_lists_per_t), y_min=0.45, y_max=1.05,
                                env_names=['train', 'room', 'batteries', 'apples'],
                                comb=['','','',''],
                                color_list=['green', 'orange', '#5177d6', 'firebrick'])
plt.xlabel("Horizon", fontsize=19)
ax.legend(bbox_to_anchor=(1.05, 1.05), fontsize=12)
plt.tight_layout()

pp = PdfPages('results/horizon_t0.pdf')
pp.savefig()
pp.close()
plt.show()