In [None]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from config import *
import operator
import os

In [None]:
# plt.rcParams.update({
#     "pgf.texsystem": "lualatex"
# })
font_scale = 1.5
sns.set_theme(style="whitegrid",font_scale=font_scale)
fontsize = 15
plt.rc("font", **{"family": "serif", "serif": ["Times"], "size" : fontsize})
plt.rc("text", usetex=True)
save_format = 'pdf'
backend = 'pdf'
transparent = True
pad_inches = 0.1
bbox_inches = 'tight'
save_args = {
    'format' : save_format,
    'backend' : backend,
    'transparent' : transparent,
    'pad_inches' : pad_inches,
    'bbox_inches' : bbox_inches,
}

In [None]:
# BASE_DIR='./outputs/second'
# BASE_DIR='./output-union'
# BASE_DIR='./bert/output-merged'
# BASE_DIR='./output-parallel-subfamily'
# BASE_DIR='./output-parallel-bert'
# BASE_DIR='./verifai/output/parallel-full'
BASE_DIR='./verifaiIJCAI/outputIJCAI/parallel-IJCAI'
SUBFAMILY_SIZE=10
SEED=2
SEEDS=list(range(2,12))

In [None]:
SEEDS

In [None]:
with open("./statistics.pickle", 'rb') as handle:
    statistics = pickle.load(handle)

In [None]:
statistics[ENVS[0]]

In [None]:
OUTPUT_DIR = "./plot_builder_output-clean"
os.makedirs(f"{OUTPUT_DIR}/heatmaps", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)

In [None]:
import random


def make_heatmap(results : dict, our_results : dict, title : str, baseline : str, union_results : dict = None, minimizing = True, add_whole_family_value = True):
    # print([np.atleast_1d(list(item[-2].values())) for train, item in results.items()])
    print(np.ravel(list(results[6][-2].values())).shape)
    print(np.ravel(list(results[7][-2].values())).shape)
    print(np.ravel(list(results[8][-2].values())).shape)
    values = np.array([np.ravel(list(item[-2].values())) for train, item in results.items()])#.squeeze(axis=-1)
    values[~np.isfinite(values)] = np.nan
    # if add_whole_family_value:
        # entire_family_values = np.array([item[-1] for train, item in results.items()])
        # values = np.hstack([values, entire_family_values[..., None]])
    if union_results:
        # print(values, union_results['subfamily'])
        # print(np.shape(values), np.shape()
        if type(union_results['subfamily']) == dict:
            ur = list(union_results['subfamily'].values())
        else:
            ur = union_results['subfamily']
            
        values = np.vstack([values, ur])
    values = np.vstack([values, our_results['ours']])
    if add_whole_family_value:
        entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + ([union_results['whole_family']] if union_results else []) + [our_results['whole_family']])
        # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
        values = np.hstack([values, entire_family_values[..., None]])
    plt.figure(figsize=(16,9))
    
    if minimizing:
        colormap = sns.cm.rocket_r
        # colormap.set_bad(sns.cm.rocket.get_bad())
    else:
        colormap = sns.cm.rocket
    # colormap.set_bad('black')
    
    
    ax = sns.heatmap(values, annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=colormap, fmt='.2f')#, mask=~np.isfinite(values))
    
    # if add_whole_family_value:
        # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box
    
    subfamily_size = len(results.keys())
    
    best_value = np.inf if minimizing else -np.inf
    for r in range(values.shape[0]):
        row_values = values[r][:-1] if add_whole_family_value else values[r]
        if np.isfinite(row_values).any():
            idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
        else:
            idx = random.randint(0, len(row_values)-1)
        cmp = operator.le if minimizing else operator.ge
        if cmp(row_values[idx], best_value):
            best_value = row_values[idx]
            best_rectangle = (idx, r)
        ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
    ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))
    
    if add_whole_family_value:
        best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
        ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))
    
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")
    plt.yticks(rotation=0) 
    xticks = [f"{i}" for i in range(subfamily_size)]
    if add_whole_family_value:
        xticks += ["Entire family"]
    ax.set_xticklabels(xticks) 
    ax.set_yticklabels([f"{baseline} on {i} ({results[i][1].num_nodes if results[i][1] else None}-FSC)" for i in range(subfamily_size)] + (["Saynt on Union"] if union_results else []) + ["Ours: GD on (sub)family"])
    ax.set_title(f"{title}: {baseline} vs Ours ({'lower' if minimizing else 'higher'} is better)")
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/heatmaps/{title}-{baseline}.{save_format}", **save_args)

