In [None]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from config import *
import operator
import os

In [None]:
# BASE_DIR='./outputs/second'
# BASE_DIR='./output-union'
# BASE_DIR='./bert/output-merged'
# BASE_DIR='./output-parallel-subfamily'
# BASE_DIR='./output-parallel-bert'
BASE_DIR='/opt/payntdev/verifai/output/parallel-full'
SUBFAMILY_SIZE=10

In [None]:
OUTPUT_DIR = "./plot_builder_output-clean"
os.makedirs(f"{OUTPUT_DIR}/heatmaps", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)

In [None]:
def make_heatmap(results : dict, our_results : dict, title : str, baseline : str, union_results : dict = None, minimizing = True, add_whole_family_value = True):
    values = np.array([list(item[-2].values()) for train, item in results.items()]).squeeze(axis=-1)
    values[~np.isfinite(values)] = np.nan
    # if add_whole_family_value:
        # entire_family_values = np.array([item[-1] for train, item in results.items()])
        # values = np.hstack([values, entire_family_values[..., None]])
    if union_results:
        values = np.vstack([values, union_results['subfamily']])
    values = np.vstack([values, our_results['ours']])
    if add_whole_family_value:
        entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + ([union_results['whole_family']] if union_results else []) + [our_results['whole_family']])
        # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
        values = np.hstack([values, entire_family_values[..., None]])
    plt.figure(figsize=(16,9))
    
    if minimizing:
        colormap = sns.cm.rocket_r
        # colormap.set_bad(sns.cm.rocket.get_bad())
    else:
        colormap = sns.cm.rocket
    # colormap.set_bad('black')
    
    
    ax = sns.heatmap(values, annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=colormap, fmt='.2f')#, mask=~np.isfinite(values))
    
    # if add_whole_family_value:
        # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box
    
    subfamily_size = len(results.keys())
    
    best_value = np.inf if minimizing else -np.inf
    for r in range(values.shape[0]):
        row_values = values[r][:-1] if add_whole_family_value else values[r]
        idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
        cmp = operator.le if minimizing else operator.ge
        if cmp(row_values[idx], best_value):
            best_value = row_values[idx]
            best_rectangle = (idx, r)
        ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
    ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))
    
    if add_whole_family_value:
        best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
        ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))
    
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")
    plt.yticks(rotation=0) 
    xticks = [f"{i}" for i in range(subfamily_size)]
    if add_whole_family_value:
        xticks += ["Entire family"]
    ax.set_xticklabels(xticks) 
    ax.set_yticklabels([f"{baseline} on {i} ({results[i][1].num_nodes}-FSC)" for i in range(subfamily_size)] + (["Saynt on Union"] if union_results else []) + ["Ours: GD on (sub)family"])
    ax.set_title(f"{title}: {baseline} vs Ours ({'lower' if minimizing else 'higher'} is better)")
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/heatmaps/{title}-{baseline}.png")

In [None]:
import math
import random


