# Experiment 03: Plan Space

Process this notebook like so to generate the PDF output:

```bash
jupyter execute --inplace 03-Plan-Space.ipynb
jupyter nbconvert --to pdf 03-Plan-Space.ipynb
```

# Internals

The cells in this section can be ignored in the PDF output. They perform the technical aspects of the data analysis.
This mainly involves creating distortion plots (Figure 3 in the original paper) and determining the best replacements for
the original plots.

Please take a look at the following sections to see the actual outputs.

In [None]:
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
from matplotlib import figure

from postbound.db import db, postgres
from postbound.experiments import workloads
from postbound.optimizer import jointree
from postbound.optimizer.policies import cardinalities
from postbound.qal import qal
from postbound.util import collections as collection_utils
from postbound.util import jsonize

In [None]:
results_base = Path("/ari/results/experiment-03-plan-space-analysis/")
imperfect_results_base = Path("/ari/results/experiment-04-base-join-impact/")
output_dir = Path("/ari/results/eval/experiment-03-plan-space/")
imperfect_out_dir = Path("/ari/results/eval/experiment-04-base-join-impact/")
output_dir.mkdir(parents=True, exist_ok=True)
imperfect_out_dir.mkdir(parents=True, exist_ok=True)

static_results_base = Path("/ari/results/00-base")
workloads.workloads_base_dir = "/ari/postbound/workloads"
plt.rcParams["figure.figsize"] = (20, 5)
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
sns.set_style("whitegrid")
sns.set_context("talk")

In [None]:
def load_pg_explain(raw_explain: str) -> Optional[postgres.PostgresExplainPlan]:
    plan_json = json.loads(raw_explain)
    if not plan_json:
        return None
    return postgres.PostgresExplainPlan(plan_json)


def load_results(workload: workloads.Workload) -> Optional[pd.DataFrame]:
    result_df = pd.DataFrame()
    workload_name = workload.name.lower()

    for label in workload.labels():
        data_file = results_base / workload_name / f"plan-space-analysis-{label}.csv"
        if not data_file.exists():
            continue

        current_df = pd.read_csv(data_file, converters={"query_plan": load_pg_explain})
        current_df["estimated_cost"] = current_df["query_plan"].map(
            lambda plan: plan.explain_data["Plan"]["Total Cost"]
        )
        current_df["plan_hash"] = current_df["query_plan"].map(hash)

        result_df = pd.concat([result_df, current_df], ignore_index=True)

    if result_df.empty:
        return None

    result_df["label"] = pd.Categorical(
        result_df["label"], categories=workload.labels(), ordered=True
    )

    return result_df


def prediction_error(samples: pd.DataFrame) -> pd.Series:
    fit = scipy.stats.linregress(samples["estimated_cost"], samples["runtime"])
    prediction = fit.slope * samples["estimated_cost"] + fit.intercept
    error = np.abs(samples["runtime"] - prediction) / samples["runtime"]

    return pd.DataFrame(
        {"predicted_runtime": prediction, "prediction_error": error},
        index=samples.index,
    )


def make_correlation_plots(
    df: pd.DataFrame | None, *, workload: workloads.Workload
) -> None:
    if df is None:
        return

    for label in workload.labels():
        samples = df.query("label == @label and ~timeout")
        corr_value = round(
            scipy.stats.pearsonr(
                samples["estimated_cost"], samples["runtime"]
            ).statistic,
            2,
        )
        title = f"{workload.name} query {label} [Pearson-r = {corr_value}]"

        fig, ax = plt.subplots()
        g = sns.regplot(
            samples,
            x="estimated_cost",
            y="runtime",
            scatter=True,
            scatter_kws={"edgecolors": "white"},
            ax=ax,
        )
        g.set_xscale("log", subs=[])
        g.set(xlabel="Estimated cost [log]", ylabel="Execution time [s]", title=title)

        out_file = output_dir / workload.name.lower() / f"cost-runtime-corr-{label}.pdf"
        out_file.parent.mkdir(parents=True, exist_ok=True)
        fig.tight_layout()
        fig.savefig(out_file)
        plt.close(fig)


def prediction_error_df(df: pd.DataFrame) -> pd.DataFrame:
    err_df = (
        df.query("~timeout")
        .groupby("label", as_index=False, observed=True)
        .apply(prediction_error, include_groups=False)
        .reset_index(level=1)
    )

    return df.merge(err_df, left_index=True, right_on="level_1").drop(columns="level_1")


