In [None]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import operator
import os

In [None]:
# BASE_DIR='./outputs/second'
BASE_DIR='./output-union'
# BASE_DIR='./bert/output-merged'
SUBFAMILY_SIZE=5

In [None]:
OUTPUT_DIR = "./plot_builder_output"
os.makedirs(f"{OUTPUT_DIR}/heatmaps", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/lineplot", exist_ok=True)

In [None]:
def make_heatmap(results : dict, our_results : dict, title : str, baseline : str, minimizing = True, add_whole_family_value = True):
    values = np.array([list(item[-2].values()) for train, item in results.items()]).squeeze(axis=-1)
    values[~np.isfinite(values)] = np.nan
    # if add_whole_family_value:
        # entire_family_values = np.array([item[-1] for train, item in results.items()])
        # values = np.hstack([values, entire_family_values[..., None]])
    values = np.vstack([values, our_results['ours']])
    if add_whole_family_value:
        entire_family_values = np.array([item[-1] if item[-1] is not None else np.nan for train, item in results.items()] + [our_results['whole_family']])
        # entire_family_values[~np.isfinite(entire_family_values)] = np.nan
        values = np.hstack([values, entire_family_values[..., None]])
    plt.figure(figsize=(16,9))
    
    ax = sns.heatmap(values, annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=sns.cm.rocket_r if minimizing else sns.cm.rocket, mask=~np.isfinite(values), fmt='.2f')
    
    # if add_whole_family_value:
        # ax.add_patch(Rectangle((0,0), 10, 11, fill=False, edgecolor='white', lw=3)) # White bounding box
    
    subfamily_size = len(results.keys())
    
    best_value = np.inf if minimizing else -np.inf
    for r in range(values.shape[0]):
        row_values = values[r][:-1] if add_whole_family_value else values[r]
        idx = np.nanargmax(row_values) if minimizing else np.nanargmin(row_values)
        cmp = operator.le if minimizing else operator.ge
        if cmp(row_values[idx], best_value):
            best_value = row_values[idx]
            best_rectangle = (idx, r)
        ax.add_patch(Rectangle((idx, r),1,1, fill=False, edgecolor='blue', lw=3))
    ax.add_patch(Rectangle(best_rectangle,1,1, fill=False, edgecolor='green', lw=3))
    
    if add_whole_family_value:
        best_family_idx = np.nanargmin(values[:, -1]) if minimizing else np.nanargmax(values[:, -1])
        ax.add_patch(Rectangle((subfamily_size, best_family_idx),1,1, fill=False, edgecolor='green', lw=3))
    
    ax.set_xlabel("Test")
    ax.set_ylabel("Train")
    plt.yticks(rotation=0) 
    xticks = [f"{i}" for i in range(subfamily_size)]
    if add_whole_family_value:
        xticks += ["Entire family"]
    ax.set_xticklabels(xticks) 
    ax.set_yticklabels([f"{baseline} on {i} ({results[i][1].num_nodes}-FSC)" for i in range(subfamily_size)] + ["Ours: GD on (sub)family"])
    ax.set_title(f"{title}: {baseline} vs Ours ({'lower' if minimizing else 'higher'} is better)")
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/heatmaps/{title}-{baseline}.png")

In [None]:
import math
import random


def make_lineplot(results : dict, title : str = 'Placeholder', minimizing = True, type_of_plot = 'family_trace'):
    fig = plt.figure()
    ax = fig.gca()
    plt.title(f"{title} ({'lower' if minimizing else 'higher'} is better)")
    plt.xlabel("Iteration")
    plt.ylabel("Worst family member value")
    ax.plot(results['gd-normal'][type_of_plot], label='rfPG', color='green')
    
    min_x_normal = np.argmin(results['gd-normal'][type_of_plot]) if minimizing else np.argmax(results['gd-normal'][type_of_plot])
    min_y_normal = results['gd-normal'][type_of_plot][min_x_normal]
    
    plt.plot(results['gd-no-momentum'][type_of_plot], label='rfPG (no momentum)', color='blue')
    
    min_x_no_mom = np.argmin(results['gd-no-momentum'][type_of_plot]) if minimizing else np.argmax(results['gd-no-momentum'][type_of_plot])
    min_y_no_mom = results['gd-no-momentum'][type_of_plot][min_x_no_mom]
    
    
    ax.plot(results['gd-random'][type_of_plot], label='Random rfPG',  color='red')
    
    min_x = np.argmin(results['gd-random'][type_of_plot]) if minimizing else np.argmax(results['gd-random'][type_of_plot])
    min_y = results['gd-random'][type_of_plot][min_x]
    
    mini = min([min_x, min_x_normal, min_x_no_mom])
    maxi = max([min_x, min_x_normal, min_x_no_mom])
    
    def get(x):
        if math.isclose(mini, maxi) and math.isclose(x, maxi):
            return random.randint(-50, 50)
        elif math.isclose(x, maxi):
            return 50
        elif math.isclose(x, mini):
            return -50
        else:
            return 0
    
    ax.annotate(f"{min_y:.2f}",
            xy=(min_x, min_y), xycoords='data',
            xytext=(get(min_x), (50 if minimizing else -50)), textcoords='offset points',
            arrowprops=dict(facecolor='red', shrink=0),
            horizontalalignment='center', verticalalignment='bottom')
    
    ax.annotate(f"{min_y_normal:.2f}",
                xy=(min_x_normal, min_y_normal), xycoords='data',
                xytext=(get(min_x_normal), (50 if minimizing else -50)), textcoords='offset points',
                arrowprops=dict(facecolor='green', shrink=0),
                horizontalalignment='center', verticalalignment='bottom')
    
    ax.annotate(f"{min_y_no_mom:.2f}",
        xy=(min_x_no_mom, min_y_no_mom), xycoords='data',
        xytext=(get(min_x_no_mom), (50 if minimizing else -50)), textcoords='offset points',
        arrowprops=dict(facecolor='blue', shrink=0),
        horizontalalignment='center', verticalalignment='bottom')
    
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/lineplot/{title}.png")
    plt.show()

In [None]:
results['gd-normal'].keys()

In [None]:
ENVS = ['dpm', 'obstacles-10-2', 'avoid', 'obstacles-8-3', 'rover', 'network']
MINIMIZING = [False, True, True, True, False, False]

In [None]:
for env, minimizing in zip(ENVS, MINIMIZING):
    try:
        with open(f"{BASE_DIR}/{env}/gd-experiment.pickle", 'rb') as handle:
            results = pickle.load(handle)
            make_lineplot(results, title=env, minimizing=minimizing)
    except FileNotFoundError as fnfe:
        print(fnfe)

# DPM

timeout = 10s

In [None]:
dpm_our = [504.76099479, 496.6463078,  522.29707964, 471.26597349, 483.96767624,
 463.09737765, 294.67001291, 498.28841493, 510.35243059, 509.71750883]

In [None]:
with open(f"{BASE_DIR}/dpm/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/dpm/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
for i in range(SUBFAMILY_SIZE):
    fsc = subfamily_saynt_results[i][1]
    print(i, fsc.num_nodes)
    for j in range(fsc.num_nodes):
        print(fsc.update_function[j])

In [None]:
with open(f"{BASE_DIR}/dpm/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
ours

In [None]:
make_heatmap(subfamily_saynt_results, ours, "DPM", "Saynt", minimizing=False)

In [None]:
make_heatmap(subfamily_gd_results, ours, "DPM", "GD", minimizing=False)

# OBSTACLES 8 3

timeout = 30s

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    saynt = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    gradient = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-8-3/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
make_heatmap(saynt, ours, "OBSTACLES(8,3)", "Saynt")

## Gradient baseline

In [None]:
make_heatmap(gradient, ours, "OBSTACLES(8,3)", "GD")

# OBSTACLES 10 2

timeout = 10s

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
print("Obstacles(10, 2):")
for i in range(SUBFAMILY_SIZE):
    fsc = subfamily_saynt_results[i][1]
    print("POMDP", i, 'with a', f"{fsc.num_nodes}-FSC. Memory model:", fsc.memory_model)
    # for j in range(fsc.num_nodes):
        # print(fsc.update_function[j])
        # print("Node", j, [list(d.keys())[0] for d in fsc.update_function[j] if len(d) == 1])

In [None]:
with open(f"{BASE_DIR}/obstacles-10-2/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
ours_raw = [27.85853738, 28.71888006, 30.78940948, 28.75065993, 29.3297606,  30.77959397,
 29.48840814, 29.3297606,  28.22878829, 30.14997737]

In [None]:
ours

In [None]:
make_heatmap(subfamily_saynt_results, ours, "OBSTACLES(10,2)", "Saynt")

In [None]:
make_heatmap(subfamily_gd_results, ours, "OBSTACLES(10,2)", "GD")

# AVOID

timeout = 60s

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/ours.pickle", 'rb') as handle:
    ours = pickle.load(handle)

In [None]:
ours

In [None]:
# with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/paynt.pickle", 'rb') as handle:
#     subfamily_paynt_results = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/gradient.pickle", 'rb') as handle:
    subfamily_gd_results = pickle.load(handle)

In [None]:
with open(f"{BASE_DIR}/avoid/{SUBFAMILY_SIZE}/saynt.pickle", 'rb') as handle:
    subfamily_saynt_results = pickle.load(handle)

In [None]:
for i in range(SUBFAMILY_SIZE):
    fsc = subfamily_gd_results[i][1]
    print(i, fsc.update_function)

In [None]:
print("AVOID:")
for i in range(SUBFAMILY_SIZE):
    fsc = subfamily_saynt_results[i][1]
    print("POMDP", i, 'with a', f"{fsc.num_nodes}-FSC. Memory model:", fsc.memory_model)
    # for j in range(fsc.num_nodes):
        # print(fsc.update_function[j])
        # print("Node", j, [list(d.keys())[0] for d in fsc.update_function[j] if len(d) == 1])

In [None]:
l = [51.94674404, 65.42364497, 63.76348781, 98.68745545, 32.52550974, 67.09607064,
 87.41523619, 64.81453459, 53.95332456, 56.50783793]

In [None]:
make_heatmap(subfamily_saynt_results, ours, "AVOID", "Saynt")

In [None]:
# make_heatmap(subfamily_paynt_results, ours, "AVOID", "Paynt")

In [None]:
make_heatmap(subfamily_gd_results, ours, "AVOID", "GD")

# UNIONS

In [None]:
# ENVS = ['dpm', 'network', 'obstacles-8-3', 'obstacles-10-2', 'rover']

for env, minimizing in zip(ENVS, MINIMIZING):
    if 'avoid' in env.lower(): continue
    values = np.zeros((1, 11))
    # print(env)
    with open(f"{BASE_DIR}/{env}/union/union.pickle", 'rb') as handle:
        results = pickle.load(handle)
    values[0, :-1] = results['subfamily']
    values[0, -1]  = results['whole_family']

    print(env, results['subfamily'], max(results['subfamily']) if minimizing else min(results['subfamily']), results['whole_family'], sep='\n')
    # plt.figure()
    # sns.heatmap(values, yticklabels=[env], annot=True, vmin=np.nanmin(values), vmax=np.nanmax(values), cmap=sns.cm.rocket_r if minimizing else sns.cm.rocket, mask=~np.isfinite(values), fmt='.2f')
    # plt.show()