def make_lineplot(results : dict, title : str = 'Placeholder', minimizing = True, type_of_plot = 'family_trace', use_time_x_axis=False):
    fig = plt.figure()
    ax = fig.gca()
    plt.title(f"{title} ({'lower' if minimizing else 'higher'} is better)")
    plt.ylabel("Worst family member value")
    
    max_time_normal = max(results['gd-normal']['plot_times'])
    max_time_random = max(results['gd-random']['plot_times'])
    
    results['gd-random']['plot_times'] = (np.array(results['gd-random']['plot_times']) / max_time_random) * max_time_normal
    
    print(max_time_normal, max_time_random)
    
    if use_time_x_axis:
        plt.xlabel("Time")
        ax.plot(results['gd-normal']['plot_times'], results['gd-normal'][type_of_plot], label='rfPG', color='green')
    else:
        plt.xlabel("Iteration")
        ax.plot(results['gd-normal'][type_of_plot], label='rfPG', color='green')
    
    min_x_normal = np.argmin(results['gd-normal'][type_of_plot]) if minimizing else np.argmax(results['gd-normal'][type_of_plot])
    min_y_normal = results['gd-normal'][type_of_plot][min_x_normal]
    if use_time_x_axis: min_x_normal = results['gd-normal']['plot_times'][min_x_normal]
    
    # plt.plot(results['gd-no-momentum'][type_of_plot], label='rfPG (no momentum)', color='blue')
    
    # min_x_no_mom = np.argmin(results['gd-no-momentum'][type_of_plot]) if minimizing else np.argmax(results['gd-no-momentum'][type_of_plot])
    # min_y_no_mom = results['gd-no-momentum'][type_of_plot][min_x_no_mom]
    
    
    
    if use_time_x_axis:
        ax.plot(results['gd-random']['plot_times'], results['gd-random'][type_of_plot], label='Random rfPG', color='red')
    else:
        exit()
        ax.plot(results['gd-random'][type_of_plot], label='Random rfPG', color='red')

    min_x = np.argmin(results['gd-random'][type_of_plot]) if minimizing else np.argmax(results['gd-random'][type_of_plot])
    min_y = results['gd-random'][type_of_plot][min_x]
    if use_time_x_axis: min_x = results['gd-random']['plot_times'][min_x]
    
    mini = min([min_x, min_x_normal])
    maxi = max([min_x, min_x_normal])
    
    def get(x):
        if math.isclose(mini, maxi) and math.isclose(x, maxi):
            return random.randint(-50, 50)
        elif math.isclose(x, maxi):
            return 50
        elif math.isclose(x, mini):
            return -50
        else:
            return 0
    
    ax.annotate(f"{min_y:.2f}",
            xy=(min_x, min_y), xycoords='data',
            xytext=(get(min_x), (50 if minimizing else -50)), textcoords='offset points',
            arrowprops=dict(facecolor='red', shrink=0),
            horizontalalignment='center', verticalalignment='bottom')
    
    ax.annotate(f"{min_y_normal:.2f}",
                xy=(min_x_normal, min_y_normal), xycoords='data',
                xytext=(get(min_x_normal), (50 if minimizing else -50)), textcoords='offset points',
                arrowprops=dict(facecolor='green', shrink=0),
                horizontalalignment='center', verticalalignment='bottom')
    
    # ax.annotate(f"{min_y_no_mom:.2f}",
    #     xy=(min_x_no_mom, min_y_no_mom), xycoords='data',
    #     xytext=(get(min_x_no_mom), (50 if minimizing else -50)), textcoords='offset points',
    #     arrowprops=dict(facecolor='blue', shrink=0),
    #     horizontalalignment='center', verticalalignment='bottom')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/lineplot/{title}.png")
    plt.show()

In [None]:
results['gd-normal'].keys()

In [None]:
ENVS = ['dpm', 'obstacles-10-2', 'avoid', 'obstacles-8-3', 'rover', 'network']
MINIMIZING = [False, True, True, True, False, False]

In [None]:
for env, minimizing in zip(ENVS, MINIMIZING):
    try:
        with open(f"{BASE_DIR}/{env}/gd-experiment.pickle", 'rb') as handle:
            results = pickle.load(handle)
            make_lineplot(results, title=env, minimizing=minimizing, use_time_x_axis=True)
    except FileNotFoundError as fnfe:
        print(fnfe)

# DEFAULT FUNC

In [None]:
def create_heatmap(env : str, plot_gradient_baseline = False, include_union_results = False, **kwargs):
    with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
        saynt = pickle.load(handle)
        
    with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
        ours = pickle.load(handle)
    
    if include_union_results:
        try:
            with open(f"{BASE_DIR}/{env}/union/union.pickle", 'rb') as handle: 
                union_results = pickle.load(handle)
        except Exception as e:
            print(e)
            union_results = None
    else:
        union_results = None
    
    if plot_gradient_baseline:
        with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
            gradient = pickle.load(handle)
        
        make_heatmap(gradient, ours, env.upper(), "GD", union_results=union_results, **kwargs)
            
    make_heatmap(saynt, ours, env.upper(), "Saynt", union_results=union_results, **kwargs)


# ALL

In [None]:
for env, minimizing in zip(ENVS, MINIMIZING):
    create_heatmap(env, minimizing=minimizing, include_union_results=True)

# EXAMPLE

In [None]:
EXAMPLE_DIR = f"{BASE_DIR}/obstacles-illustrative-2/{SUBFAMILY_SIZE}"

In [None]:
from pomdp_families import POMDPFamiliesSynthesis
from config import ILLUSTRATIVE

seed = 11
gd = POMDPFamiliesSynthesis(ILLUSTRATIVE, use_softmax=True, steps=1, learning_rate=0.01, seed=seed)
subfamily_assigments, hole_combinations = gd.stratified_subfamily_sampling(SUBFAMILY_SIZE, seed=seed)

