In [None]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from config import *
import operator
import math
import os

In [None]:
# # plt.rcParams.update({
# #     "pgf.texsystem": "lualatex"
# # })
font_scale = 2
plt.rcParams['pgf.texsystem'] = 'pdflatex'
sns.set_theme(style="whitegrid",font_scale=font_scale)
# fontsize = 35
# plt.rc("font", **{"family": "serif", "serif": ["Times"]})#, "size" : fontsize})
# plt.rc("text", usetex=True)

# plt.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command
plt.rcParams["pgf.preamble"] = r'\usepackage{amsmath}'
plt.rcParams.update({"text.latex.preamble": plt.rcParams["pgf.preamble"] })
plt.rcParams['text.usetex'] = True
save_format = 'pgf'
backend = 'pgf'
transparent = True
pad_inches = 0.1
bbox_inches = 'tight'
save_args = {
    'format' : save_format,
    'backend' : backend,
    'transparent' : transparent,
    'pad_inches' : pad_inches,
    'bbox_inches' : bbox_inches,
}

In [None]:
BASE_DIR=BASE_OUTPUT_DIR
SUBFAMILY_SIZE=10
SEED=2
SEEDS=list(range(2,12))
ENV_NAMES = ['dpm', 'obstacles-10-2', 'avoid', 'obstacles-8-3', 'rover', 'network']
MINIMIZING = [False, True, True, True, False, False]

In [None]:
SEEDS

## If statistics.pickle does not exist, then execute the code under 'gather statistics' first!

In [None]:
with open("./statistics.pickle", 'rb') as handle:
    statistics = pickle.load(handle)
statistics = {key.split("/")[-1] : value for key, value in statistics.items()}

In [None]:
statistics.keys()

In [None]:
ENV_TO_MINIMIZING = {env : mini for env, mini in zip(ENV_NAMES, MINIMIZING)}

def get_env_str(env : str):
    if True:
        prefix = '$\\downarrow$~' if ENV_TO_MINIMIZING[env] else '$\\uparrow$~'
        suffix = ''
        # suffix = '}'
        # prefix = '\\emph{'
        if env == 'dpm':
            env_id = prefix + f"{env.upper()}" + suffix
        elif '-' in env:
            env_id = prefix + f"{env.split("-")[0][0].upper()}{env.split("-")[0][1:].lower()}({env.split("-")[1]}, {env.split("-")[2]})" + suffix
            env_id = env_id.replace("3", "5") # Obstacles(8,3) is Obstacles(8,5)
        else:
            env_id = prefix + f"{env[0].upper()}{env[1:].lower()}" + suffix
    return env_id

In [None]:
def wrap(x):
    return f'${x}$'

dimensions = {}
for env in ENV_NAMES:
    dimensions[get_env_str(env)] = {
        '$|\\family|$' : wrap(statistics[env]['family_size']),
                '$|S|$' : wrap(statistics[env]['max_num_states']),
        '$|\\obsset|$' : wrap(statistics[env]['num_observations']), 
        '$|\\Act|$' : wrap(statistics[env]['num_actions']),

        
        }
    print(f"nS:{statistics[env]['max_num_states']}\tnO:{statistics[env]['num_observations']}\t\tnA:{statistics[env]['num_actions']}", env, sep='\t')

In [None]:
def significant(x, p):
    x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10**(p-1))
    mags = 10 ** (p - 1 - np.floor(np.log10(x_positive)))
    return np.round(x * mags) / mags

In [None]:
print(pd.DataFrame(dimensions).T.style.to_latex(hrules=True))

In [None]:
OUTPUT_DIR = "./plot_builder_output-pgf"
os.makedirs(f"{OUTPUT_DIR}/heatmaps", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)

# heatmap func

In [None]:
import random


