In [None]:
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn
import torch

In [None]:
runs_dir = Path.cwd().parent / 'runs'
assert runs_dir.exists() and runs_dir.is_dir()
for p in runs_dir.iterdir():
    print(p)

In [None]:
from typing import Dict, List
import sys
sys.path.insert(0, str(Path.cwd().parent))
import scripts

import matplotlib as mpl
from matplotlib.ticker import FuncFormatter

def lr_schedule(current: int, last_batch: int):
    if current < 0.1 * last_batch:
        return 1e-2
    elif current < 0.75 * last_batch:
        return 1e-3
    else:
        return 1e-4

FONT_SIZE = 22

FIVE_THIRTY_EIGHT = {
        "Original": "#000000",
        "SDN": "#30a2da",
        "PABEE": "#fc4f30",
        "ZTW": "#e5ae38",
        "Stacking": "#6d904f",
        "Ensembling": "#810f7c",
}


def running_average(xs: List[float], alpha: float):
    res = [xs[0]]
    for i, x in enumerate(xs[1:]):
        res.append(alpha * x + (1 - alpha) * res[-1])
    return res

def choices_to_cost(chosen_ics: Dict[int, int], total_ops: Dict[int, torch.Tensor]) -> float:
    total_samples = 0
    summed_cost = 0.0
    # print(f'chosen_ics: {chosen_ics} total_ops: {total_ops}')
    for k, v in chosen_ics.items():
        total_samples += v
        summed_cost += v * total_ops[k].item()
    return summed_cost / total_samples

def mark_orig_result(data: Dict, total_ops: Dict[int, torch.Tensor], ax: matplotlib.axes.SubplotBase, name: str, color):
    ts = np.array(sorted(data.keys()))
    assert len(ts) == 1
    if isinstance(total_ops, np.ndarray):
        cost = np.array([total_ops[0]])
    elif isinstance(total_ops, dict):
        cost = np.array([total_ops[0]])
    else:
        cost = np.array([total_ops])
    ret_mean = np.array([data[t][0] for t in ts])
    ret_std = np.array([data[t][1] for t in ts])
    print(f'name: {name} cost: {cost} ret_mean: {ret_mean} ret_std: {ret_std}')
    print(f'name: {name} mean cost: {cost.mean()} averaged mean return: {ret_mean.mean()} averaged std: {ret_std.mean()}')
    ax.scatter(cost, ret_mean, marker='X', label=name, color='black', s=250, zorder=3, linewidths=0.)
    ax.errorbar(cost, ret_mean, yerr=ret_std, ecolor=color, alpha=0.5)
    return cost[0]

def draw_time_return_for_thresholds(data: Dict, total_ops: Dict[int, torch.Tensor], ax: matplotlib.axes.SubplotBase, name: str, color):
    alpha = 0.25
    ts = np.array(sorted(data.keys()))
    cost = np.array([choices_to_cost(data[t][2], total_ops) for t in ts])
    plot_data = [(cost[i], data[t][0], data[t][1]) for i, t in enumerate(ts)]
    plot_data = sorted(plot_data, key=lambda x: x[0])
    cost = np.array([d[0] for d in plot_data])
    ret_mean = [d[1] for d in plot_data]
    ret_mean = running_average(ret_mean, alpha=alpha)
    ret_mean = np.array(ret_mean)
    ret_std = [d[2] for d in plot_data]
    ret_std = running_average(ret_std, alpha=alpha)
    ret_std = np.array(ret_std)
    heads = [data[t][2] for t in ts]
    print(f'name: {name} mean cost: {cost.mean()} averaged mean return: {ret_mean.mean()} averaged std: {ret_std.mean()}')
#     ax.scatter(cost, ret_mean, label=name, color=color)
#     ax.errorbar(cost, ret_mean, yerr=ret_std, alpha=0.5, fmt='none', ecolor=color)
    ax.plot(cost, ret_mean, label=name, color=color)
    ax.fill_between(cost, ret_mean - ret_std, ret_mean + ret_std, alpha=0.3, color=color)