In [None]:
import math
import random


def make_lineplot(results : dict, title : str = 'Placeholder', minimizing = True, type_of_plot = 'family_trace', use_time_x_axis=False):
    fig = plt.figure()
    ax = fig.gca()
    plt.title(f"{title} ({'lower' if minimizing else 'higher'} is better)")
    plt.ylabel("Worst family member value")
    
    max_time_normal = max(results['gd-normal']['plot_times'])
    max_time_random = max(results['gd-random']['plot_times'])
    
    results['gd-random']['plot_times'] = (np.array(results['gd-random']['plot_times']) / max_time_random) * max_time_normal
    
    print(max_time_normal, max_time_random)
    
    if use_time_x_axis:
        plt.xlabel("Time")
        print(results['gd-normal']['plot_times'])
        ax.plot(results['gd-normal']['plot_times'], results['gd-normal'][type_of_plot], label='rfPG', color='green')
    else:
        plt.xlabel("Iteration")
        ax.plot(results['gd-normal'][type_of_plot], label='rfPG', color='green')
    
    min_x_normal = np.argmin(results['gd-normal'][type_of_plot]) if minimizing else np.argmax(results['gd-normal'][type_of_plot])
    min_y_normal = results['gd-normal'][type_of_plot][min_x_normal]
    if use_time_x_axis: min_x_normal = results['gd-normal']['plot_times'][min_x_normal]
    
    # plt.plot(results['gd-no-momentum'][type_of_plot], label='rfPG (no momentum)', color='blue')
    
    # min_x_no_mom = np.argmin(results['gd-no-momentum'][type_of_plot]) if minimizing else np.argmax(results['gd-no-momentum'][type_of_plot])
    # min_y_no_mom = results['gd-no-momentum'][type_of_plot][min_x_no_mom]
    
    
    
    if use_time_x_axis:
        ax.plot(results['gd-random']['plot_times'], results['gd-random'][type_of_plot], label='Random rfPG', color='red')
    else:
        exit()
        ax.plot(results['gd-random'][type_of_plot], label='Random rfPG', color='red')

    min_x = np.argmin(results['gd-random'][type_of_plot]) if minimizing else np.argmax(results['gd-random'][type_of_plot])
    min_y = results['gd-random'][type_of_plot][min_x]
    if use_time_x_axis: min_x = results['gd-random']['plot_times'][min_x]
    
    mini = min([min_x, min_x_normal])
    maxi = max([min_x, min_x_normal])
    
    def get(x):
        if math.isclose(mini, maxi) and math.isclose(x, maxi):
            return random.randint(-50, 50)
        elif math.isclose(x, maxi):
            return 50
        elif math.isclose(x, mini):
            return -50
        else:
            return 0
    
    ax.annotate(f"{min_y:.2f}",
            xy=(min_x, min_y), xycoords='data',
            xytext=(get(min_x), (50 if minimizing else -50)), textcoords='offset points',
            arrowprops=dict(facecolor='red', shrink=0),
            horizontalalignment='center', verticalalignment='bottom')
    
    ax.annotate(f"{min_y_normal:.2f}",
                xy=(min_x_normal, min_y_normal), xycoords='data',
                xytext=(get(min_x_normal), (50 if minimizing else -50)), textcoords='offset points',
                arrowprops=dict(facecolor='green', shrink=0),
                horizontalalignment='center', verticalalignment='bottom')
    
    # ax.annotate(f"{min_y_no_mom:.2f}",
    #     xy=(min_x_no_mom, min_y_no_mom), xycoords='data',
    #     xytext=(get(min_x_no_mom), (50 if minimizing else -50)), textcoords='offset points',
    #     arrowprops=dict(facecolor='blue', shrink=0),
    #     horizontalalignment='center', verticalalignment='bottom')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/lineplot/{title}.{save_format}", **save_args)
    plt.show()