def lookup_base_joins(
    qep: db.QueryExecutionPlan | postgres.PostgresExplainPlan,
) -> set[db.QueryExecutionPlan]:
    qep = (
        qep if isinstance(qep, db.QueryExecutionPlan) else qep.as_query_execution_plan()
    )
    return (
        {qep}
        if qep.is_base_join()
        else collection_utils.set_union(
            lookup_base_joins(child) for child in qep.children
        )
    )


def make_base_join_df(df: pd.DataFrame | None) -> Optional[pd.DataFrame]:
    if df is None:
        return None

    df["base_join"] = (
        df["query_plan"]
        .map(lookup_base_joins)
        .map(lambda joins: {frozenset(join.tables()) for join in joins})
    )
    df = df.explode("base_join")
    df["join_label"] = df["base_join"].map(
        lambda join: " ⋈ ".join(tab.identifier() for tab in sorted(join))
    )
    return df


def make_base_join_plots(
    df: pd.DataFrame | None, *, workload: workloads.Workload
) -> dict[str, figure.Figure]:
    if df is None:
        return {}

    plots: dict[str, figure.Figure] = {}
    for label in workload.labels():
        samples = df.query("label == @label").sort_values(by="join_label")

        fig, ax = plt.subplots()
        g = sns.scatterplot(samples, x="runtime", y="join_label", ax=ax)
        g.set(
            xlabel="Plan runtime [s]",
            ylabel="Base join",
            title=f"{workload.name} query {label}",
        )

        out_file = output_dir / workload.name.lower() / f"base-joins-{label}.pdf"
        out_file.parent.mkdir(parents=True, exist_ok=True)
        fig.tight_layout()

        plots[label] = fig
        fig.savefig(out_file)
        plt.close(fig)

    return plots


def make_importance_dfs(
    plan_df: pd.DataFrame | None, base_join_df: pd.DataFrame
) -> tuple[pd.DataFrame, pd.DataFrame]:
    if plan_df is None:
        return None, None

    # First, compute the minimum/maximum runtimes and the top 25% threshold for each query
    rt_summary = (
        plan_df.groupby(  # don't use base_join_df here, it contains duplicates which skew the quantile!
            "label", as_index=False, observed=False
        )
        .agg(
            min_rt=pd.NamedAgg(column="runtime", aggfunc="min"),
            max_rt=pd.NamedAgg(column="runtime", aggfunc=lambda rts: rts.quantile(0.9)),
        )
        .assign(
            top25_rt=lambda sample: 0.75 * sample["min_rt"] + 0.25 * sample["max_rt"]
        )
    )  # unrolled form: min + 0.25 * (max - min)

    # Now, determine how many of the execution plans are in the top 25% for each query
    top25_plans = (
        plan_df.merge(  # see comment above: never use base_join_df here
            rt_summary, on="label"
        )
        .query("runtime <= top25_rt")
        .groupby(
            ["label", "top25_rt"], as_index=False, observed=True
        )  # top25_rt is dependent, we just carry it along
        .size()
        .rename(columns={"size": "total_top25_plans"})
    )

    # We are ready to compute the F1 score for each base join
    join_importance_df = (
        base_join_df.merge(top25_plans, on="label")
        .assign(top25_indicator=lambda sample: sample["runtime"] <= sample["top25_rt"])
        .groupby(
            ["label", "base_join", "join_label", "total_top25_plans"],
            as_index=False,
            observed=True,
        )  # total_top25_plans is dependent, we just carry it along
        .agg(
            base_join_plans=pd.NamedAgg(
                column="plan_hash", aggfunc="nunique"
            ),  # how many plans do we have for this base join? (Theoretically, we don't need to do nunique here since the plans should be unique anyway, but its more expressive..)
            top25_plans=pd.NamedAgg(column="top25_indicator", aggfunc="sum"),
        )  # how many of these plans are in the top 25%?
        .assign(
            precision=lambda sample: sample["top25_plans"] / sample["base_join_plans"],
            recall=lambda sample: sample["top25_plans"] / sample["total_top25_plans"],
            f1_score=lambda sample: 2
            * sample["precision"]
            * sample["recall"]
            / (sample["precision"] + sample["recall"]),
        )
        .sort_values(by=["label", "join_label"])
    )

    # Prepare for the aggregated F1 scores: determine the weighting factor for each base join
    importance_weigths = (
        join_importance_df.sort_values(by="f1_score", ascending=False)
        .groupby("label", as_index=False, observed=True)["join_label"]
        .transform(lambda sample: np.arange(len(sample)) + 1)
        .to_frame()
        .rename(columns={"join_label": "harmonic_weight"})
    )

    # Also, we will need to know which base join was actually the best for each query, so let's just compute this here as well
    max_f1s = (
        join_importance_df.assign(
            max_f1=(
                join_importance_df.groupby("label", observed=True)[
                    "f1_score"
                ].transform("max")
            )
        )
        .query("f1_score == max_f1")
        .rename(columns={"base_join": "best_join"})[["label", "best_join"]]
    )

    # Finally, all that's left to do is aggregate
    harmonic_importance = (
        join_importance_df.merge(importance_weigths, left_index=True, right_index=True)
        .merge(max_f1s, on="label")
        .assign(
            f1_harmonic=lambda sample: 1
            / sample["harmonic_weight"]
            * sample["f1_score"]
        )
        .groupby(["label", "best_join"], as_index=False, observed=True)
        .agg(f1_harmonic=pd.NamedAgg(column="f1_harmonic", aggfunc="sum"))
    )

    return join_importance_df, harmonic_importance