def make_heatmap(results : dict, our_results : dict, env : str, title : str, baseline : str, union_results : dict = None, minimizing = True, add_whole_family_value = True):
    print(title)
    values = np.array([np.ravel(list(item[-2].values())) for train, item in results.items()])#.squeeze(axis=-1)
    values[~np.isfinite(values)] = np.nan
    if union_results:
        if type(union_results['subfamily']) == dict:
            ur = list(union_results['subfamily'].values())
        else:
            ur = union_results['subfamily']
            
        values = np.vstack([values, ur])
    values = np.vstack([values, our_results['ours']])
    if add_whole_family_value:
        entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + ([union_results['whole_family']] if union_results else []) + [our_results['whole_family']])
        # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
        values = np.hstack([values, entire_family_values[..., None]])
    # plt.figure(figsize=(16,9), dpi=300)
    
    if minimizing:
        colormap = sns.cm.rocket_r
        # colormap.set_bad(sns.cm.rocket.get_bad())
    else:
        colormap = sns.cm.rocket
    # colormap.set_bad('black')
    
    plt.figure(figsize=(16,9))
    
    ax = sns.heatmap(values, annot=True, cbar=False, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=colormap, fmt='.4g')#, annot_kws={"size": 25})#, mask=~np.isfinite(values))
    
    # if add_whole_family_value:
        # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box
    
    subfamily_size = len(results.keys())
    
    best_value = np.inf if minimizing else -np.inf
    for r in range(values.shape[0]):
        row_values = values[r][:-1] if add_whole_family_value else values[r]
        if np.isfinite(row_values).any():
            idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
        else:
            idx = random.randint(0, len(row_values)-1)
        cmp = operator.le if minimizing else operator.ge
        if cmp(row_values[idx], best_value):
            best_value = row_values[idx]
            best_rectangle = (idx, r)
        ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
    ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))
    
    if add_whole_family_value:
        best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
        ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))
    
    # ax.set_xlabel("Test")
    # ax.set_ylabel("Train")
    plt.yticks(rotation=0)
    xticks = ["$M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)]
    if add_whole_family_value:
        xticks += ["$\mathcal{M}$"]
    ax.set_xticklabels(xticks)#, rotation=45, ha='right', rotation_mode='anchor')
    # ax.set_xticklabels(xticks)
    suffix = lambda i : f" ({results[i][1].num_nodes if results[i][1] else None}-FSC)"
    suffix = lambda _ : ""
    prefixE = r'\textsc{gd-E}' if 'gd' in baseline.lower() else r'\textsc{Saynt-E}' if 'saynt' in baseline.lower() else 'WHAT?'
    prefixU = r'\textsc{gd-U}' if 'gd' in baseline.lower() else r'\textsc{Saynt-U}' if 'saynt' in baseline.lower() else 'WHAT?'
    ax.set_yticklabels(["$M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)] + ([prefixU] if union_results else []) + [r'\textsc{rfPG-S}']) # ["rfPG on subfamily"])
    # ax.set_yticklabels([prefixE + f" on POMDP {i+1}{suffix(i)}" for i in range(subfamily_size)] + ([prefixU + f" on Union POMDP"] if union_results else []) + [r'\textsc{rfPG-S} on Subfamily']) # ["rfPG on subfamily"])
    # ax.set_title(f"{title}: {baseline} baselines vs " + r'\textsc{rfPG-S}' + f" for a single run ({'lower' if minimizing else 'higher'} is better)")
    ax.set_title(title)
    plt.tight_layout()
    print(env)
    plt.savefig(f"{OUTPUT_DIR}/heatmaps/{env}-{baseline}.{save_format}", **save_args)

In [None]:
import math
import random


def make_lineplot(results : dict, title : str = 'Placeholder', minimizing = True, type_of_plot = 'family_trace', use_time_x_axis=False):
    fig = plt.figure(dpi=300)
    ax = fig.gca()
    plt.title(f"{title} ({'lower' if minimizing else 'higher'} is better)")
    plt.ylabel("Worst family member value")
    
    max_time_normal = max(results['gd-normal']['plot_times'])
    max_time_random = max(results['gd-random']['plot_times'])
    
    results['gd-random']['plot_times'] = (np.array(results['gd-random']['plot_times']) / max_time_random) * max_time_normal
    
    print(max_time_normal, max_time_random)
    
    if use_time_x_axis:
        plt.xlabel("Time")
        print(results['gd-normal']['plot_times'])
        ax.plot(results['gd-normal']['plot_times'], results['gd-normal'][type_of_plot], label='rfPG', color='green')
    else:
        plt.xlabel("Iteration")
        ax.plot(results['gd-normal'][type_of_plot], label='rfPG', color='green')
    
    min_x_normal = np.argmin(results['gd-normal'][type_of_plot]) if minimizing else np.argmax(results['gd-normal'][type_of_plot])
    min_y_normal = results['gd-normal'][type_of_plot][min_x_normal]
    if use_time_x_axis: min_x_normal = results['gd-normal']['plot_times'][min_x_normal]
    
    # plt.plot(results['gd-no-momentum'][type_of_plot], label='rfPG (no momentum)', color='blue')
    
    # min_x_no_mom = np.argmin(results['gd-no-momentum'][type_of_plot]) if minimizing else np.argmax(results['gd-no-momentum'][type_of_plot])
    # min_y_no_mom = results['gd-no-momentum'][type_of_plot][min_x_no_mom]
    
    
    
    if use_time_x_axis:
        ax.plot(results['gd-random']['plot_times'], results['gd-random'][type_of_plot], label='Random rfPG', color='red')
    else:
        exit()
        ax.plot(results['gd-random'][type_of_plot], label='Random rfPG', color='red')

    min_x = np.argmin(results['gd-random'][type_of_plot]) if minimizing else np.argmax(results['gd-random'][type_of_plot])
    min_y = results['gd-random'][type_of_plot][min_x]
    if use_time_x_axis: min_x = results['gd-random']['plot_times'][min_x]
    
    mini = min([min_x, min_x_normal])
    maxi = max([min_x, min_x_normal])
    
    def get(x):
        if math.isclose(mini, maxi) and math.isclose(x, maxi):
            return random.randint(-50, 50)
        elif math.isclose(x, maxi):
            return 50
        elif math.isclose(x, mini):
            return -50
        else:
            return 0
    
    ax.annotate(f"{min_y:.2f}",
            xy=(min_x, min_y), xycoords='data',
            xytext=(get(min_x), (50 if minimizing else -50)), textcoords='offset points',
            arrowprops=dict(facecolor='red', shrink=0),
            horizontalalignment='center', verticalalignment='bottom')
    
    ax.annotate(f"{min_y_normal:.2f}",
                xy=(min_x_normal, min_y_normal), xycoords='data',
                xytext=(get(min_x_normal), (50 if minimizing else -50)), textcoords='offset points',
                arrowprops=dict(facecolor='green', shrink=0),
                horizontalalignment='center', verticalalignment='bottom')
    
    # ax.annotate(f"{min_y_no_mom:.2f}",
    #     xy=(min_x_no_mom, min_y_no_mom), xycoords='data',
    #     xytext=(get(min_x_no_mom), (50 if minimizing else -50)), textcoords='offset points',
    #     arrowprops=dict(facecolor='blue', shrink=0),
    #     horizontalalignment='center', verticalalignment='bottom')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/lineplot/{title}.{save_format}", **save_args)
    plt.show()

In [None]:
def plot_learning_curve_single_seed(envs=ENV_NAMES, minimizings=MINIMIZING, seed=SEED):
    for env, minimizing in zip(envs, minimizings):
        try:
            with open(f"{BASE_DIR}/{env}/{seed}/gd-experiment.pickle", 'rb') as handle:
                results = pickle.load(handle)
                make_lineplot(results, title=env, minimizing=minimizing, use_time_x_axis=True)
        except FileNotFoundError as fnfe:
            print(fnfe)

In [None]:
import scipy.stats

def plot_learning_curve(envs=ENV_NAMES, minimizings=MINIMIZING, seed=SEED, use_ci=True):
    
    x_common = np.linspace(0, 100, 1000)
    
    for env, minimizing in zip(envs, minimizings):
        rand_X = []
        rand_Y = []
        X = []
        Y = []
        for seed in SEEDS:
            try:
                with open(f"{BASE_DIR}/{env}/seed{seed}/gd-experiment.pickle", 'rb') as handle:
                    results = pickle.load(handle)
                    # make_lineplot(results, title=env, minimizing=minimizing, use_time_x_axis=True)
            except FileNotFoundError as fnfe:
                print(fnfe)
                continue
            
            # print(env, seed)
            type_of_plot = 'family_trace'                    
            
            x_random = np.array(results['gd-random']['plot_times']) 
            y_random = np.array(results['gd-random'][type_of_plot])
            
            rand_X.append(x_random)
            rand_Y.append(y_random)

            x = np.array(results['gd-normal']['plot_times'])
            y = np.array(results['gd-normal'][type_of_plot])
            
            # if True:
            try:
                max_time_normal = max(results['gd-normal']['plot_times'])
                max_time_random = max(results['gd-random']['plot_times'])
            except:
                continue
            
            if max_time_random >= max_time_normal:
                x_random = (x_random / max_time_random) * max_time_normal
            else:
                x = (x / max_time_random) * max_time_normal
            
            X.append(x)
            Y.append(y)
            
            # y_common = np.interp(x_common, this_x, this_y)
            
            # print(max_time_normal, max_time_random)
        
        num_points = 360
        x_common = np.linspace(0, 3600, num_points)

        
        def plot_interp(X, Y):
            # y_sum = np.zeros(num_points)
            y_agg = []
            
            assert len(rand_X) == len(rand_Y)
            number_of_curves = len(rand_X)

            for this_x, this_y in zip(X, Y):                
                # Interpolate y so that it's using a common x-axis
                try:
                    y_common = np.interp(x_common, this_x, this_y)
                except Exception as e:
                    # print(this_x, this_y)
                    continue
                
                # Add it to the other curves
                y_agg += [y_common]

            # Divide the sum by the number of curves to get the average curve
            # y_average = y_sum / number_of_curves
            
            return y_agg
            
        # fig = plt.figure(figsize=(4,3), dpi=100)
        # fig = plt.figure(figsize=(1.755 * 2,1.31 * 2))
        # plt.figure(dpi=100)
        # plt.grid(True)
        # ax = fig.gca()
        rand_Y_interp = plot_interp(rand_X, rand_Y)
        # sns.lineplot(x=rand_X_interp, y=rand_Y_interp, label='Random rfPG', color='red')
        plt.plot(x_common, np.mean(rand_Y_interp, axis=0), label='Random', color='red')
        
        if use_ci:
            ci_a, ci_b = scipy.stats.t.interval(0.95, len(x_common)-1, loc=np.mean(rand_Y_interp, axis=0), scale=scipy.stats.sem(rand_Y_interp, axis=0))
            plt.fill_between(x_common, ci_a, ci_b,  alpha=0.25, color='red')
        else:
            plt.fill_between(x_common, np.mean(rand_Y_interp, axis=0) + np.var(rand_Y_interp, axis=0), np.mean(rand_Y_interp, axis=0) - np.var(rand_Y_interp, axis=0),  alpha=0.25, color='red')
        
        Y_interp = plot_interp(X, Y)
        plt.plot(x_common, np.mean(Y_interp, axis=0), label=r'\textsc{rfPG}', color='green')

        if use_ci:
            ci_a, ci_b = scipy.stats.t.interval(0.95, len(x_common)-1, loc=np.mean(Y_interp, axis=0), scale=scipy.stats.sem(Y_interp, axis=0))
            plt.fill_between(x_common, ci_a, ci_b,  alpha=0.25, color='green')
        else:
            plt.fill_between(x_common, np.mean(Y_interp, axis=0) + np.std(Y_interp, axis=0), np.mean(Y_interp, axis=0) - np.std(Y_interp, axis=0),  alpha=0.25, color='green')
        # plt.title(f"{env.upper()} ({'lower' if minimizing else 'higher'} is better)")
        plt.title(get_env_str(env))
        plt.ylabel("Robust performance")
        plt.xlabel("Time (seconds)")
        plt.legend()
        plt.tight_layout()
        os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)
        # plt.savefig(f"{OUTPUT_DIR}/lineplot/{env}.{save_format}", **save_args)
        plt.savefig(f"{OUTPUT_DIR}/lineplot/{env}.{save_format}", backend='pgf', format='pgf', bbox_inches = 'tight', pad_inches = 0.05)
        plt.show()
        plt.close()

In [None]:
plot_learning_curve()

# DEFAULT FUNC

In [None]:
def load_results(env : str, include_union_results : bool = True, seed=SEED, subfamsize=SUBFAMILY_SIZE, include_gd_results=False, include_saynt_results : bool = True):
    if include_saynt_results:
        with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-saynt.pickle", 'rb') as handle:
            saynt = pickle.load(handle)
    else:
        saynt = None
        
    with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-ours.pickle", 'rb') as handle:
        ours = pickle.load(handle)
    
    if include_gd_results:
        with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-gradient.pickle", 'rb') as handle:
            gradient = pickle.load(handle)
    else:
        gradient = None
    
    if include_union_results:
        union_results = {}
        try:
            with open(f"{BASE_DIR}/{env}/union/seed{seed}/union-saynt.pickle", 'rb') as handle: 
                union_results['saynt'] = pickle.load(handle)
            if type(union_results['saynt']['subfamily']) == dict:
                union_results['saynt']['subfamily'] = list(union_results['saynt']['subfamily'].values())
        except FileNotFoundError as e:
            print(e)
            union_results['saynt'] = None
        try:
            with open(f"{BASE_DIR}/{env}/union/seed{seed}/union-gradient.pickle", 'rb') as handle: 
                union_results['gradient'] = pickle.load(handle)
        except FileNotFoundError as e:
            print(e)
            union_results['gradient'] = None
    else:
        union_results = None
    
    return saynt, ours, union_results, gradient

In [None]:
def create_table(envs : list[str], minimizings : list[bool], include_union_results = True):
    subfamresults = {}
    whlfamresults = {}
    for env, minimizing in sorted(zip(envs, minimizings), key=lambda x : x[1]):
        saynt, ours, union_results, gradient = load_results(env, include_union_results=include_union_results, include_gd_results=True)
        
        aggregator = lambda x : float((max if minimizing else min)(x))        
        
        env_id = f"{env.upper()} ({'min.' if minimizing else 'max.'})"
        
        
        subfamresults[env_id] = {
            # min if minimizing because taking the best FSC out of the 10 FSCs, but worst value for that FSC among the 10 POMDPs evaluated. 
            '\\saynt (one-by-one)' : (min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()]),
            '\\saynt (union)' : aggregator(union_results['subfamily']) if union_results else np.nan,
            '\\ours  (subfamily)'  : aggregator(ours['ours']),
        }
        
        with open(f"{BASE_DIR}/{env}/seed{SEED}/gd-experiment.pickle", 'rb') as handle:
            rfpg_results = pickle.load(handle)
            
        
        
        whlfamresults[env_id] = {
            '\\saynt (one-by-one)' : aggregator([item[-1] if item[-1] is not None else np.nan for train, item in saynt.items()]),
            '\\saynt (union)' : float(union_results['whole_family']) if union_results else np.nan,
            '\\ours  (subfamily)'  : float(ours['whole_family']),
            '\\ours  (whole family)' : rfpg_results['gd-normal']['best_worst_value']
        }
    
    return subfamresults, whlfamresults


In [None]:
from collections import defaultdict

def hotfix(x : np.ndarray, set_inf_to_random_family_value = False, env = None) -> np.ndarray:
    # x = x[~np.isnan(x)]
    # if set_inf_to_random_family_value:
        # x[~np.isfinite(x)] = statistics[env]['family_value']
    return x

# one_by_one_suffix = ' (enum.)'
# one_by_one_suffix = '-E'
one_by_one_suffix = 'E'
# union_suffix = ' (union)'
# union_suffix = '-U'
union_suffix = 'U'

def create_table_multiple_seeds(envs : list[str], minimizings : list[bool], include_union_results = True, include_gd_results = True, normalize_results = True, include_error = False, seeds=SEEDS):
    subfamresults = defaultdict(dict)
    whlfamresults = defaultdict(dict)
    for env, minimizing in sorted(zip(envs, minimizings), key=lambda x : x[1]):
        
        aggregator = lambda x : float((max if minimizing else min)(x))        
        

        env_id = get_env_str(env)       
        
        subfamresults[env_id] = defaultdict(list)
        whlfamresults[env_id] = defaultdict(list)
        
        # nan_num = np.inf if minimizing 
        


        count = 0
        
        for seed in seeds:
            try:
                # saynt, ours, union_results = load_results(env, include_union_results=include_union_results, seed=seed)
                saynt, ours, union_results, gradient = load_results(env, include_union_results=include_union_results, seed=seed, include_gd_results=include_gd_results)
                
                with open(f"{BASE_DIR}/{env}/seed{seed}/gd-experiment.pickle", 'rb') as handle:
                    rfpg_results = pickle.load(handle)
            except Exception as e:
                print(e)
                continue

            count += 1
            
            whl_fam_results_norm = rfpg_results['gd-normal']['best_worst_value']
            assert np.isfinite(whl_fam_results_norm)
            assert whl_fam_results_norm > 0
            
            sub_fam_results_norm = aggregator(ours['ours'])
            assert np.isfinite(sub_fam_results_norm)
            assert sub_fam_results_norm > 0
            
            norm_sub = lambda x : (sub_fam_results_norm / x if minimizing else x / sub_fam_results_norm) if normalize_results else x
            norm_whl = lambda x : (whl_fam_results_norm / x if minimizing else x / whl_fam_results_norm) if normalize_results else x
            
            random_policy_subfam = statistics[env][seed]['evaluations'].max() if minimizing else statistics[env][seed]['evaluations'].min() # get random policy value of worst POMDP inside subfamily
            random_policy_whlfam = statistics[env]['family_value']
            
            if minimizing:
                assert random_policy_subfam <= random_policy_whlfam, (random_policy_subfam, random_policy_whlfam)
            else:
                assert random_policy_subfam >= random_policy_whlfam, (random_policy_subfam, random_policy_whlfam)
            
            # print(gradient)
            # subfamresults[env_id]['Direction'] = 'min.' if minimizing else 'max.'
            subfamresults[env_id][f'\\saynt{one_by_one_suffix}'] += [norm_sub((min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()]))]
            if union_results['saynt']:
                res = aggregator(union_results['saynt']['subfamily'])
                if not np.isfinite(res):
                    # print(env, seed, "NOT FINITE:", res)
                    res = random_policy_subfam
                subfamresults[env_id][f'\\saynt{union_suffix}'] += [norm_sub(res)]
            
            if include_gd_results:
                subfamresults[env_id][f'\\gd{one_by_one_suffix}'] += [norm_sub((min if minimizing else max)([aggregator(np.array(list(gradient[k][2].values())).flatten().tolist()) for k in gradient.keys()]))]
                if union_results['gradient']:
                    res = aggregator(union_results['gradient']['subfamily'])
                    if not np.isfinite(res):
                        # print(env, seed, "NOT FINITE:", res)
                        res = random_policy_subfam
                    subfamresults[env_id][f'\\gd{union_suffix}'] += [norm_sub(res)]
            
            subfamresults[env_id]['\\ours  (subfamily)'] += [norm_sub(aggregator(ours['ours']))]            
            
            # whlfamresults[env_id]['Random'] = statistics[env]
            # whlfamresults[env_id]['Direction'] = 'min.' if minimizing else 'max.'
            # whlfamresults[env_id]['$|\\family|$'] = statistics[env]['family_size']
            # whlfamresults[env_id]['Random'] = statistics[env]['family_value']
            # print(env, seed, '\\saynt (one-by-one)', [item[-1] for train, item in saynt.items()])
            whlfamresults[env_id][f'\\saynt{one_by_one_suffix}'] += [norm_whl(aggregator([item[-1] if item[-1] is not None else random_policy_whlfam for train, item in saynt.items()]))]
            if union_results['saynt']:
                res = float(union_results['saynt']['whole_family'])
                if not np.isfinite(res):
                    res = statistics[env]['family_value']
                    assert np.isfinite(res), res
                whlfamresults[env_id][f'\\saynt{union_suffix}'] += [norm_whl(res)]
            # print(whlfamresults[env_id]['\\saynt (one-by-one)'])
                
            
            if include_gd_results:
                whlfamresults[env_id][f'\\gd{one_by_one_suffix}'] += [norm_whl(aggregator([item[-1] if item[-1] is not None else random_policy_whlfam for train, item in gradient.items()]))]
                if union_results['gradient']:
                    whlfamresults[env_id][f'\\gd{union_suffix}'] += [norm_whl(float(union_results['gradient']['whole_family']))]
            
            whlfamresults[env_id]['\\ours  (subfamily)'] += [norm_whl(float(ours['whole_family']))]
            whlfamresults[env_id]['\\ours  (whole family)'] += [rfpg_results['gd-normal']['best_worst_value']]
        
        # if 'avoid' in env.lower():
            # print(env_id, subfamresults[env_id]['\\saynt (union)'])
        
        # print(env_id, whlfamresults[env_id])
        
        assert count > 0, (env)
        # print(env, whlfamresults[env_id]['\\saynt (one-by-one)'])
        
        # statistic_func = lambda x : np.round(np.mean(x), decimals=2)
        # statistic_func = lambda x : np.format_float_scientific(np.mean(x),precision=2,trim='k')
        # statistic_func = lambda x : f"${np.format_float_scientific(np.mean(x),precision=2,trim='k')} \pm {np.format_float_scientific(np.std(x),precision=2,trim='k')}$"
        # statistic_func = lambda x : f"${np.mean(x):.4g}\pm {(scipy.stats.sem(x) if scipy.stats.sem(x) > 0.001 else 0.0):.2g}$"
        whl_fam_results_norm = rfpg_results['gd-normal']['best_worst_value']
        assert np.isfinite(whl_fam_results_norm)
        assert whl_fam_results_norm > 0
        statistic_func_whl = lambda x : f"${np.mean(whl_fam_results_norm / np.array(x)if minimizing else np.array(x) / whl_fam_results_norm):.4g}$"
        
        sub_fam_results_norm = aggregator(ours['ours'])
        assert np.isfinite(sub_fam_results_norm)
        assert sub_fam_results_norm > 0
        statistic_func_sub = lambda x : f"${np.mean(sub_fam_results_norm / np.array(x) if minimizing else np.array(x) / sub_fam_results_norm):.4g}$"
        # statistic_func = lambda x : f"${np.mean(x):.4g}\pm {np.std(x):.4g}$"
        # statistic_func = lambda x : f"${np.mean(x):.4g}\pm {scipy.stats.t.interval(0.95, len(x)-1, loc=np.mean(x), scale=scipy.stats.sem(x))[0]:.4g}$"
        # statistic_func = lambda x : f"${np.round(np.mean(x), decimals=2)} \pm {np.round(np.std(x), decimals=2)}$"
        
        # print(env_id, subfamresults[env_id][f'\\saynt ({one_by_one})'])
        
        print('sub norm:', sub_fam_results_norm)
        print('whl norm:', whl_fam_results_norm)
        
        # for key, value in subfamresults.items():
        #     subfamresults[key] = statistic_func(value)
        
        for d, _ in zip([subfamresults, whlfamresults], [statistic_func_sub, statistic_func_whl]):
            if include_error:
                statistic_func = lambda x : f"${np.mean(x):.2f} \pm {scipy.stats.sem(x):.2f}$" if normalize_results else f"${np.mean(x):.4g} \pm {scipy.stats.sem(x):.4g}$"
            else:
                statistic_func = lambda x : f"${np.mean(x):.2f}$" if normalize_results else f"${np.mean(x):.4g}$"
            bests = []
            best_number = np.inf if minimizing else -np.inf
            for key, value in d[env_id].items():
                d[env_id][key] = statistic_func(value)
                average = np.mean(value)
                # average = float(d[env_id][key].replace("$", ""))
                if minimizing and average <= best_number:
                    if math.isclose(average, best_number):
                        bests.append(key)
                    else:
                        best_number = average
                        bests = [key]
                if not minimizing and average >= best_number:
                    if math.isclose(average, best_number):
                        bests.append(key)
                    else:
                        best_number = average
                        bests = [key]
                    
                
            for best in bests:
                print(best)

    return subfamresults, whlfamresults