In [None]:
ENV_NAMES = ['dpm', 'obstacles-10-2', 'avoid', 'obstacles-8-3', 'rover', 'network']
MINIMIZING = [False, True, True, True, False, False]

In [None]:
def plot_learning_curve_single_seed(envs=ENV_NAMES, minimizings=MINIMIZING, seed=SEED):
    for env, minimizing in zip(envs, minimizings):
        try:
            with open(f"{BASE_DIR}/{env}/{seed}/gd-experiment.pickle", 'rb') as handle:
                results = pickle.load(handle)
                make_lineplot(results, title=env, minimizing=minimizing, use_time_x_axis=True)
        except FileNotFoundError as fnfe:
            print(fnfe)

In [None]:
import scipy.stats

def plot_learning_curve(envs=ENV_NAMES, minimizings=MINIMIZING, seed=SEED, use_ci=True):
    
    x_common = np.linspace(0, 100, 1000)
    
    for env, minimizing in zip(envs, minimizings):
        rand_X = []
        rand_Y = []
        X = []
        Y = []
        for seed in SEEDS:
            try:
                with open(f"{BASE_DIR}/{env}/seed{seed}/gd-experiment.pickle", 'rb') as handle:
                    results = pickle.load(handle)
                    # make_lineplot(results, title=env, minimizing=minimizing, use_time_x_axis=True)
            except FileNotFoundError as fnfe:
                print(fnfe)
                continue
            
            # print(env, seed)
            
            
            type_of_plot = 'family_trace'                    
            
            x_random = np.array(results['gd-random']['plot_times']) 
            y_random = np.array(results['gd-random'][type_of_plot])
            
            rand_X.append(x_random)
            rand_Y.append(y_random)

            x = np.array(results['gd-normal']['plot_times'])
            y = np.array(results['gd-normal'][type_of_plot])
            
            # if True:
            try:
                max_time_normal = max(results['gd-normal']['plot_times'])
                max_time_random = max(results['gd-random']['plot_times'])
            except:
                continue
            
            if max_time_random >= max_time_normal:
                x_random = (x_random / max_time_random) * max_time_normal
            else:
                x = (x / max_time_random) * max_time_normal
            
            X.append(x)
            Y.append(y)
            
            # y_common = np.interp(x_common, this_x, this_y)
            
            # print(max_time_normal, max_time_random)
        
        num_points = 3600
        x_common = np.linspace(0, 3600, num_points)

        
        def plot_interp(X, Y):
            # y_sum = np.zeros(num_points)
            y_agg = []
            
            assert len(rand_X) == len(rand_Y)
            number_of_curves = len(rand_X)

            for this_x, this_y in zip(X, Y):                
                # Interpolate y so that it's using a common x-axis
                try:
                    y_common = np.interp(x_common, this_x, this_y)
                except Exception as e:
                    # print(this_x, this_y)
                    continue
                
                # Add it to the other curves
                y_agg += [y_common]

            # Divide the sum by the number of curves to get the average curve
            # y_average = y_sum / number_of_curves
            
            return y_agg
            
        # fig = plt.figure(figsize=(3,4))
        # fig = plt.figure(figsize=(1.755 * 2,1.31 * 2))
        plt.figure()
        # ax = fig.gca()
        rand_Y_interp = plot_interp(rand_X, rand_Y)
        # sns.lineplot(x=rand_X_interp, y=rand_Y_interp, label='Random rfPG', color='red')
        plt.plot(x_common, np.mean(rand_Y_interp, axis=0), label='Random', color='red')
        
        if use_ci:
            ci_a, ci_b = scipy.stats.t.interval(0.95, len(x_common)-1, loc=np.mean(rand_Y_interp, axis=0), scale=scipy.stats.sem(rand_Y_interp, axis=0))
            plt.fill_between(x_common, ci_a, ci_b,  alpha=0.25, color='red')
        else:
            plt.fill_between(x_common, np.mean(rand_Y_interp, axis=0) + np.var(rand_Y_interp, axis=0), np.mean(rand_Y_interp, axis=0) - np.var(rand_Y_interp, axis=0),  alpha=0.25, color='red')
        
        Y_interp = plot_interp(X, Y)
        plt.plot(x_common, np.mean(Y_interp, axis=0), label=r'\textsc{rfPG}', color='green')

        if use_ci:
            ci_a, ci_b = scipy.stats.t.interval(0.95, len(x_common)-1, loc=np.mean(Y_interp, axis=0), scale=scipy.stats.sem(Y_interp, axis=0))
            plt.fill_between(x_common, ci_a, ci_b,  alpha=0.25, color='green')
        else:
            plt.fill_between(x_common, np.mean(Y_interp, axis=0) + np.std(Y_interp, axis=0), np.mean(Y_interp, axis=0) - np.std(Y_interp, axis=0),  alpha=0.25, color='green')
        plt.title(f"{env.upper()} ({'lower' if minimizing else 'higher'} is better)")
        plt.ylabel("Robust performance")
        plt.xlabel("Time (seconds)")
        # if 'obstacles-10-2' in env.lower():
            # plt.ylim((20, 50))
        # elif 'avoid' in env.lower():
            # plt.ylim((5, 1000))
        plt.legend()
        plt.tight_layout()
        os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)
        plt.savefig(f"{OUTPUT_DIR}/lineplot/{env}.{save_format}", **save_args)
        plt.show()
        plt.close()

