# Notebook to Generate Benchmark Plots

In [49]:
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import json
import glob

plt.style.use("ggplot")

In [411]:
def parse_json_results_to_df(json_results_fname):

    # read in json from fname
    with open(json_results_fname, "r") as f:
        json_data = json.load(f)

    # construct dataframe
    df = pd.DataFrame(json_data["benchmarks"])
    df = pd.concat([df, df["stats"].apply(pd.Series)], axis=1)
    df["test_name"] = df["name"].apply(lambda x: x.split("[")[0])
    df["param_name"] = df["params"].apply(lambda x: list(x.keys())[0])
    df["param_value"] = df["params"].apply(lambda x: list(x.values())[0])

    # select columns to keep
    columns = [
        "test_name",
        "param_name",
        "param_value",
        "min",
        "max",
        "mean",
        "median",
        "stddev",
        "ops",
        "iterations",
        "rounds",
        "extra_info",
        "params",
    ]
    df = df[columns]
    return df


def plot_series(x, y, color="C0", label=""):
    params = x
    mean_times = y
    isnan = np.isnan(mean_times)
    if np.sum(~isnan) > 1:
        m, b = np.polyfit(params[~isnan], mean_times[~isnan], deg=1)
        plt.plot(params, m * params + b, "--", color=color, label=label)
    else:
        m = 0
        b = mean_times[~isnan]
        plt.plot([], [], "--", color=color, label=label)
    ys = np.copy(mean_times)
    ys[isnan] = m * params[isnan] + b  # replace nans with linear prediction
    for x, y, isna in zip(params, ys, isnan):
        plt.scatter(
            x, y, marker="x" if isna else "o", s=80 if isna else 40, color=color
        )


def load_multiple_runs(benchmark_runs):
    runs = []
    for run_name, run_fname in benchmark_runs.items():
        df = parse_json_results_to_df(run_fname)
        cols = df.columns
        cols = cols.insert(0, "run_name")
        df["run_name"] = run_name
        df = df[cols]
        runs.append(df)
    df = pd.concat(runs)
    return df


def plot_across_params(runs_df, show_plots=False, save_to_pdf: bool or str = False):
    if save_to_pdf:
        pdf = PdfPages(save_to_pdf)
    for test_name, test_grp in runs_df.groupby(["test_name"]):
        plt.figure(figsize=(10, 6))
        colors = ["green", "blue", "purple", "orange", "red"]
        # plot time vs. params for different runs (as different series)
        for color, (run_name, run_grp) in zip(colors, test_grp.groupby(["run_name"])):
            plot_series(
                run_grp["param_value"], run_grp["mean"], label=run_name, color=color
            )
        plt.title(test_name)
        plt.xlabel(run_grp["param_name"].iloc[0])
        plt.ylabel("time (s)")
        ax = plt.gca()
        box = ax.get_position()  # Shink current axis by 20% to fit legend in pdf
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        plt.legend(bbox_to_anchor=(1, 1))
        if save_to_pdf:
            pdf.savefig()
        if show_plots:
            plt.show()
        plt.close()
    if save_to_pdf:
        pdf.close()


def plot_across_runs(runs_df, show_plots=False, save_to_pdf: bool or str = False):
    if save_to_pdf:
        pdf = PdfPages(save_to_pdf)
    colors = ["green", "blue", "purple", "orange", "red"][::-1]
    for i, (test_name, test_grp) in enumerate(runs_df.groupby(["test_name"])):
        for param_value, param_grp in test_grp.groupby(["param_value"]):
            plt.figure(figsize=(10, 6))
            run_names = param_grp["run_name"][::-1]
            height = np.nan_to_num(np.array(param_grp["mean"]), -1)[::-1]
            plt.bar(x=run_names, height=height, label=param_grp, color=colors)
            plt.title(f"{test_name} [{param_value}]")
            plt.xticks(rotation=45)
            plt.ylabel("time (s)")
            ax = plt.gca()
            box = ax.get_position()  # Shink current axis by 20% to fit xlabels in pdf
            ax.set_position([box.x0, box.y0 + (box.height * 0.2), box.width, box.height * 0.8])
            if save_to_pdf:
                pdf.savefig()
            if show_plots:
                plt.show()
            plt.close()
    if save_to_pdf:
        pdf.close()


In [412]:
benchmark_runs = {
    "log normalizer (latest)": "/Users/collinschlager/Code/ssm-jax-refactor/tests/timing_comparisons/.benchmarks/Darwin-CPython-3.9-64bit/0006_use_log_normalizer.json",
    "discrete chain (prev1)": "/Users/collinschlager/Code/ssm-jax-refactor/tests/timing_comparisons/.benchmarks/Darwin-CPython-3.9-64bit/0004_scott-refactor.json",
    "m_step_refactor (prev2)": "/Users/collinschlager/Code/ssm-jax-refactor/tests/timing_comparisons/.benchmarks/Darwin-CPython-3.9-64bit/0005_main_pre_refactor.json",
    "components (prev3)": "/Users/collinschlager/Code/ssm-jax-refactor/tests/timing_comparisons/.benchmarks/Darwin-CPython-3.9-64bit/0003_pre_hmm_refactor.json",
    "ssm_v0": "/Users/collinschlager/Code/ssm-jax-refactor/tests/timing_comparisons/ssm_v0_benchmark_tests/.benchmarks/Darwin-CPython-3.9-64bit/0002_ssm-v0-hmm.json",
}

runs_df = load_multiple_runs(benchmark_runs)
plot_across_params(runs_df, save_to_pdf="timing_report_A.pdf")
plot_across_runs(runs_df, save_to_pdf="timing_report_B.pdf")