In [None]:
SEEDS[:1]

In [None]:
# subfam, whlfam = create_table(ENV_NAMES, MINIMIZING, include_union_results=True)
_, whlfam = create_table_multiple_seeds(ENV_NAMES, MINIMIZING, include_union_results=True, normalize_results=False, include_error=False, seeds=SEEDS[0:1])

In [None]:
# subfam, whlfam = create_table(ENV_NAMES, MINIMIZING, include_union_results=True)
subfam, whlfam = create_table_multiple_seeds(ENV_NAMES, MINIMIZING, include_union_results=True, normalize_results=True, include_error=True)

In [None]:
def pd_highlight() -> str:
    return "highlight:--rwrap;"

In [None]:
def to_latex_str_new(df : pd.DataFrame, highlight=False):
    # df = df.round(decimals=2)
    return df.style.format(precision=2).to_latex(hrules=True, multicol_align='l')

In [None]:
def to_latex_str(df : pd.DataFrame):
    df = df.round(decimals=2)
    return df[:3].style.highlight_max(axis=1, props=pd_highlight().format(precision=2)).concat(df[3:].style.highlight_min(axis=1, props=pd_highlight()).format(precision=2)).format(precision=2).to_latex(hrules=True)

In [None]:
print("Whole family results:")

In [None]:
pd.DataFrame(whlfam).T