In [None]:
plot_learning_curve()

# DEFAULT FUNC

In [None]:
def load_results(env : str, include_union_results : bool = True, seed=SEED, subfamsize=SUBFAMILY_SIZE, include_gd_results=False):
    with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-saynt.pickle", 'rb') as handle:
        saynt = pickle.load(handle)
        
    with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-ours.pickle", 'rb') as handle:
        ours = pickle.load(handle)
    
    if include_gd_results:
        with open(f"{BASE_DIR}/{env}/subfamsize{subfamsize}/seed{seed}/subfam-gradient.pickle", 'rb') as handle:
            gradient = pickle.load(handle)
    else:
        gradient = None
    
    if include_union_results:
        try:
            with open(f"{BASE_DIR}/{env}/union/seed{seed}/union.pickle", 'rb') as handle: 
                union_results = pickle.load(handle)
        except Exception as e:
            print(e)
            union_results = None
    else:
        union_results = None
    
    if include_gd_results:
        return saynt, ours, union_results, gradient
    else:
        return saynt, ours, union_results

In [None]:
def create_table(envs : list[str], minimizings : list[bool], include_union_results = True):
    subfamresults = {}
    whlfamresults = {}
    for env, minimizing in sorted(zip(envs, minimizings), key=lambda x : x[1]):
        saynt, ours, union_results, gradient = load_results(env, include_union_results=include_union_results, include_gd_results=True)
        
        aggregator = lambda x : float((max if minimizing else min)(x))        
        
        env_id = f"{env.upper()} ({'min.' if minimizing else 'max.'})"
        
        
        subfamresults[env_id] = {
            # min if minimizing because taking the best FSC out of the 10 FSCs, but worst value for that FSC among the 10 POMDPs evaluated. 
            '\\saynt (one-by-one)' : (min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()]),
            '\\saynt (union)' : aggregator(union_results['subfamily']) if union_results else np.nan,
            '\\ours  (subfamily)'  : aggregator(ours['ours']),
        }
        
        with open(f"{BASE_DIR}/{env}/seed{SEED}/gd-experiment.pickle", 'rb') as handle:
            rfpg_results = pickle.load(handle)
            
        
        
        whlfamresults[env_id] = {
            '\\saynt (one-by-one)' : aggregator([item[-1] if item[-1] is not None else np.nan for train, item in saynt.items()]),
            '\\saynt (union)' : float(union_results['whole_family']) if union_results else np.nan,
            '\\ours  (subfamily)'  : float(ours['whole_family']),
            '\\ours  (whole family)' : rfpg_results['gd-normal']['best_worst_value']
        }
    
    return subfamresults, whlfamresults


In [None]:
from collections import defaultdict