def select_important_base_join_replacement(
    importance_df: pd.DataFrame | None, *, plots: dict[str, figure.Figure]
) -> Optional[figure.Figure]:
    if importance_df is None:
        return None
    selected: str = importance_df.iloc[importance_df["f1_harmonic"].argmax()]["label"]
    return plots[selected]


def select_whatever_base_join_replacement(
    importance_df: pd.DataFrame | None, *, plots: dict[str, figure.Figure]
) -> Optional[figure.Figure]:
    if importance_df is None:
        return None
    selected: str = importance_df.iloc[importance_df["f1_harmonic"].argmin()]["label"]
    return plots[selected]


def make_join_importance_plot(
    importance_df: pd.DataFrame | None,
    *,
    workload: workloads.Workload,
    thresh: float = 0.75,
) -> Optional[figure.Figure]:
    if importance_df is None:
        return None

    fig, ax = plt.subplots()
    g = sns.barplot(
        importance_df.assign(relevant=importance_df["f1_harmonic"] >= thresh),
        x="label",
        y="f1_harmonic",
        hue="relevant",
        hue_order=[True, False],
        ax=ax,
    )
    g.xaxis.set_ticks(workload.labels())
    g.axhline(thresh, color="grey", linestyle="dashed")
    g.set_xticklabels(
        [
            label
            if label.get_text()[-1] == "a"
            and ((int(label.get_text()[:-1]) % 5 == 0) or (i == 0))
            else ""
            for i, label in enumerate(g.get_xticklabels())
        ]
    )
    g.set(xlabel="Query", ylabel="Importance score")
    g.legend().remove()
    fig.tight_layout()

    out_file = output_dir / workload.name.lower() / "base-join-importance.pdf"
    out_file.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_file)
    plt.close(fig)

    return fig


def load_jointree(
    query_plan: postgres.PostgresExplainPlan, *, query: qal.SqlQuery
) -> jointree.LogicalJoinTree:
    qep = query_plan.as_query_execution_plan()
    return jointree.LogicalJoinTree.load_from_query_plan(qep, query=query)


def load_imperfect_results(workload: workloads.Workload) -> Optional[pd.DataFrame]:
    result_df = pd.DataFrame()

    for label, query in workload.entries():
        data_file = (
            imperfect_results_base
            / workload.name.lower()
            / f"base-tab-operator-flexibility-{label}.csv"
        )
        if not data_file.exists():
            continue

        current_df = pd.read_csv(data_file, converters={"query_plan": load_pg_explain})
        current_df["join_order"] = current_df["query_plan"].apply(
            load_jointree, query=query
        )

        result_df = pd.concat([result_df, current_df], ignore_index=True)

    if result_df.empty:
        return None

    result_df["label"] = pd.Categorical(
        result_df["label"], categories=workload.labels(), ordered=True
    )

    return result_df