In [None]:
pd.DataFrame(whlfam).T.iloc[:, :-1]

In [None]:
pd.DataFrame(subfam).T

In [None]:
df1 = pd.DataFrame(whlfam).T.iloc[:, :-1]
last_col = pd.DataFrame(whlfam).T.iloc[:, -1]
df2 = pd.DataFrame(subfam).T

In [None]:
last_col

In [None]:
df = df1.join(df2, lsuffix="_whl", rsuffix="_sub")

In [None]:
new_col = [col + suffix for pair in zip(df1.columns, df2.columns) for col, suffix in zip(pair, ['_sub', '_whl'])]
new_col

In [None]:
len(new_col)

In [None]:
df

In [None]:
methods = [
f'\\saynt{one_by_one_suffix}',
f'\\saynt{union_suffix}',
  f'\\gd{one_by_one_suffix}',
 f'\\gd{union_suffix}',
 '\\ours']

def create_df_for_table_one(df1, df2, last_col):
    df = df1.join(df2, lsuffix="_whl", rsuffix="_sub")
    new_col = [col + suffix for pair in zip(df1.columns, df2.columns) for col, suffix in zip(pair, ['_sub', '_whl'])]
    df = df[new_col]
    df = pd.concat([df, last_col], axis=1)
    # methods = list(set(list(subfam[ENV_NAMES[0].upper()].keys()) + list(whlfam[ENV_NAMES[0].upper()].keys())))
    idx = pd.MultiIndex(levels=[methods, ['Subset', 'Whole'], ['Subset', 'Whole']],
                codes=[[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]],
                names=['Method', 'Synthesis', 'Evaluation'])
    df.columns = idx
    return df