def create_table_multiple_seeds(envs : list[str], minimizings : list[bool], include_union_results = True, include_gd_results = True):
    subfamresults = defaultdict(dict)
    whlfamresults = defaultdict(dict)
    for env, minimizing in sorted(zip(envs, minimizings), key=lambda x : x[1]):
        
        aggregator = lambda x : float((max if minimizing else min)(x))        
        
        env_id = f"{env.upper()} ({'min.' if minimizing else 'max.'})"
        
        subfamresults[env_id] = defaultdict(list)
        whlfamresults[env_id] = defaultdict(list)
        
        # nan_num = np.inf if minimizing 
        
        def hotfix(x : np.ndarray) -> np.ndarray:
            
            return x[~np.isnan(x)]

        count = 0
        
        for seed in SEEDS:
            try:
                # saynt, ours, union_results = load_results(env, include_union_results=include_union_results, seed=seed)
                saynt, ours, union_results, gradient = load_results(env, include_union_results=include_union_results, seed=seed, include_gd_results=include_gd_results)
                
                with open(f"{BASE_DIR}/{env}/seed{seed}/gd-experiment.pickle", 'rb') as handle:
                    rfpg_results = pickle.load(handle)
            except Exception as e:
                print(e)
                continue
            
            # subfamresults[env_id] = {
                # min if minimizing because taking the best FSC out of the 10 FSCs, but worst value for that FSC among the 10 POMDPs evaluated. 
                # '\\saynt (one-by-one)' : (min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()]),
                # '\\saynt (union)' : aggregator(union_results['subfamily']) if union_results else np.nan,
                # '\\ours  (subfamily)' : aggregator(ours['ours']),
            # }

            count += 1
            
            print(gradient)
            
            subfamresults[env_id]['\\saynt (one-by-one)'] += [(min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()])]
            subfamresults[env_id]['\\saynt (one-by-one)'] += [(min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()])]
            subfamresults[env_id]['\\saynt (union)'] += [aggregator(union_results['subfamily']) if union_results else np.nan]
            subfamresults[env_id]['\\ours  (subfamily)'] += [aggregator(ours['ours'])]
            
            whlfamresults[env_id]['\\saynt (one-by-one)'] += [aggregator([item[-1] if item[-1] is not None else np.nan for train, item in saynt.items()])]
            whlfamresults[env_id]['\\saynt (union)'] += [float(union_results['whole_family']) if union_results else np.nan]
            whlfamresults[env_id]['\\ours  (subfamily)'] += [float(ours['whole_family'])]
            whlfamresults[env_id]['\\ours  (whole family)'] += [rfpg_results['gd-normal']['best_worst_value']]
            
        assert count > 0
        print(env, whlfamresults[env_id]['\\saynt (one-by-one)'])
        subfamresults[env_id]['\\saynt (one-by-one)'] = np.median(subfamresults[env_id]['\\saynt (one-by-one)'])
        subfamresults[env_id]['\\saynt (union)']  = np.median(subfamresults[env_id]['\\saynt (union)'])
        subfamresults[env_id]['\\ours  (subfamily)']  = np.median(subfamresults[env_id]['\\ours  (subfamily)'])
        
        whlfamresults[env_id]['\\saynt (one-by-one)']  = np.median(hotfix(np.array(whlfamresults[env_id]['\\saynt (one-by-one)']))) # TODO this is a hotfix (replace NaN's by value of uniform random value)
        whlfamresults[env_id]['\\saynt (union)']  = np.median(whlfamresults[env_id]['\\saynt (union)'])
        whlfamresults[env_id]['\\ours  (subfamily)']  = np.median(whlfamresults[env_id]['\\ours  (subfamily)'])
        whlfamresults[env_id]['\\ours  (whole family)'] = np.median(whlfamresults[env_id]['\\ours  (whole family)'])
        
        # subfamresults[env_id] = {
        #     # min if minimizing because taking the best FSC out of the 10 FSCs, but worst value for that FSC among the 10 POMDPs evaluated. 
        #     '\\saynt (one-by-one)' : (min if minimizing else max)([aggregator(np.array(list(saynt[k][2].values())).flatten().tolist()) for k in saynt.keys()]),
        #     '\\saynt (union)' : aggregator(union_results['subfamily']) if union_results else np.nan,
        #     '\\ours  (subfamily)'  : aggregator(ours['ours']),
        # }

        # whlfamresults[env_id] = {
        #     '\\saynt (one-by-one)' : aggregator([item[-1] if item[-1] is not None else np.nan for train, item in saynt.items()]),
        #     '\\saynt (union)' : float(union_results['whole_family']) if union_results else np.nan,
        #     '\\ours  (subfamily)'  : float(ours['whole_family']),
        #     '\\ours  (whole family)' : rfpg_results['gd-normal']['best_worst_value']
        # }
    
    return subfamresults, whlfamresults


