In [None]:
import re
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from analytics.app.data.load import list_pipelines
from analytics.app.data.transform import (
    df_aggregate_eval_metric,
    dfs_models_and_evals,
    logs_dataframe,
    patch_yearbook_time,
)

%load_ext autoreload
%autoreload 2

In [None]:
# INPUTS

pipelines_dir = Path("/Users/robinholzinger/robin/dev/eth/modyn-sigmod-data/yearbook/data_selection_50%/logs_agg_patch")
output_dir = Path("/Users/robinholzinger/robin/dev/eth/modyn-2/.analytics.log/.data/_plots")
assert pipelines_dir.exists()
assert output_dir.exists()

In [None]:
pipelines = list_pipelines(pipelines_dir)
max_pipeline_id = max(pipelines.keys())
pipelines

In [None]:
from analytics.app.data.load import load_pipeline_logs

pipeline_logs = {p_id: load_pipeline_logs(p_id, pipelines_dir) for (p_id, (_, p_path)) in pipelines.items()}

In [None]:
type(pipeline_logs[32])

In [None]:
def map_pipeline_names(pipeline_ref: str) -> str:
    stripped = re.sub(
        "_nosched.*",
        "",
        (pipeline_ref.removeprefix("yearbook_yearbooknet_").removeprefix("cglm_")),
    )
    return {
        "full": "Full",
        "rs2wo": "RS2 (w/o)",
        "grad_bts": "DLIS",
        "margin_bts": "Margin",
        "lc_bts": "Least conf.",
        "entropy_bts": "Entropy",
        "rs2w": "RS2",
        "classb": "Class-Bal.",
        "uniform": "Uniform",
        "loss_bts": "Loss",
    }.get(stripped, stripped) + " "

In [None]:
pipeline_ids = [p_id for p_id, (p, _) in pipelines.items()]
composite_model_variant = "currently_trained_model"  # currently_trained_model
patch_yearbook = True
dataset_id = "yearbook_test"
eval_handler = "slidingmatrix"
metric = "Accuracy"

pipelines = {
    int(k): (map_pipeline_names(v[0]), v[1])
    for k, v in pipelines.items()
    if not (v[0].endswith("_r125") or v[0].endswith("_r250"))
}

[(p_id, pname) for p_id, (pname, _) in pipelines.items() if p_id in pipeline_ids]

# Wrangle data

In [None]:
list_df_eval_single: list[pd.DataFrame] = []

for pipeline_id in pipeline_ids:
    df_all = logs_dataframe(pipeline_logs[pipeline_id], pipelines[pipeline_id][0])

    _, _, df_eval_single = dfs_models_and_evals(
        pipeline_logs[pipeline_id], df_all["sample_time"].max(), pipelines[pipeline_id][0]
    )
    list_df_eval_single.append(df_eval_single)

df_adjusted = pd.concat(list_df_eval_single)
df_adjusted

In [None]:
df_adjusted = df_adjusted[
    (df_adjusted["dataset_id"] == dataset_id)
    & (df_adjusted["eval_handler"] == eval_handler)
    & (df_adjusted["metric"] == metric)
]

# in percent (0-100)
df_adjusted["value"] = df_adjusted["value"] * 100
df_adjusted

In [None]:
if patch_yearbook:
    for column in ["interval_start", "interval_center", "interval_end"]:
        patch_yearbook_time(df_adjusted, column)
    patch_yearbook_time(df_all, "sample_time")

In [None]:
df_adjusted = df_adjusted.sort_values(by=["interval_center"])

In [None]:
# Reduce to composite models
df_adjusted = df_adjusted[df_adjusted[composite_model_variant]]
df_adjusted[composite_model_variant].unique()

# Dump Data backup

# Create Plot

In [None]:
# reduce evaluation interval to interval where all policies have evaluations
min_active_eval_center_per_pipeline = (
    df_adjusted[df_adjusted[composite_model_variant]].groupby("pipeline_ref")["interval_center"].min()
)
maximum_min = min_active_eval_center_per_pipeline.max()
print(maximum_min, min_active_eval_center_per_pipeline)

df_adjusted = df_adjusted[df_adjusted["interval_center"] >= maximum_min]
df_adjusted["interval_center"].unique()

In [None]:
df_adjusted["interval_center"] = df_adjusted["interval_center"].astype(str).str.split("-").str[0]

In [None]:
df_adjusted[df_adjusted["pipeline_ref"].str.contains("Full")]

In [None]:
# Aggregate metrics to a scalar value per pipeline
mean_accuracies = df_aggregate_eval_metric(
    df_adjusted,
    group_by=["pipeline_ref", "metric"],
    in_col="value",
    out_col="metric_value",
    aggregate_func="mean",
)

mean_accuracies.sort_values(by=["metric_value"])

In [None]:
mean_accuracies_candidate = mean_accuracies[~mean_accuracies["pipeline_ref"].str.contains("Full")]
mean_accuracy_ref = mean_accuracies[mean_accuracies["pipeline_ref"].str.contains("Full")]

In [None]:
# Create the heatmap
from analytics.plotting.common.common import INIT_PLOT

INIT_PLOT()
# sns.set_theme(style="ticks")
# plt.rcParams['svg.fonttype'] = 'none'
sns.set_style("whitegrid")

FONTSIZE = 20
DOUBLE_FIG_WIDTH = 10
DOUBLE_FIG_HEIGHT = 3.5
DOUBLE_FIG_SIZE = (DOUBLE_FIG_WIDTH, 1.1 * DOUBLE_FIG_HEIGHT)

fig = plt.figure(
    edgecolor="black",
    frameon=True,
    figsize=DOUBLE_FIG_SIZE,
    dpi=300,
)

palette = sns.color_palette("RdBu_r", 10)
palette = [palette[1], palette[1]]
ax = sns.stripplot(
    mean_accuracies_candidate,
    x="pipeline_ref",
    order=mean_accuracies_candidate.sort_values(by="metric_value")["pipeline_ref"],
    y="metric_value",
    hue="pipeline_ref",
    hue_order=mean_accuracies_candidate.sort_values(by="metric_value")["pipeline_ref"],
    palette=palette,
    s=15,
    marker="X",
    legend=False,
)
ax.set(ylim=(82, 93))

# draw horizontal line for "Full" model
plt.axhline(
    y=mean_accuracy_ref["metric_value"].values[0],
    color="dimgrey",
    linestyle="--",
    linewidth=3,
)

plt.text(s="Full data training", x=-0.2, y=mean_accuracy_ref["metric_value"].values[0] - 2, color="dimgrey")


# Set x-axis
plt.xlabel("")
plt.xticks(rotation=45)

# Set y-axis ticks to be equally spaced
plt.ylabel("Mean Accuracy %", labelpad=15)
plt.yticks(
    ticks=[x for x in range(82, 92 + 1, 3)],
    labels=[x for x in range(82, 92 + 1, 3)],
    rotation=0,
)

# Display the plot
plt.tight_layout()
plt.show()

# Save Plot as svg

In [None]:
for img_type in ["png", "svg"]:
    img_path = output_dir / f"scatter_selection_yb.{img_type}"
    fig.savefig(img_path, bbox_inches="tight", transparent=True)