def create_df_for_table_two(df1, df2, last_col):
    df = df1.join(df2, lsuffix="_whl", rsuffix="_sub")
    print(df.columns)
    # new_col = [col + suffix for pair in zip(df1.columns, df2.columns) for col, suffix in zip(pair, ['_sub', '_whl'])]
    # df = df[new_col]
    df = pd.concat([df, last_col], axis=1)
    idx = pd.MultiIndex(levels=[['Subset', 'Whole'], ['Subset', 'Whole'], methods],
                codes=[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4]],
                names=['Synthesis', 'Evaluation', 'Method'])
    df.columns = idx
    return df

def create_df_for_table_three(df1, df2, last_col, flip_synth_eval=False):
    df = df1.join(df2, lsuffix="_whl", rsuffix="_sub")
    print(df.columns)
    new_col = [
        f'\\saynt{one_by_one_suffix}',
f'\\saynt{union_suffix}',
  f'\\gd{one_by_one_suffix}',
 f'\\gd{union_suffix}',
    ]
    new_col = [f"{m}_sub" for m in methods if not 'ours' in m.lower()] + ['\\ours  (subfamily)_sub'] + [f"{m}_whl" for m in methods if not 'ours' in m.lower()] + ['\\ours  (subfamily)_whl']
    # new_col = [col + suffix for pair in zip(df1.columns, df2.columns) for col, suffix in zip(pair, ['_sub', '_whl'])]
    df = df[new_col]
    df = pd.concat([df, last_col], axis=1)
    codes = [[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 4]] if flip_synth_eval else [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 4]]
    names = ['Evaluation', 'Synthesis', 'Method'] if flip_synth_eval else ['Synthesis', 'Evaluation', 'Method']
    idx = pd.MultiIndex(levels=[['Subset', 'Whole'], ['Subset', 'Whole'], methods],
                codes=codes,
                names=names)
    df.columns = idx
    return df