In [None]:
# subfam, whlfam = create_table(ENV_NAMES, MINIMIZING, include_union_results=True)
subfam, whlfam = create_table_multiple_seeds(ENV_NAMES, MINIMIZING, include_union_results=False)

In [None]:
def pd_highlight() -> str:
    return "highlight:--rwrap;"

In [None]:
def to_latex_str(df : pd.DataFrame):
    df = df.round(decimals=2)
    return df[:3].style.highlight_max(axis=1, props=pd_highlight().format(precision=2)).concat(df[3:].style.highlight_min(axis=1, props=pd_highlight()).format(precision=2)).format(precision=2).to_latex(hrules=True)

In [None]:
print("Whole family results:")

In [None]:
pd.DataFrame(subfam).T

In [None]:
pd.DataFrame(whlfam).T

In [None]:
print(to_latex_str(pd.DataFrame(whlfam).T))

In [None]:
with open(f"{OUTPUT_DIR}/whlfam.tex", 'w') as file:
    print(to_latex_str(pd.DataFrame(whlfam).T), file=file)

In [None]:
with open(f"{OUTPUT_DIR}/subfam.tex", 'w') as file:
    print(to_latex_str(pd.DataFrame(subfam).T), file=file)

In [None]:
pd.DataFrame(subfam).T

In [None]:
print(to_latex_str(pd.DataFrame(subfam).T))

In [None]:
def create_heatmap(env : str, plot_gradient_baseline = False, include_union_results = False, **kwargs):
    saynt, ours, union_results = load_results(env, include_union_results=include_union_results)

    if plot_gradient_baseline:
        with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
            gradient = pickle.load(handle)

        make_heatmap(gradient, ours, env.upper(), "GD", union_results=union_results, **kwargs)

    make_heatmap(saynt, ours, env.upper(), "Saynt", union_results=union_results, **kwargs)


# Gather statistics

In [None]:
from pomdp_families import POMDPFamiliesSynthesis

In [None]:
def get_max_number_of_states(env):
    gd = POMDPFamiliesSynthesis(env)
    highest_nr_states = 0
    for hole_combination in gd.pomdp_sketch.family.all_combinations():
        assignment = gd.pomdp_sketch.family.construct_assignment(hole_combination)
        pomdp = gd.pomdp_sketch.build_pomdp(assignment)
        curr_nr_states = pomdp.model.nr_states
        highest_nr_states = max(highest_nr_states, curr_nr_states)
    return highest_nr_states

In [None]:
def get_statistics_and_random_policy_value(env, stratified=True, subfamily_size=SUBFAMILY_SIZE):
    gd = POMDPFamiliesSynthesis(env)
    family_size = gd.pomdp_sketch.family.size
    rand_fsc = gd.random_fsc(1)
    
    dtmc_sketch = gd.get_dtmc_sketch(rand_fsc)
    
    results = {}
    
    for seed in SEEDS:
        if stratified:
            subfamily_assigments, hole_combinations = gd.stratified_subfamily_sampling(subfamily_size, seed=seed)
        else:
            subfamily_assigments, hole_combinations = gd.create_random_subfamily(subfamily_size)
    
        evaluations = gd.get_values_on_subfamily(dtmc_sketch, subfamily_assigments)
        
        results[seed] = {
            'hole_combinations' : hole_combinations,
            'evaluations' : evaluations
        }

    _, family_value = gd.paynt_call(dtmc_sketch)
    
    results['family_value'] = family_value
    
    results['family_size'] = gd.pomdp_sketch.family.size
    results['num_actions'] = gd.pomdp_sketch.num_actions
    results['num_observations'] = gd.pomdp_sketch.num_observations
    
    results['max_num_states'] = get_max_number_of_states(env)
    
    
    return results