def draw_time_return_for_ics(data: List, total_ops: Dict[int, torch.Tensor], ax: matplotlib.axes.SubplotBase, name: str, color):
    print(f'data: {data}')
    print(f'total_ops: {total_ops}')
    cost = np.array([total_ops[i] for i in range(len(data) - 1)])
    ret_mean = np.array([e[0] for e in data[:-1]])
    ret_std = np.array([e[1] for e in data[:-1]])
    ax.scatter(cost, ret_mean, marker='X', label=f'{name} IC', color=color, s=200, zorder=3, edgecolors='black', linewidths=1)
    ax.errorbar(cost, ret_mean, yerr=ret_std, ecolor=color, alpha=0.5, fmt='none')

def y_fmt(y, pos):
    decades = [1e9, 1e6, 1e3, 1e0, 1e-3, 1e-6, 1e-9 ]
    suffix  = ["G", "M", "k", "" , "m" , "u", "n"  ]
    if y == 0:
        return str(0)
    for i, d in enumerate(decades):
        if np.abs(y) >=d:
            val = y/float(d)
            signf = len(str(val).split(".")[1])
            if signf == 0:
                return '{val:d} {suffix}'.format(val=int(val), suffix=suffix[i])
            else:
                if signf == 1:
                    if str(val).split(".")[1] == "0":
                       return '{val:d} {suffix}'.format(val=int(round(val)), suffix=suffix[i]) 
                tx = "{"+"val:.{signf}f".format(signf = signf) +"} {suffix}"
                return tx.format(val=val, suffix=suffix[i])
    return y
    
    
def plot_time_return_tradeoff(results_dirs: List[Path], names: List[str], title: str):
    seaborn.set_style('whitegrid')
    current_palette = FIVE_THIRTY_EIGHT
    fig, ax = plt.subplots(1, 1, figsize=(15, 9))
    for i, r_dir in enumerate(results_dirs):
        thresholds_results_path = r_dir / 'eval_results'
        ic_results_path = r_dir / 'eval_ics'
        ops_path = r_dir / 'total_ops'
        settings_path = r_dir / 'settings'
        threshold_data = torch.load(thresholds_results_path)
        if ic_results_path.exists():
            ic_data = torch.load(ic_results_path)
        total_ops = torch.load(ops_path)
        settings = torch.load(settings_path)
        if isinstance(total_ops, np.float64) or len(total_ops.keys()) == 1:
            baseline_ops = mark_orig_result(threshold_data, total_ops, ax, names[i], current_palette[names[i]])
        else:
            # draw_time_return_curve(threshold_data, total_ops, ax, names[i], current_palette[names[i]])
            draw_time_return_for_thresholds(threshold_data, total_ops, ax, names[i], current_palette[names[i]])
            if ic_results_path.exists():
                draw_time_return_for_ics(ic_data, total_ops, ax, names[i], current_palette[names[i]])
    env_id = settings.env_id
    ax.legend(loc='lower right', prop={'size': FONT_SIZE - 2})
    # ax.set_title(settings.env_id)
    # ax.set_title(str(results_dirs[0].name), fontdict={'fontsize': FONT_SIZE + 1})
    ax.set_title(env_id, fontdict={'fontsize': FONT_SIZE + 1})
    ax.set_xlabel('Inference Time', fontsize=FONT_SIZE)
    ax.set_ylabel('Return', fontsize=FONT_SIZE)
#     ax.set_xlim(right=1.1 * baseline_ops)
    #
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(baseline_ops / 4))
    ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=baseline_ops))
    ax.yaxis.set_major_formatter(FuncFormatter(y_fmt))
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4) 
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4) 
    fig.show()
    # plt.show()

In [None]:
import traceback

base_names = [str(p) for p in runs_dir.glob('kld_small_IC*v?#0')]
for base_name in base_names:
    dir_list = [runs_dir / f'{base_name}', runs_dir / f'{base_name}_stacking#0_rensb#0']
    names = ['Original', 'ZTW']
    print(f'base_name: {base_name}')
    try:
        plot_time_return_tradeoff(dir_list, names, title)
    except:
        traceback.print_exc()