def create_df_for_table_four(df1, df2, last_col):
    df = df1.join(df2, lsuffix="_whl", rsuffix="_sub")
    print(df.columns)
    new_col = [
        f'\\saynt{one_by_one_suffix}',
f'\\saynt{union_suffix}',
  f'\\gd{one_by_one_suffix}',
 f'\\gd{union_suffix}',
    ]
    new_col = [f"{m}_sub" for m in methods if not 'ours' in m.lower()] + ['\\ours  (subfamily)_sub'] + [f"{m}_whl" for m in methods if not 'ours' in m.lower()] + ['\\ours  (subfamily)_whl']
    # new_col = [col + suffix for pair in zip(df1.columns, df2.columns) for col, suffix in zip(pair, ['_sub', '_whl'])]
    df = df[new_col]
    df = pd.concat([df, last_col], axis=1)
    idx = pd.MultiIndex(levels=[['Subset', 'Whole'], ['Subset', 'Whole'], methods, ['Enum', 'Union', '']],
                codes=[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 4], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 2]],
                names=['Synthesis', 'Evaluation', 'Method', 'Type'])
    df.columns = idx
    return df

In [None]:
df = create_df_for_table_one(df1, df2, last_col)
df

In [None]:
# create_df_for_table_two(df1, df2, last_col)

In [None]:
with open(f"{OUTPUT_DIR}/joint.tex", 'w') as file:
    print(to_latex_str_new(df), file=file)

In [None]:
df = create_df_for_table_three(df1, df2, last_col)
pd.concat([df.iloc[:, 0:4], df.iloc[:, 5:]], axis=1)

In [None]:
df = create_df_for_table_four(df1, df2, last_col)
df = pd.concat([df.iloc[:, 0:4], df.iloc[:, 5:-1]], axis=1)
df

In [None]:
with open(f"{OUTPUT_DIR}/joint.tex", 'w') as file:
    print(to_latex_str_new(df), file=file)

In [None]:
subdf = df.iloc[:, :4]
with open(f"{OUTPUT_DIR}/joint-error-1.tex", 'w') as file:
    print(to_latex_str_new(subdf), file=file)
subdf

In [None]:
subdf = df.iloc[:, 5:]
with open(f"{OUTPUT_DIR}/joint-error-2.tex", 'w') as file:
    print(to_latex_str_new(subdf), file=file)
subdf

In [None]:
print("".join(["\\cmidrule(lr){" + f"{i}-{i}" + "}" for i in range(1, 13)]))

In [None]:
df = create_df_for_table_four(df1, df2, last_col)
df

In [None]:
df.iloc[:3, :5]

In [None]:
string = df[:3].style.highlight_max(axis=1, props=pd_highlight().format(precision=2)).concat(df[3:].style.highlight_min(axis=1, props=pd_highlight()).format(precision=2)).format(precision=2).to_latex(hrules=True)
print(string)

In [None]:
methods = list(set(list(subfam[ENV_NAMES[0].upper()].keys()) + list(whlfam[ENV_NAMES[0].upper()].keys())))
methods

In [None]:
columns = [
 '\\saynt (enum.)_sub',
 '\\saynt (enum.)_whl',
 '\\saynt (union)_sub',
 '\\saynt (union)_whl',
 '\\gd (enum.)_sub',
 '\\gd (enum.)_whl',
 '\\gd (union)_sub',
 '\\gd (union)_whl',
 '\\ours  (subfamily)_sub',
 '\\ours  (subfamily)_whl'
 ]

In [None]:
print(to_latex_str_new(pd.DataFrame(df)))

In [None]:
print(to_latex_str(pd.DataFrame(whlfam).T))

In [None]:
methods = [
 '\\saynt (enum.)',
 '\\saynt (union)',
 '\\gd (enum.)',
 '\\gd (union)',
 '\\ours'
 ]

In [None]:
with open(f"{OUTPUT_DIR}/joint.tex", 'w') as file:
    print(to_latex_str_new(df), file=file)

In [None]:
with open(f"{OUTPUT_DIR}/whlfam.tex", 'w') as file:
    print(to_latex_str(pd.DataFrame(whlfam).T), file=file)

In [None]:
with open(f"{OUTPUT_DIR}/subfam.tex", 'w') as file:
    print(to_latex_str(pd.DataFrame(subfam).T), file=file)

In [None]:
pd.DataFrame(subfam).T

In [None]:
print(to_latex_str(pd.DataFrame(subfam).T))

In [None]:
def create_heatmap(env : str, plot_gradient_baseline = False, include_union_results = False, seed=SEED, subfamsize=SUBFAMILY_SIZE, **kwargs):
    saynt, ours, union_results, gradient = load_results(env, include_union_results=include_union_results, include_gd_results=plot_gradient_baseline, seed=seed, subfamsize=subfamsize)
    # print(union_results_saynt)
    # print(union_results_gradient)
    if plot_gradient_baseline:
        with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-gradient.pickle", 'rb') as handle:
            gradient = pickle.load(handle)

        if gradient and union_results['gradient']:
            make_heatmap(gradient, ours, env, get_env_str(env), "GD", union_results=union_results['gradient'], **kwargs)
        else:
            print("Gradient Union results None for", env, seed)

    if union_results and union_results['saynt']:
        make_heatmap(saynt, ours, env, get_env_str(env), "Saynt", union_results=union_results['saynt'], **kwargs)
    else:
        print("Saynt Union results None for", env, seed)