def make_native_est_impact_df(
    importance_df: pd.DataFrame | None, *, workload: workloads.Workload
) -> Optional[pd.DataFrame]:
    if importance_df is None:
        return None

    imperfect_df = load_imperfect_results(workload)
    important_labels = set(imperfect_df["label"])

    perfect_samples = importance_df[
        importance_df["label"].isin(important_labels)
    ].copy()
    perfect_samples["join_order"] = perfect_samples.apply(
        lambda sample: load_jointree(
            sample["query_plan"], query=workload[sample["label"]]
        ),
        axis="columns",
    )

    native_est_impact = pd.DataFrame()
    for label in important_labels:
        perfect_cards = perfect_samples.query("label == @label")
        imperfect_cards = imperfect_df.query("label == @label")

        current_impact_df = pd.merge(
            perfect_cards[["label", "runtime", "join_order"]],
            imperfect_cards[["label", "runtime", "join_order"]],
            on=["label", "join_order"],
            suffixes=("_perf", "_nat"),
        )
        native_est_impact = pd.concat(
            [native_est_impact, current_impact_df], ignore_index=True
        )

    native_est_impact.sort_values(by="label", inplace=True)
    return native_est_impact


def make_card_impact_plot(
    card_impact_df: pd.DataFrame | None, *, workload: workloads.Workload
) -> Optional[figure.Figure]:
    if card_impact_df is None:
        return None

    min_rt = card_impact_df[["runtime_nat", "runtime_perf"]].min().min()
    max_rt = card_impact_df[["runtime_nat", "runtime_perf"]].max().max()

    fig, ax = plt.subplots()
    ticks = [0.1, 0.5, 1.0, 2.5, 5.0, 10]

    g = sns.scatterplot(card_impact_df, x="runtime_perf", y="runtime_nat", ax=ax)
    g.plot([min_rt, max_rt], [min_rt, max_rt], color="grey", linestyle="dashed")
    g.set_xscale("log")
    g.set_yscale("log")
    g.xaxis.set_ticks(ticks, labels=ticks)
    g.yaxis.set_ticks(ticks, labels=ticks)
    g.set(xlabel="Estimated card. runtime [s]", ylabel="Perfect card. runtime [s]")

    out_file = (
        imperfect_out_dir / workload.name.lower() / "base-join-imperfect-slowdown.pdf"
    )
    out_file.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(out_file)
    plt.close(fig)
    return fig


In [None]:
job = workloads.job()
stats = workloads.stats()

In [None]:
df_job = load_results(job)
df_stats = load_results(stats)

In [None]:
base_joins_job = make_base_join_df(df_job)
base_joins_stats = make_base_join_df(df_stats)

In [None]:
base_join_plots_job = make_base_join_plots(base_joins_job, workload=job)
_ = make_base_join_plots(base_joins_stats, workload=stats)

In [None]:
join_importance_job, harmonic_importance_job = make_importance_dfs(
    df_job, base_joins_job
)
join_importance_stats, harmonic_importance_stats = make_importance_dfs(
    df_stats, base_joins_stats
)

In [None]:
imperfect_df_job = load_imperfect_results(job)
imperfect_df_stats = load_imperfect_results(stats)

In [None]:
native_est_impact_job = make_native_est_impact_df(join_importance_job, workload=job)
native_est_impact_stats = make_native_est_impact_df(
    join_importance_stats, workload=stats
)

# Section 5.1 - cost / runtime correlation

(Mean, median) prediction error on JOB:

In [None]:
errs_job = prediction_error_df(df_job)["prediction_error"]
errs_job.mean(), errs_job.median()

# Section 5.2 - impact of base joins

In [None]:
select_important_base_join_replacement(
    harmonic_importance_job, plots=base_join_plots_job
)

In [None]:
select_whatever_base_join_replacement(
    harmonic_importance_job, plots=base_join_plots_job
)

In [None]:
join_importance_plot = make_join_importance_plot(harmonic_importance_job, workload=job)
_ = make_join_importance_plot(harmonic_importance_stats, workload=stats)

join_importance_plot

# Section 5.2 - impact of imperfect estimates

In [None]:
card_impact_plot = make_card_impact_plot(native_est_impact_job, workload=job)
_ = make_card_impact_plot(native_est_impact_stats, workload=stats)

card_impact_plot