In [None]:
# statistics = {}
# for env in ENVS:
#     statistics[env] = get_statistics_and_random_policy_value(env)

In [None]:
# with open("./statistics.pickle", 'wb') as handle:
#     pickle.dump(statistics, handle)

# ALL

In [None]:
for env, minimizing in zip(ENV_NAMES, MINIMIZING):
    # try:
    create_heatmap(env, minimizing=minimizing, include_union_results=False)
    # except:
        # print(env, 'failed!')

# EXAMPLE

In [None]:
EXAMPLE_DIR = f"{BASE_DIR}/obstacles-illustrative-2/{SUBFAMILY_SIZE}"

In [None]:
from pomdp_families import POMDPFamiliesSynthesis
from config import ILLUSTRATIVE

seed = 11
gd = POMDPFamiliesSynthesis(ILLUSTRATIVE, use_softmax=True, steps=1, learning_rate=0.01, seed=seed)
subfamily_assigments, hole_combinations = gd.stratified_subfamily_sampling(SUBFAMILY_SIZE, seed=seed)

In [None]:
with open(f"{EXAMPLE_DIR}/ours-sparse.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
fsc = ours['fsc']
fsc.num_nodes

In [None]:
gd.paynt_call_given_fsc(fsc)

In [None]:
import copy
det_fsc = copy.deepcopy(fsc)
for node in range(det_fsc.num_nodes):
    for obs in range(det_fsc.num_observations):
        det_fsc.action_function[node][obs] = max(fsc.action_function[node][obs], key=fsc.action_function[node][obs].get)        
        det_fsc.update_function[node][obs] = max(fsc.update_function[node][obs], key=fsc.update_function[node][obs].get)
det_fsc.is_deterministic = True

In [None]:
det_fsc.make_stochastic()
gd.paynt_call_given_fsc(det_fsc)

In [None]:
with open(f"{EXAMPLE_DIR}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
with open(f"{EXAMPLE_DIR}/paynt.pickle", 'rb') as handle:
    subfamily_paynt_results = pickle.load(handle)

In [None]:
with open(f"{EXAMPLE_DIR}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
make_heatmap(subfamily_paynt_results, ours, "Illustrative Example", "Paynt", minimizing=True)

In [None]:
make_heatmap(subfamily_saynt_results, ours, "Illustrative Example", "Saynt", minimizing=True)

In [None]:
make_heatmap(subfamily_gd_results, ours, "Illustrative Example", "GD", minimizing=True)

# UNIONS

In [None]:
ENV_NAMES

In [None]:
for env, minimizing in zip(ENV_NAMES, MINIMIZING):
    if 'avoid' in env.lower(): continue
    values = np.zeros((1, 11))
    # print(env)
    with open(f"{BASE_DIR}/{env}/union/union.pickle", 'rb') as handle:
        results = pickle.load(handle)
    values[0, :-1] = results['subfamily']
    values[0, -1]  = results['whole_family']
    
    with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
        our_results = pickle.load(handle)
        
    with open(f"{BASE_DIR}/{env}/gd-experiment.pickle", 'rb') as handle:
        rfpg_results = pickle.load(handle)
    
    print(rfpg_results['gd-normal'].keys())
        
    print(env.upper(), f"MINIMIZING={minimizing}", results['subfamily'], f"UNION worst out of subfamily: {max(results['subfamily']) if minimizing else min(results['subfamily'])}", f"UNION Whole family worst: {results['whole_family']}", sep='\n')
    print(our_results['ours'], f"OURS worst out of subfamily: {max(our_results['ours']) if minimizing else min(our_results['ours'])}", f"OURS Whole family worst: {our_results['whole_family']}", sep='\n')
    print(f"OURS FULL GD whole family worst: {rfpg_results['gd-normal']['best_worst_value']}")
    

    # plt.figure()
    # sns.heatmap(values, yticklabels=[env], annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=sns.cm.rocket_r if minimizing else sns.cm.rocket, mask=~np.isfinite(values), fmt='.2f')
    # plt.show()

In [None]:
results