# Heatmaps ALL

In [None]:
env = 'obstacles-illustrative'
sns.set_theme(style="whitegrid",font_scale=3)

saynt, ours, union_results, gradient = load_results(env, include_union_results=False, include_gd_results=True, include_saynt_results=False, seed=11, subfamsize=3)

make_heatmap(gradient, ours, env, r"$\downarrow~$Illustrative Example", "GD")

In [None]:
env = 'obstacles-illustrative'
sns.set_theme(style="whitegrid",font_scale=4)

saynt, ours, union_results, gradient = load_results(env, include_union_results=False, include_gd_results=True, include_saynt_results=False, seed=11, subfamsize=3)
ours

In [None]:
env = 'obstacles-illustrative'
sns.set_theme(style="whitegrid",font_scale=4)

temp = BASE_DIR
BASE_DIR = './example-test/parallel-IJCAI-example-deterministic'

saynt, ours, union_results, gradient = load_results(env, include_union_results=False, include_gd_results=True, include_saynt_results=False, seed=11, subfamsize=3)
add_whole_family_value = True
minimizing = True
results = gradient
our_results = ours
values = np.array([np.ravel(list(item[-2].values())) for train, item in results.items()])#.squeeze(axis=-1)
values[~np.isfinite(values)] = np.nan
# if add_whole_family_value:
    # entire_family_values = np.array([item[-1] for train, item in results.items()])
    # values = np.hstack([values, entire_family_values[..., None]])
if union_results:
    # print(values, union_results['subfamily'])
    # print(np.shape(values), np.shape()
    if type(union_results['saynt']['subfamily']) == dict:
        ur = list(union_results['saynt']['subfamily'].values())
    else:
        ur = union_results['saynt']['subfamily']
        
    values = np.vstack([values, ur])
    
    if type(union_results['gradient']['subfamily']) == dict:
        ur = list(union_results['gradient']['subfamily'].values())
    else:
        ur = union_results['gradient']['subfamily']
        
    values = np.vstack([values, ur])
values = np.vstack([values, our_results['ours']])
if add_whole_family_value:
    entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + ([union_results['saynt']['whole_family'], union_results['gradient']['whole_family']] if union_results else []) + [our_results['whole_family']])
    # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
    values = np.hstack([values, entire_family_values[..., None]])
# plt.figure(figsize=(16,9), dpi=300)

if minimizing:
    colormap = sns.cm.rocket_r
    # colormap.set_bad(sns.cm.rocket.get_bad())
else:
    colormap = sns.cm.rocket
# colormap.set_bad('black')

plt.figure(figsize=(16,9))

ax = sns.heatmap(values, annot=True, cbar=False, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=colormap, fmt='.0f')#, annot_kws={"size": 25})#, mask=~np.isfinite(values))

# if add_whole_family_value:
    # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box

subfamily_size = len(results.keys())

best_value = np.inf if minimizing else -np.inf
for r in range(values.shape[0]):
    row_values = values[r][:-1] if add_whole_family_value else values[r]
    if np.isfinite(row_values).any():
        idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
    else:
        idx = random.randint(0, len(row_values)-1)
    cmp = operator.le if minimizing else operator.ge
    if cmp(row_values[idx], best_value):
        best_value = row_values[idx]
        best_rectangle = (idx, r)
    ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))

if add_whole_family_value:
    best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
    ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))

# ax.set_xlabel("Test")
# ax.set_ylabel("Train")
plt.yticks(rotation=0)
xticks = ["$M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)]
if add_whole_family_value:
    xticks += ["$\mathcal{M}$"]
ax.set_xticklabels(xticks)#, rotation=45, ha='right', rotation_mode='anchor')
# ax.set_xticklabels(xticks)
suffix = lambda i : f" ({results[i][1].num_nodes if results[i][1] else None}-FSC)"
suffix = lambda _ : ""
# prefixE = r'\textsc{gd-E}' if 'gd' in baseline.lower() else r'\textsc{Saynt-E}' if 'saynt' in baseline.lower() else 'WHAT?'
# prefixU = r'\textsc{gd-U}' if 'gd' in baseline.lower() else r'\textsc{Saynt-U}' if 'saynt' in baseline.lower() else 'WHAT?'
ax.set_yticklabels(["$\\displaystyle M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)] + (['\\textsc{gd-U}', '\\textsc{Saynt-U}'] if union_results else []) + ['\\textsc{rfPG}']) # ["rfPG on subfamily"])
# ax.set_yticklabels([prefixE + f" on POMDP {i+1}{suffix(i)}" for i in range(subfamily_size)] + ([prefixU + f" on Union POMDP"] if union_results else []) + [r'\textsc{rfPG-S} on Subfamily']) # ["rfPG on subfamily"])
# ax.set_title(f"{title}: {baseline} baselines vs " + r'\textsc{rfPG-S}' + f" for a single run ({'lower' if minimizing else 'higher'} is better)")
# ax.set_title(get_env_str(env))
# plt.tight_layout()
print(env)
plt.savefig(f"{OUTPUT_DIR}/heatmaps/{env}-main-body", **save_args)
save_args2 = {k : v for k, v in save_args.items()}
save_args2['backend'] = 'pdf'
save_args2['format']  = 'pdf'
print(save_args, save_args2)
plt.savefig(f"{OUTPUT_DIR}/heatmaps/{env}-main-body.pdf", **save_args2)
plt.show()

BASE_DIR = temp

In [None]:
env = 'obstacles-8-3'
sns.set_theme(style="whitegrid",font_scale=2)

saynt, ours, union_results, gradient = load_results(env, include_union_results=True, include_gd_results=True, include_saynt_results=True, seed=SEEDS[0], subfamsize=10)
add_whole_family_value = True
minimizing = True
results = saynt
our_results = ours
values = np.array([np.ravel(list(item[-2].values())) for train, item in results.items()])#.squeeze(axis=-1)
values[~np.isfinite(values)] = np.nan
# if add_whole_family_value:
    # entire_family_values = np.array([item[-1] for train, item in results.items()])
    # values = np.hstack([values, entire_family_values[..., None]])
if union_results:
    # print(values, union_results['subfamily'])
    # print(np.shape(values), np.shape()
    if type(union_results['saynt']['subfamily']) == dict:
        ur = list(union_results['saynt']['subfamily'].values())
    else:
        ur = union_results['saynt']['subfamily']
        
    values = np.vstack([values, ur])
    
    if type(union_results['gradient']['subfamily']) == dict:
        ur = list(union_results['gradient']['subfamily'].values())
    else:
        ur = union_results['gradient']['subfamily']
        
    values = np.vstack([values, ur])
values = np.vstack([values, our_results['ours']])
if add_whole_family_value:
    entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + ([union_results['saynt']['whole_family'], union_results['gradient']['whole_family']] if union_results else []) + [our_results['whole_family']])
    # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
    values = np.hstack([values, entire_family_values[..., None]])
# plt.figure(figsize=(16,9), dpi=300)

if minimizing:
    colormap = sns.cm.rocket_r
    # colormap.set_bad(sns.cm.rocket.get_bad())
else:
    colormap = sns.cm.rocket
# colormap.set_bad('black')

plt.figure(figsize=(16,9))

ax = sns.heatmap(values, annot=True, cbar=False, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=colormap, fmt='.0f')#, annot_kws={"size": 25})#, mask=~np.isfinite(values))

# if add_whole_family_value:
    # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box

subfamily_size = len(results.keys())

best_value = np.inf if minimizing else -np.inf
for r in range(values.shape[0]):
    row_values = values[r][:-1] if add_whole_family_value else values[r]
    if np.isfinite(row_values).any():
        idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
    else:
        idx = random.randint(0, len(row_values)-1)
    cmp = operator.le if minimizing else operator.ge
    if cmp(row_values[idx], best_value):
        best_value = row_values[idx]
        best_rectangle = (idx, r)
    ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))