In [None]:
with open(f"{EXAMPLE_DIR}/ours-sparse.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
fsc = ours['fsc']
fsc.num_nodes

In [None]:
gd.paynt_call_given_fsc(fsc)

In [None]:
import copy
det_fsc = copy.deepcopy(fsc)
for node in range(det_fsc.num_nodes):
    for obs in range(det_fsc.num_observations):
        det_fsc.action_function[node][obs] = max(fsc.action_function[node][obs], key=fsc.action_function[node][obs].get)        
        det_fsc.update_function[node][obs] = max(fsc.update_function[node][obs], key=fsc.update_function[node][obs].get)
det_fsc.is_deterministic = True

In [None]:
det_fsc.make_stochastic()
gd.paynt_call_given_fsc(det_fsc)

In [None]:
with open(f"{EXAMPLE_DIR}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
with open(f"{EXAMPLE_DIR}/paynt.pickle", 'rb') as handle:
    subfamily_paynt_results = pickle.load(handle)

In [None]:
with open(f"{EXAMPLE_DIR}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
make_heatmap(subfamily_paynt_results, ours, "Illustrative Example", "Paynt", minimizing=True)

In [None]:
make_heatmap(subfamily_saynt_results, ours, "Illustrative Example", "Saynt", minimizing=True)

In [None]:
make_heatmap(subfamily_gd_results, ours, "Illustrative Example", "GD", minimizing=True)

# DPM

timeout = 10s

In [None]:
dpm_our = [504.76099479, 496.6463078,  522.29707964, 471.26597349, 483.96767624,
 463.09737765, 294.67001291, 498.28841493, 510.35243059, 509.71750883]

In [None]:
create_heatmap('dpm', minimizing=False)

In [None]:
make_heatmap(subfamily_saynt_results, ours, "DPM", "Saynt", minimizing=False)

In [None]:
make_heatmap(subfamily_gd_results, ours, "DPM", "GD", minimizing=False)

# OBSTACLES 8 3

timeout = 30s

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    saynt = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    gradient = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
make_heatmap(saynt, ours, "OBSTACLES(8,3)", "Saynt")

## Gradient baseline

In [None]:
make_heatmap(gradient, ours, "OBSTACLES(8,3)", "GD")

# OBSTACLES 10 2

timeout = 10s

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
ours_raw = [27.85853738, 28.71888006, 30.78940948, 28.75065993, 29.3297606,  30.77959397,
 29.48840814, 29.3297606,  28.22878829, 30.14997737]

In [None]:
make_heatmap(subfamily_saynt_results, ours, "OBSTACLES(10,2)", "Saynt")

In [None]:
make_heatmap(subfamily_gd_results, ours, "OBSTACLES(10,2)", "GD")

# AVOID

timeout = 60s

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
ours

In [None]:
# with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/paynt.pickle", 'rb') as handle:
#     subfamily_paynt_results = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
l = [51.94674404, 65.42364497, 63.76348781, 98.68745545, 32.52550974, 67.09607064,
 87.41523619, 64.81453459, 53.95332456, 56.50783793]

In [None]:
make_heatmap(subfamily_saynt_results, ours, "AVOID", "Saynt")

In [None]:
# make_heatmap(subfamily_paynt_results, ours, "AVOID", "Paynt")

In [None]:
make_heatmap(subfamily_gd_results, ours, "AVOID", "GD")

# UNIONS

In [None]:
ENVS

In [None]:
# ENVS = ['dpm', 'network', 'obstacles-8-3', 'obstacles-10-2', 'rover']

for env, minimizing in zip(ENVS, MINIMIZING):
    if 'avoid' in env.lower(): continue
    values = np.zeros((1, 11))
    # print(env)
    with open(f"{BASE_DIR}/{env}/union/union.pickle", 'rb') as handle:
        results = pickle.load(handle)
    values[0, :-1] = results['subfamily']
    values[0, -1]  = results['whole_family']
    
    with open(f"{BASE_DIR}/{env}/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
        our_results = pickle.load(handle)
        
    with open(f"{BASE_DIR}/{env}/gd-experiment.pickle", 'rb') as handle:
        rfpg_results = pickle.load(handle)
    
    print(rfpg_results['gd-normal'].keys())
        
    print(env.upper(), f"MINIMIZING={minimizing}", results['subfamily'], f"UNION worst out of subfamily: {max(results['subfamily']) if minimizing else min(results['subfamily'])}", f"UNION Whole family worst: {results['whole_family']}", sep='\n')
    print(our_results['ours'], f"OURS worst out of subfamily: {max(our_results['ours']) if minimizing else min(our_results['ours'])}", f"OURS Whole family worst: {our_results['whole_family']}", sep='\n')
    print(f"OURS FULL GD whole family worst: {rfpg_results['gd-normal']['best_worst_value']}")
    

    # plt.figure()
    # sns.heatmap(values, yticklabels=[env], annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=sns.cm.rocket_r if minimizing else sns.cm.rocket, mask=~np.isfinite(values), fmt='.2f')
    # plt.show()

In [None]:
results