if add_whole_family_value:
    best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
    ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))

# ax.set_xlabel("Test")
# ax.set_ylabel("Train")
plt.yticks(rotation=0)
xticks = ["$M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)]
if add_whole_family_value:
    xticks += ["$\mathcal{M}$"]
ax.set_xticklabels(xticks)#, rotation=45, ha='right', rotation_mode='anchor')
# ax.set_xticklabels(xticks)
suffix = lambda i : f" ({results[i][1].num_nodes if results[i][1] else None}-FSC)"
suffix = lambda _ : ""
# prefixE = r'\textsc{gd-E}' if 'gd' in baseline.lower() else r'\textsc{Saynt-E}' if 'saynt' in baseline.lower() else 'WHAT?'
# prefixU = r'\textsc{gd-U}' if 'gd' in baseline.lower() else r'\textsc{Saynt-U}' if 'saynt' in baseline.lower() else 'WHAT?'
ax.set_yticklabels(["\\textsc{Saynt} $M_{" + f"{i+1}" + "}$" for i in range(subfamily_size)] + (['\\textsc{gd-U}', '\\textsc{Saynt-U}'] if union_results else []) + ['\\textsc{rfPG-S}']) # ["rfPG on subfamily"])
# ax.set_yticklabels([prefixE + f" on POMDP {i+1}{suffix(i)}" for i in range(subfamily_size)] + ([prefixU + f" on Union POMDP"] if union_results else []) + [r'\textsc{rfPG-S} on Subfamily']) # ["rfPG on subfamily"])
# ax.set_title(f"{title}: {baseline} baselines vs " + r'\textsc{rfPG-S}' + f" for a single run ({'lower' if minimizing else 'higher'} is better)")
# ax.set_title(get_env_str(env).replace('3','5'))
plt.tight_layout()
print(env)
plt.savefig(f"{OUTPUT_DIR}/heatmaps/{env}-main-body.{save_format}", **save_args)
plt.show()

In [None]:
sns.set_theme(style="whitegrid",font_scale=2)
for env, minimizing in zip(ENV_NAMES, MINIMIZING):
    # if '8-3' not in env:
        # continue
    # try:
    create_heatmap(env, minimizing=minimizing, include_union_results=True, plot_gradient_baseline=True, seed=SEEDS[0])
    # break
    # except:
        # print(env, 'failed!')

# Gather statistics

In [None]:
from pomdp_families import POMDPFamiliesSynthesis

In [None]:
def get_max_number_of_states(env):
    gd = POMDPFamiliesSynthesis(env)
    highest_nr_states = 0
    for hole_combination in gd.pomdp_sketch.family.all_combinations():
        assignment = gd.pomdp_sketch.family.construct_assignment(hole_combination)
        pomdp = gd.pomdp_sketch.build_pomdp(assignment)
        curr_nr_states = pomdp.model.nr_states
        highest_nr_states = max(highest_nr_states, curr_nr_states)
    return highest_nr_states

In [None]:
def get_statistics_and_random_policy_value(env, stratified=True, subfamily_size=SUBFAMILY_SIZE):
    gd = POMDPFamiliesSynthesis(env)
    family_size = gd.pomdp_sketch.family.size
    rand_fsc = gd.random_fsc(1)
    
    dtmc_sketch = gd.get_dtmc_sketch(rand_fsc)
    
    results = {}
    
    for seed in SEEDS:
        if stratified:
            subfamily_assigments, hole_combinations = gd.stratified_subfamily_sampling(subfamily_size, seed=seed)
        else:
            subfamily_assigments, hole_combinations = gd.create_random_subfamily(subfamily_size)
    
        evaluations = gd.get_values_on_subfamily(dtmc_sketch, subfamily_assigments)
        
        results[seed] = {
            'hole_combinations' : hole_combinations,
            'evaluations' : evaluations
        }

    _, family_value = gd.paynt_call(dtmc_sketch)
    
    results['family_value'] = family_value
    
    results['family_size'] = gd.pomdp_sketch.family.size
    results['num_actions'] = gd.pomdp_sketch.num_actions
    results['num_observations'] = gd.pomdp_sketch.num_observations
    
    results['max_num_states'] = get_max_number_of_states(env)
    
    
    return results

In [None]:
# statistics = {}
# for env in ENVS:
#     statistics[env] = get_statistics_and_random_policy_value(env)

In [None]:
# with open("./statistics.pickle", 'wb') as handle:
#     pickle.dump(statistics, handle)