# ROM - Evaluate Experiments Automatically

#### Description

The purpose of this notebook is to evaluate ROM offline experiments. 

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import pickle as pkl
import sys
import os
import os.path as path

sys.path.insert(0, path.dirname(os.getcwd()))
from pathlib import Path
from modelling_pkg.config import DATADIR
from utils import *
from glob import glob
import pandas as pd
import numpy as np
from rom import ROMTrainerConfig
import dataclasses
from random import sample
from collections import defaultdict
import math
from warnings import warn
from IPython.display import display

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, PercentFormatter
import seaborn as sns
from matplotlib.ticker import StrMethodFormatter

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    log_loss,
)
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.metrics import average_precision_score

sns.set_theme()

DATADIR = "../" + DATADIR
ROM_PLOT_DIR = "plots"

## Utils

In [6]:
# PREP

res_path = f"{DATADIR}/rom_results/"
discounts = "real", "rom", "mdd"
discount_names = "Real", "ROM", "MDD"


def raise_warning(s, score, cfgns=None):
    global warning_score, warning_scores, cfg_names
    warn(s)
    mult = 1
    if cfgns is not None:
        if isinstance(cfgns, str):
            cfgns = [cfgns]
        for cfgn in cfgns:
            if cfgn in warning_scores:
                warning_scores[cfgn] += score
        mult = len(cfgns)
    score *= mult / len(cfg_names)
    warning_score += score


def get_markers(n_lines):
    markers = []
    i = 0
    while n_lines > 0:
        m = "o^P*"[i]
        n = min(10, n_lines)
        markers.extend([m] * n)
        n_lines -= n
        i += 1
    return [""] + markers + ["d"]


def show_experiments_w_countries(name_pattern=None):
    exp2ctrs = defaultdict(list)
    all_exp = glob(f"{DATADIR}/rom_results/*/*")
    all_exp_names = set(map(lambda p: p.split("\\")[-1], all_exp))
    for p in all_exp:
        exp2ctrs[p.split("\\")[-1]].append(tuple(p.split("\\")[-2].split("~")))
    printmd("**On which countries was each experiment conducted:**\n")
    for e in sorted(exp2ctrs.keys()):
        if name_pattern is None or name_pattern in e:
            printmd(f"**{e}** - {list(map(lambda l: l[0], exp2ctrs[e]))}")
    return exp2ctrs


def show_available_countries(preselected_exp, exp2ctrs):
    ctr_paths_ = glob(f"{res_path}/*")
    ctr_spans = []
    ctr_paths = []

    printmd("**Countries:**\n")
    i = 0
    for p in ctr_paths_:
        ctr_span = ctr, *span = tuple(p.split("\\")[-1].split("~"))
        if preselected_exp is None or ctr_span in exp2ctrs[preselected_exp]:
            ctr_spans.append(ctr_span)
            ctr_paths.append(p)
            print(f"{i:2}.  {ctr} - {span}")
            i += 1
    return ctr_spans, ctr_paths


def show_available_experiments(
    country_selection, ctr_spans, ctr_paths, preselected_exp
):
    ctr_name, *span = ctr_spans[country_selection]

    exp_paths = glob(f"{ctr_paths[country_selection]}/*")
    exp_names = []

    preselected_exp_idx = None

    printmd(f"Selected country: **{ctr_name}**, span: {span}\n\n**Experiments:**\n")
    for i, p in enumerate(exp_paths):
        pre = ""
        exp_name = p.split("\\")[-1]
        if exp_name == preselected_exp:
            preselected_exp_idx = i
            pre = "$$$ "
        cfgns = list(map(lambda p0: p0.split("\\")[-1], glob(f"{p}/*")))
        exp_names.append(exp_name)
        print(f"{pre}{i:2}.  {exp_name}")  #  - {cfgns}')
        for cfgn in cfgns:
            print(f"\t- {cfgn}")

    return ctr_name, span, exp_paths, exp_names, preselected_exp_idx


def show_available_configs(experiment_selection, preselected_exp_idx, exp_names):
    if preselected_exp_idx is not None:
        experiment_selection = preselected_exp_idx
    exp_name = exp_names[experiment_selection]
    printmd(f"Selected experiment: **{exp_name}**\n\n**Configurations:**\n")
    cfg_paths = glob(f"{exp_paths[experiment_selection]}/*")
    cfg_names_all = []
    for i, p in enumerate(cfg_paths):
        cfg_name = p.split("\\")[-1]
        cfg_names_all.append(cfg_name)
        print(f"{i:2}.  {cfg_name}")
    return exp_name, cfg_paths, cfg_names_all


def show_selected_configs(config_selection, cfg_names_all):
    if config_selection:
        cfg_names = [cfg_names_all[cs] for cs in config_selection]
    else:
        cfg_names = cfg_names_all.copy()
        config_selection = list(range(len(cfg_names_all)))
    cfg_hash = hash(tuple(cfg_names)) % 1000
    printmd(f"**Selected configurations (hash: {cfg_hash}):**\n")
    for cfgn in cfg_names:
        print(cfgn)

    return cfg_names, cfg_hash, config_selection


def load_data_and_prep_eval(config_selection, cfg_paths, n_cols=4, n_cols_wide=2):
    global warning_score, warning_scores, cfg_names
    warning_score = 0
    warning_scores = {cfgn: 0 for cfgn in cfg_names}

    clf_dfs = {}
    rcm_dfs = {}
    eval_dfs = {}
    configs = {}
    dicts = []

    for cs in config_selection:
        res_path_ = cfg_paths[cs]

        # cfg = ROMTrainerConfig(**json.load(open(f'{res_path_}/cfg.json', 'r')))  #, experiment='xyz')
        cfg = ROMTrainerConfig.from_json_file(
            f"{res_path_}/cfg.json"
        )  # , experiment='xyz')
        name = cfg.name

        df = pd.read_csv(f"{res_path_}/clf.csv", index_col=0)
        clf_dfs[name] = df.reset_index() if df.index.name else df

        df = pd.read_csv(f"{res_path_}/rom.csv", index_col=0)
        rcm_dfs[name] = df.reset_index() if df.index.name else df

        eval_dfs[name] = pkl.load(open(f"{res_path_}/eval.pkl", "rb"))
        configs[name] = cfg
        cfg_d = dataclasses.asdict(cfg)
        dicts.append(cfg_d)

    gammas = cfg.gammas
    n_configs = len(config_selection)

    df_name = f"{ctr_name} - {span}"
    cfgs_df = pd.DataFrame(dicts).rename(columns={"name": df_name}).set_index(df_name).T

    # highlight differences
    var_cols = []
    for c in cfgs_df.T.columns:
        col = cfgs_df.T[c]
        if isinstance(col[0], list):
            col = col.apply(tuple)
        if col.nunique() > 1:
            var_cols.append(c)

    def custom_style(row):
        if row.name in var_cols:
            return ["background-color: pink"] * len(row.values)
        return [""] * len(row.values)

    display(cfgs_df.style.apply(custom_style, axis=1))

    n_cols = min(n_cols, n_configs)
    n_rows = math.ceil(n_configs / n_cols)
    n_rows_wide = math.ceil(n_configs / n_cols_wide)
    models = cfg_names
    n_models = len(models)

    n_samples = len(rcm_dfs[cfg_names[0]])
    return (
        clf_dfs,
        rcm_dfs,
        eval_dfs,
        n_samples,
        configs,
        (n_cols, n_rows, n_cols_wide, n_rows_wide),
        models,
        n_models,
        gammas,
        n_configs,
    )


# EVAL


def show_selected_evaluation(ctr_name, span, exp_name, cfg_names):
    printmd(
        f'<font size="+2">**Selected evaluation:**</font>\n\nCountry: **{ctr_name}** ({span})\n\nExperiment: **{exp_name}**'
    )  # \n\nConfigurations: **{cfg_names}**')
    cfgns_md = ""
    for cfgn in cfg_names:
        cfgns_md += f"\n- **{cfgn}**"
    printmd(f"Configurations:{cfgns_md}")


def plot_pred_dist_acc_vs_rej():
    global ctr_name, exp_name, n_rows, n_cols, cfg_names, clf_dfs, rcm_dfs, eval_dfs

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, n_rows * 3.5))

    plt.suptitle(
        f"{ctr_name} - exp: {exp_name} | Distribution of predicted probabilities for accepted vs rejected offers"
    )
    axes = axes.flatten()

    for cfgn, ax in zip(cfg_names, axes):
        axes2 = clf_dfs[cfgn].groupby("accepted").pred.plot(kind="kde", ax=ax)
        for ax in axes2:
            ax.set_xlim(0, 1)
            ax.set_ylim(0, None)
            ax.set_title(f"{cfgn}")
            ax.set_xlabel("Predicted probability of acceptance")
        plt.legend(["Rejected", "Accepted"])
    plt.tight_layout()

    # Remove empty subplots if the number of plots is less than the total subplots
    if n_configs < n_rows * n_cols:
        for j in range(n_configs, n_rows * n_cols):
            fig.delaxes(axes[j])


def calc_metrics():
    global ctr_name, exp_name, n_rows, n_cols, cfg_names, clf_dfs, rcm_dfs, eval_dfs
    metrics_array = []
    for cfgn in cfg_names:
        clf_res = clf_dfs[cfgn]
        metrics = dict(
            acceptance_rate=clf_res.accepted.mean(),
            accuracy=accuracy_score(clf_res["accepted"], clf_res["pred_b"]),
            precision=precision_score(clf_res["accepted"], clf_res["pred_b"]),
            recall=recall_score(clf_res["accepted"], clf_res["pred_b"]),
            f1=f1_score(clf_res["accepted"], clf_res["pred_b"]),
            auc_roc=roc_auc_score(clf_res["accepted"], clf_res["pred"]),
            avg_precision=average_precision_score(clf_res["accepted"], clf_res["pred"]),
            logloss=log_loss(clf_res["accepted"], clf_res["pred"]),
        )
        metrics_array.append(metrics)

    return pd.DataFrame(metrics_array, index=cfg_names)


def plot_metrics(df):
    global n_configs
    ax = df.T.plot()
    colors_ = [l.get_c() for l in ax.get_lines()]
    ax.remove()

    n_rows_ = 2
    n_cols_ = 2
    fig, axes = plt.subplots(
        n_rows_, n_cols_, figsize=(n_cols_ * n_configs * 0.8, n_rows_ * 4)
    )

    plt.suptitle(f"{ctr_name} - exp: {exp_name} | Metrics")
    axes = axes.flatten()

    for col, ax in zip(["accuracy", "f1", "auc_roc", "logloss"], axes):
        ax = df[col].plot.bar(color=colors_, rot=25, ax=ax)

        # ax.set_xlabel('False Positive Rate (FPR)')
        # ax.set_ylabel('True Positive Rate (TPR)')
        ax.set_title(col)
        ax.set_ylabel(col)

    plt.tight_layout()
    pass


def plot_roc():
    global n_rows, n_cols, ctr_name, exp_name, cfg_names, clf_dfs, rcm_dfs, eval_dfs
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, n_rows * 4))

    plt.suptitle(f"{ctr_name} - exp: {exp_name} | ROC Curve")
    axes = axes.flatten()

    for cfgn, ax in zip(cfg_names, axes):
        clf_res = clf_dfs[cfgn]

        fpr, tpr, thresholds = roc_curve(clf_res["accepted"], clf_res["pred"])
        roc_auc = auc(fpr, tpr)

        warn_roc_thresh = 0.55
        if roc_auc < warn_roc_thresh:
            raise_warning(f"ROC < {warn_roc_thresh} for {cfgn}", 1, cfgn)

        ax.plot(fpr, tpr, color="blue", lw=2, label="ROC curve (AUC = %0.2f)" % roc_auc)
        ax.plot(
            [0, 1], [0, 1], color="gray", lw=1, linestyle="--"
        )  # Diagonal line for random classifier
        ax.set_xlabel("False Positive Rate (FPR)")
        ax.set_ylabel("True Positive Rate (TPR)")
        ax.set_title(f"{cfgn}")
        ax.legend(loc="lower right")

    plt.tight_layout()

    # Remove empty subplots if the number of plots is less than the total subplots
    if n_configs < n_rows * n_cols:
        for j in range(n_configs, n_rows * n_cols):
            fig.delaxes(axes[j])


def get_random_examples(n_examples):
    global cfg_names, clf_dfs, rcm_dfs, eval_dfs
    example_idxs = sample(range(len(eval_dfs[cfg_names[0]])), n_examples)
    examples = []
    for cfgn in cfg_names:
        examples.append([eval_dfs[cfgn][i] for i in example_idxs])
    return examples


def plot_examples_probability(examples, gamma, suptitle, real_discount):
    global n_rows, n_cols, ctr_name, exp_name, cfg_names, clf_dfs, rcm_dfs, eval_dfs
    n_examples = len(examples[0])
    fig, axes = plt.subplots(
        n_examples, n_configs, figsize=(4 * n_configs, n_examples * 3)
    )

    if suptitle:
        plt.suptitle(
            f"{ctr_name} - exp. {exp_name} - predicted probability of accepting vs discount"
        )

    for j, (cfgn, dfs) in enumerate(zip(cfg_names, examples)):
        for i, (di, df) in enumerate(dfs):
            ax = axes[i, j]
            ax.plot(
                df["discount_shp"],
                df["rom_pa"] * 100,
                marker="o",
                linestyle="-",
                linewidth=1,
                markersize=4,
            )
            ax.set_title(f"{f'{cfgn} - ' if i == 0 else ''}Plot {i + 1}")
            ax.set_xlabel("discount")
            ax.set_ylabel("probability")
            ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))
            if gamma:
                dmin, dmax, dopt = di[gamma]
                ax.axvspan(dmin, dmax, color="orange", alpha=0.2)
                ax.axvline(dopt, c="orange", alpha=0.8, ls="--")
            if real_discount:
                ax.axvline(df.real_discount.iloc[0], c="grey", ls="--")
            ax.set_ylim(0, 100)
            ax.set_xlim(0, 100)
    plt.tight_layout()


def plot_examples_revenue(examples, gamma, suptitle, real_discount):
    global n_rows, n_cols, ctr_name, exp_name, cfg_names, clf_dfs, rcm_dfs, eval_dfs

    n_examples = len(examples[0])
    fig, axes = plt.subplots(
        n_examples, n_configs, figsize=(4 * n_configs, n_examples * 3)
    )
    if suptitle:
        plt.suptitle(f"{ctr_name} - exp. {exp_name} - expected revenue vs discount")

    for j, (cfgn, dfs) in enumerate(zip(cfg_names, examples)):
        for i, (di, df) in enumerate(dfs):
            ax = axes[i, j]
            ax.plot(
                df["discount_shp"],
                df["rom_exp_rev"],
                marker="o",
                linestyle="-",
                linewidth=1,
                markersize=4,
            )
            ax.set_title(f"{f'{cfgn} - ' if i == 0 else ''}Plot {i + 1}")
            ax.set_xlabel("discount")
            ax.set_ylabel("expected revenue")
            if gamma:
                dmin, dmax, dopt = di[gamma]
                ax.axvspan(dmin, dmax, color="orange", alpha=0.2)
                ax.axvline(dopt, c="orange", alpha=0.8, ls="--")
            if real_discount:
                ax.axvline(df.real_discount.iloc[0], c="grey", ls="--")
            ax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.0f}"))
            ax.set_xlim(0, 100)
    plt.tight_layout()


def plot_interval_sizes():
    #     todo
    global n_rows, n_cols, ctr_name, exp_name, cfg_names, clf_dfs, rcm_dfs, eval_dfs
    fig, axes = plt.subplots(
        n_rows_wide,
        n_cols_wide,
        figsize=(10 * n_cols_wide, n_rows_wide * 4),
        sharex=True,
        sharey=True,
    )

    plt.suptitle(
        f"{ctr_name} - exp: {exp_name} | Sizes of recommended discount intervals"
    )
    axes = axes.flatten()

    for cfgn, ax in zip(cfg_names, axes):
        rcm_res = rcm_dfs[cfgn]

        rcm_res2 = rcm_res.copy()
        for gamma in cfg.gammas:
            rcm_res2[f"Γ = {gamma}"] = (
                rcm_res2[f"rom_discount_{gamma}_high"]
                - rcm_res2[f"rom_discount_{gamma}_low"]
            )
        rcm_res2 = rcm_res2[[f"Γ = {gamma}" for gamma in cfg.gammas]]
        ax = rcm_res2.plot.kde(bw_method=0.25, ax=ax)
        ax.set_xlim(-1, rcm_res2[f"Γ = {gamma}"].quantile(0.97))
        ax.set_title(f"{cfgn}")
        ax.set_xlabel("size of discount interval")

    # Remove empty subplots if the number of plots is less than the total subplots
    if n_configs < n_rows_wide * n_cols_wide:
        for j in range(n_configs, n_rows_wide * n_cols_wide):
            fig.delaxes(axes[j])

    plt.tight_layout()


def plot_discount_distributions():
    global \
        n_rows_wide, \
        n_cols_wide, \
        ctr_name, \
        exp_name, \
        cfg_names, \
        clf_dfs, \
        rcm_dfs, \
        eval_dfs
    fig, axes = plt.subplots(
        n_rows_wide, n_cols_wide, figsize=(8 * n_cols_wide, n_rows_wide * 4)
    )

    plt.suptitle(f"{ctr_name} - exp. {exp_name} - Distribution of discounts")
    axes = axes.flatten()

    for cfgn, rcm_res, ax in zip(cfg_names, rcm_dfs.values(), axes):
        cfgn_warned = False

        n = len(rcm_res)
        # check potential issues
        real0 = (rcm_res.real_discount < 3).sum()
        real0pct = real0 / n * 100
        rec_opt0 = (rcm_res.rom_discount < 3).sum()
        rec_opt0pct = rec_opt0 / n * 100
        if rec_opt0 > 2 * real0 and rec_opt0pct > real0pct + 1.2:
            raise_warning(
                f"0% discounts more frequently recommended than real discounts - {rec_opt0} vs {real0}/{len(rcm_res)} - {cfgn}",
                10,
                cfgn,
            )
            cfgn_warned = True

        # ax = rcm_res[['real_discount', 'rom_discount']].plot.kde(ax=ax, bw_method=0.25)
        ax = rcm_res[["real_discount", "rom_discount"]].plot.kde(
            ax=ax, bw_method=0.25, color=["b", "r"]
        )
        # ax = rcm_res[['real_discount', 'rom_discount', 'mdd_discount']].plot.kde(ax=ax, bw_method=0.25)
        ax.set_xlim((0, 100))
        ax.set_xlabel("Discount")
        # ax.legend(['Real', 'ROM', 'ROM - weighted intervals'])
        ax.legend([f"Real (zero-{real0pct:.1f}%)", f"ROM (zero-{rec_opt0pct:.1f}%)"])

        real_med = rcm_res.real_discount.median()
        rom_med = rcm_res.rom_discount.median()
        med_diff = round(abs(real_med - rom_med))
        med_diff_thresh = 9
        if med_diff > med_diff_thresh:
            score = min((med_diff - 9) ** 2, 50)
            raise_warning(
                f"medians of discounts differ too much! diff={med_diff} - {cfgn}",
                score,
                cfgn,
            )
            ax.text(
                16,
                ax.get_ylim()[1] * 0.5,
                "median",
                fontsize=20,
                weight="extra bold",
                color="b",
                horizontalalignment="left",
                verticalalignment="top",
            )
        ax.axvline(real_med, color=ax.get_lines()[0].get_c(), alpha=0.5, ls="--")
        ax.axvline(rom_med, color="r", alpha=0.5, ls="--")

        ax.set_title(f"{cfgn} | med_diff={round(med_diff)}")

        if cfgn_warned:
            ax.text(
                10,
                ax.get_ylim()[1] * 0.6,
                "0%",
                fontsize=20,
                weight="extra bold",
                color="r",
                horizontalalignment="left",
                verticalalignment="top",
            )

    if n_configs < n_rows_wide * n_cols_wide:
        for j in range(n_configs, n_rows_wide * n_cols_wide):
            fig.delaxes(axes[j])

    plt.tight_layout()


# # How many real discounts within ROM\'s recommended interval?
# gamma = 8
# within = []
# for cfgn in cfg_names:
#     res = rcm_dfs[cfgn]
#     within0 = []
#     within.append(within0)
#     for gamma in gammas:
#         w = len(res[(res.real_discount >= res[f'rom_discount_{gamma}_low'] - 0.2) & (res.real_discount <= res[f'rom_discount_{gamma}_high'] + 0.2)])
#         within0.append(f'{round(100*w/len(res), 1)}%')
# # print(f'Real discount is within ROM\'s recommended interval in {w:.1f}% cases')
# printmd('**How many real discounts within ROM\'s recommended interval?**')
# pd.DataFrame(within, columns=pd.Index(gammas, name='Γ'), index=cfg_names).T


def plot_difference_between_discounts(show_means, means_alpha=0.3):
    global \
        n_rows_wide, \
        n_cols_wide, \
        ctr_name, \
        exp_name, \
        cfg_names, \
        clf_dfs, \
        rcm_dfs, \
        eval_dfs, \
        models
    fig, axes = plt.subplots(
        n_rows_wide, n_cols_wide, figsize=(8 * n_cols_wide, n_rows_wide * 4)
    )
    plt.suptitle(
        f"{ctr_name} - Distribution of diff between recommended and actual discount"
    )
    axes = axes.flatten()
    m = "rom"
    for cfgn, ax in zip(models, axes):
        cfgn_warned = False
        res = rcm_dfs[cfgn]

        diff = res.groupby("accepted")[f"{m}_diff"]
        diff.plot(kind="kde", ax=ax)
        lim = max(
            np.abs(res[f"{m}_diff"].quantile(0.001)),
            np.abs(res[f"{m}_diff"].quantile(0.999)),
        )
        ax.set_xlim((-lim, lim))
        ax.set_title(cfgn)
        # ax.axvline(-EPS, c='grey', ls=':')
        ax.axvline(0, c="black", ls=":")
        # ax.axvline(EPS, c='grey', ls=':')
        ax.legend(["Rejected", "Accepted"])
        ax.set_xlabel("Discount difference")

        diff_med = diff.median()
        if show_means:
            colors = [l.get_c() for l in ax.get_lines()]
            ax.axvline(diff_med[1], color=colors[1], alpha=means_alpha, ls="--")
            ax.axvline(diff_med[0], color=colors[0], alpha=means_alpha, ls="--")

        # check potential issues
        med0, med1 = diff_med.loc[0], diff_med.loc[1]
        conds = (med0 < -5), (med1 > 5)
        conds2 = (med0 > 19), (med1 < -13)
        if sum(conds) or sum(conds2):
            score = min(
                100,
                (
                    ((abs(med0) - 4) ** 2 if conds[0] else 0)
                    + ((abs(med1) - 4) ** 2 if conds[1] else 0)
                )
                * sum(conds),
            )
            if sum(conds2):
                score += min(
                    100,
                    ((med0 - 19) ** 2 if conds2[0] else 0)
                    + ((abs(med1) - 13) ** 2 if conds2[1] else 0),
                )
            raise_warning(
                f"(+{score:5.1f}) - {cfgn} - Unexpected median value", score, cfgn
            )
            ax.text(
                ax.get_xlim()[0] * 0.8,
                ax.get_ylim()[1] * 0.6,
                f"median{'s' if sum(conds) == 2 else ''} (+{score:.1f})",
                fontsize=20,
                weight="extra bold",
                color="b",
                horizontalalignment="left",
                verticalalignment="top",
            )

        large_diff_thresh = 40
        large_diff_pct = (res[f"{m}_diff"] < -large_diff_thresh).sum() / len(res) * 100
        if large_diff_pct > 3:
            score = 10
            if m == "rom":
                raise_warning(
                    f"{cfgn} - Much smaller recommended discounts than real discounts (diff>{large_diff_thresh}) are frequent ({large_diff_pct:.2f}%)",
                    score,
                    cfgn,
                )
            ax.text(
                ax.get_xlim()[0] * 0.8,
                ax.get_ylim()[1] * 0.7,
                f"small (+{score})",
                fontsize=20,
                weight="extra bold",
                color="r",
                horizontalalignment="left",
                verticalalignment="top",
            )

    if n_models < n_rows_wide * n_cols_wide:
        for j in range(n_models, n_rows_wide * n_cols_wide):
            fig.delaxes(axes[j])
    plt.tight_layout()


# RCR

# def calc_rcr_dfs():
#     global cfg_names, clf_dfs, rcm_dfs, eval_dfs, MDD_EPSS
#     MDD_EPSS = [x//2 for x in gammas]
#     assert MDD_EPSS[0] > 0

#     rcr_dict = defaultdict(list)
#     rcr_dict['gamma'].extend(gammas)
#     ar_dict = defaultdict(list)
#     ar_dict['gamma'].extend(gammas)
#     counts_dict = defaultdict(list)
#     counts_dict['gamma'].extend(gammas)
#     exp_list = []

#     res = rcm_dfs[cfg_names[0]]
#     rcr_real = res.real_rev.sum() / res.pub_rev.sum() * 100
#     rcr_dict['Real'].extend([rcr_real]*len(gammas))
#     ar_real = res.accepted.mean()*100
#     ar_dict['Real'].extend([ar_real]*len(gammas))

#     for cfgn in cfg_names:
#         res = rcm_dfs[cfgn]

#         for gamma in gammas:
#             close = res[(res.real_discount >= res[f'rom_discount_{gamma}_low']-0.2) & (res.real_discount <= res[f'rom_discount_{gamma}_high']+0.2)]
#             rcr_rom = close.real_rev.sum() / close.pub_rev.sum() * 100
#             rcr_dict[cfgn].append(rcr_rom)
#             counts_dict[cfgn].append(len(close))
#             ar_dict[cfgn].append(close.accepted.mean()*100)

#             rcr_exp = close.rom_exp_rev.sum() / close.pub_rev.sum() * 100
#             exp_list.append(pd.Series(dict(name=cfgn, gamma=gamma, rcr=rcr_rom, calc='real')))
#             exp_list.append(pd.Series(dict(name=cfgn, gamma=gamma, rcr=rcr_exp, calc='predicted')))

#     m = 'mdd'
#     for eps in MDD_EPSS:
#         close = res[np.abs(res[f'{m}_diff']) < eps]
#         rcr_mdd = close.real_rev.sum() / close.pub_rev.sum() * 100
#         rcr_dict['MDD'].append(rcr_mdd)
#         ar_dict['MDD'].append(close.accepted.mean()*100)
#         counts_dict['MDD'].append(len(close))

#         # rcr_dict['Real'].append(rcr_real)

#     # rcr_mdd = close.real_rev.sum() / close.pub_rev.sum() * 100

#     rcr_df = pd.DataFrame(rcr_dict).set_index('gamma')
#     counts_df = pd.DataFrame(counts_dict).set_index('gamma')
#     ar_df = pd.DataFrame(ar_dict).set_index('gamma')
#     exp_df = pd.DataFrame(exp_list)

#     # print(f'Acceptance rate: {res.accepted.mean()*100:.0f}%')

#     return rcr_df, counts_df, ar_df, rcr_real, ar_real, exp_df


def plot_rcr_comparison(
    rcr_df,
    counts_df,
    ar_df,
    rcr_real,
    ar_real,
    ar,
    save_figures,
    check_issues=True,
    subset=None,
):
    global \
        ctr_name, \
        exp_name, \
        cfg_names, \
        cfg_hash, \
        clf_dfs, \
        rcm_dfs, \
        eval_dfs, \
        MDD_EPSS, \
        n_samples

    if subset:
        rcr_df = rcr_df[["Real"] + subset + ["MDD"]]
        ar_df = ar_df[["Real"] + subset + ["MDD"]]
        counts_df = counts_df[subset + ["MDD"]]

    if check_issues:
        below_rcr = rcr_df.drop(columns=["Real", "MDD"]).mean() < rcr_real
        below_rcr = below_rcr[below_rcr].index.to_list()
        if len(below_rcr) > 0:
            raise_warning("ROM's RCR drops below real RCR!", 50, below_rcr)

        below_counts = (counts_df.drop(columns="MDD") / n_samples * 100 < 10).loc[
            gammas[1]
        ]
        below_counts = below_counts[below_counts].index.to_list()
        if len(below_counts) > 0:
            raise_warning("Subset(s) smaller than 10%!", 30, below_counts)

    xticks = [f"{g} ({e})" for g, e in zip(gammas, MDD_EPSS)]
    markers = get_markers(n_configs)

    width = 15 if ar else 10
    n_cols_ = 3 if ar else 2
    fig = plt.figure(figsize=(width, 10))
    gs = fig.add_gridspec(2, n_cols_, height_ratios=[2, 3])
    ax1 = fig.add_subplot(gs[0, :])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[1, 1])
    if ar:
        ax4 = fig.add_subplot(gs[1, 2])

    # plot RCR
    ax = rcr_df.plot(ax=ax2)
    ax.set_title("RCR")
    ax.set_xlabel("Γ (ε)")
    ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))
    # ax.set_ylim((0, None))
    ax.set_xticks(gammas, xticks)

    lines = ax.get_lines()
    real_line = lines[0]
    real_line.set_c("hotpink")
    real_line.set_ls("--")

    for m, line in zip(markers, lines):
        line.set_marker(m)
    mdd_line = lines[-1]
    mdd_line.set_c("grey")
    mdd_line.set_ls("--")

    if ar:
        ax.legend([])
    else:
        ax.legend(bbox_to_anchor=(1.04, 1))

    colors = [l.get_c() for l in ax.get_lines()]

    # plot counts
    # todo swap % with counts
    ax = counts_df.plot.line(marker="o", ax=ax3)
    ax.set_title(f"num offers in subset (total: {n_samples})")
    ax.set_xlabel("Γ (ε)")

    for l, c, m in zip(ax.get_lines(), colors[1:], markers[1:]):
        l.set_c(c)
        l.set_marker(m)
    mdd_line = ax.get_lines()[-1]
    mdd_line.set_ls("--")

    # ax.set_ylim((0, None))
    ax.legend([])
    ax.set_xticks(gammas, xticks)
    yticks = ax.get_yticks()
    ax.set_yticks(yticks, [f"{int(y)} ({y / n_samples * 100:.1f}%)" for y in yticks])

    if ar:
        # plot acceptance rate
        ax = ar_df.plot(ax=ax4)
        ax.set_title("Acceptance Rate")
        ax.set_xlabel("Γ (ε)")
        ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))
        # ax.set_ylim((0, None))
        ax.set_xticks(gammas, xticks)

        lines = ax.get_lines()
        real_line = lines[0]
        real_line.set_c("hotpink")
        real_line.set_ls("--")

        for m, line in zip(markers, lines):
            line.set_marker(m)
        mdd_line = lines[-1]
        mdd_line.set_c("grey")
        mdd_line.set_ls("--")
        ax.legend(bbox_to_anchor=(1.04, 1))

    # plot avg rcr
    ax = rcr_df.mean().round(1).plot.bar(rot=15, color=colors, ax=ax1)
    ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))
    ax.bar_label(ax.containers[0])
    ax.set_title("Avg RCR")
    ax.axhline(rcr_real, c="grey", ls=":")

    subset_title_sfx = f" - {len(subset)}/{len(cfg_names)} configs" if subset else ""
    plt.suptitle(f"{ctr_name} - exp. {exp_name}{subset_title_sfx}")
    plt.tight_layout()

    warn_suffix = (
        f"_W{round(warning_score) if warning_score >= 1 else round(warning_score, 1)}"
        if warning_score
        else ""
    )
    subset_suffix = "_sub" if subset else ""
    if save_figures:
        dir_path = f"{ROM_PLOT_DIR}/experiment-rcr/per_exp/{exp_name}"
        fig_path = f"{dir_path}/{ctr_name}_{cfg_hash}{warn_suffix}{subset_suffix}.png"
        Path(dir_path).mkdir(parents=True, exist_ok=True)
        plt.savefig(fig_path)

        dir_path2 = f"{ROM_PLOT_DIR}/experiment-rcr/per_ctr/{ctr_name}"
        fig_path2 = f"{dir_path2}/{exp_name}_{cfg_hash}{warn_suffix}{subset_suffix}.png"
        Path(dir_path2).mkdir(parents=True, exist_ok=True)
        plt.savefig(fig_path2)
        print(f"Saved figure to:\n{fig_path}\n{fig_path2}")


def corr_logloss_rcr(metrics, rcr_df, negative=False):
    metrics = metrics.copy()
    metrics["rcr"] = rcr_df.drop(columns=["Real", "MDD"]).mean()
    metrics["negative_logloss"] = -metrics["logloss"]
    x = "logloss" if not negative else "negative_logloss"
    sns.regplot(metrics, x=x, y="rcr")


### Select best config


def calc_within_score(within):
    # thresh = 25
    # thresh2 = 10
    # return -(thresh - np.minimum(thresh, within))**(1.4) - (thresh2 - np.minimum(thresh2, within))**(1.8)
    thresh = 27
    thresh2 = 12
    thresh3 = 40
    return (
        -((thresh - np.minimum(thresh, within)) ** (1.2))
        - (thresh2 - np.minimum(thresh2, within)) ** (1.7)
        - (thresh3 - np.minimum(thresh3, within)) ** (0.8)
    )


def select_best_cfg(rcr_df, counts_df, subset):
    global n_samples
    rcr = rcr_df[subset]
    counts = counts_df[subset]
    rcr_max = rcr.mean().max()
    rcr_real = rcr_df.Real.iloc[0]
    rcr_score = (rcr.mean() - rcr_real) / (rcr_max - rcr_real) * 100

    within = counts.mean() / n_samples * 100
    within_score = calc_within_score(within)

    score = rcr_score + within_score
    display(
        pd.DataFrame(
            dict(rcr_score=rcr_score, within_score=within_score, score=score)
        ).round(1)
    )
    score_sorted = score.sort_values(ascending=False)
    cfg_best = score_sorted.index[0]
    return cfg_best


def select_best_gamma(rcr_df, counts_df, cfgn):
    global n_samples
    rcr = rcr_df[cfg_best]
    counts = counts_df[cfg_best]

    rcr_max = rcr.max()
    rcr_real = rcr_df.Real.iloc[0]
    rcr_score = (rcr - rcr_real) / (rcr_max - rcr_real) * 100

    within = counts / n_samples * 100

    within_score = calc_within_score(within)

    score = rcr_score + within_score
    display(
        pd.DataFrame(
            dict(rcr_score=rcr_score, within_score=within_score, score=score)
        ).round(1)
    )
    score_sorted = score.sort_values(ascending=False)
    return score_sorted.index[0]

# Select country and experiment

In [None]:
exp2ctrs = show_experiments_w_countries(name_pattern="v07")

<div class="alert alert-block alert-warning"><b>You may preselect an experiment</b></div>

In [None]:
preselected_exp = "v02-trn-period"
preselected_exp = "v02-CH-trnm3"
# preselected_exp = 'augm-basic-ab-02-fixed'
preselected_exp = "v02-different-trn-periods"
# preselected_exp = 'augm-weight-01'
preselected_exp = "v03-onboarding"
preselected_exp = "v05-onboarding-test"
preselected_exp = "v06-2025-onboarding"
preselected_exp = "v07-mck-poc"
# preselected_exp = None

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ctr_spans, ctr_paths = show_available_countries(preselected_exp, exp2ctrs)

<div class="alert alert-block alert-warning"><b>Select country</b></div>

In [None]:
country_selection = 0

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ctr_name, span, exp_paths, exp_names, preselected_exp_idx = show_available_experiments(
    country_selection, ctr_spans, ctr_paths, preselected_exp
)

<div class="alert alert-block alert-warning"><b>Select experiment</b></div>

In [None]:
experiment_selection = 0
# preselected_exp_idx = None

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
exp_name, cfg_paths, cfg_names_all = show_available_configs(
    experiment_selection, preselected_exp_idx, exp_names
)

<div class="alert alert-block alert-warning"><b>Select configurations</b></div>

In [None]:
# config_selection = 4,5,11,12,14,15,16,17,20,21
# config_selection = 0,1,6,7,9,10,11,12,13
# config_selection = 5,12,15,17,21
# config_selection = 0,1,6,7
config_selection = None

# temp
# config_selection = 0,1,5,6,7,10,11,15,16,4,19

# all cm1
# config_selection = list(map(lambda x: x[0], filter(lambda x: 'cm1' in x[1], list(enumerate(cfg_names_all)))))

# weight 0.4
# config_selection = list(map(lambda x: x[0], filter(lambda x: 'cm1' not in x[1] and '0.4' in x[1], list(enumerate(cfg_names_all)))))


# config_selection = list(map(lambda x: x[0], filter(lambda x: any(s in x[1] for s in '0.4 0.6 1.0'.split()), list(enumerate(cfg_names_all)))))

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
cfg_names, cfg_hash, config_selection = show_selected_configs(
    config_selection, cfg_names_all
)

## Load data and prep evaluation

In [None]:
(
    clf_dfs,
    rcm_dfs,
    eval_dfs,
    n_samples,
    configs,
    (n_cols, n_rows, n_cols_wide, n_rows_wide),
    models,
    n_models,
    gammas,
    n_configs,
) = load_data_and_prep_eval(config_selection, cfg_paths, n_cols=4, n_cols_wide=2)
# rcr_df, counts_df, ar_df, rcr_real, ar_real, exp_df = calc_rcr_dfs()

# Eval

In [15]:
# show_selected_evaluation(ctr_name, span, exp_name, cfg_names)

In [None]:
list(rcm_dfs.values())[0].accepted.value_counts()

## Model eval

### ROC

In [None]:
plot_roc()

## Recommended discounts

### Distribution of discounts

- red X - 0% recommended discounts more frequent than 0% real discounts
- blue X - medians of discounts differ too much

In [None]:
ax = plot_discount_distributions()

### Difference between real and recommended discounts

- red X - Much smaller recommended discounts than real discounts
- blue X - Unexpected median value

In [None]:
plot_difference_between_discounts(show_means=True)

## RCR

In [39]:
# plot_rcr_comparison(rcr_df, counts_df, ar_df, rcr_real, ar_real, ar=True, save_figures=True, check_issues=True)

### ε-RCR

In [None]:
# todo improve this code, maybe add to trainer
epsilons = [0.2, 0.5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16]
eps_dict = defaultdict(list)
eps_dict["eps"] = epsilons

res = rcm_dfs[models[0]]
# res = res[res.pub_rev < res.pub_rev.quantile(0.95)]
rev0 = res.real_rev.mean()
acc0 = res.accepted.mean()
rcr0 = res.real_rev.sum() / res.pub_rev.sum() * 100

arcr0 = (res.real_rev / res.pub_rev).mean() * 100

close_dict = {}

for EPS in epsilons:
    for m in models:
        m_ = "rom"
        res = rcm_dfs[m] if m != "MDD" else rcm_dfs[models[0]]

        close = res[np.abs(res[f"{m_}_diff"]) < EPS + 0.2]
        if m_ == "rom":
            close_dict[EPS] = close

        acc1 = close.accepted.mean()
        rcr1 = close.real_rev.sum() / close.pub_rev.sum() * 100
        arcr1 = (close.real_rev / close.pub_rev).mean() * 100
        rcr_inc1 = rcr1 / rcr0 * 100 - 100

        eps_dict[f"{m}_rcr_increase"].append(rcr_inc1)
        eps_dict[f"{m}_rcr"].append(rcr1)
        eps_dict[f"{m}_arcr"].append(arcr1)
        eps_dict[f"{m}_acc_rate"].append(acc1)
        eps_dict[f"{m}_counts"].append(len(close))

eps_df = pd.DataFrame(eps_dict).set_index("eps")
eps_df.head(3)

In [None]:
fig = plt.figure(figsize=(16, 7))

gs = fig.add_gridspec(2, 2, width_ratios=[2, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[:, 1])

# fig, axes = plt.subplots(2, 2, )
# axes = axes.flatten()

ax = eps_df[[f"{m}_rcr" for m in models]].plot(marker="o", ax=ax1)
ax.axhline(rcr0, c="r", ls="--")
ax.set_title("RCR comparison")
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.tick_params(which="both", bottom=True)
# ax.legend(model_names)
ax.set_xlabel("ε")
ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))


ax = eps_df[[f"{m}_acc_rate" for m in models]].plot(marker="o", ax=ax2)
ax.axhline(acc0, c="r", ls="--")
ax.set_title("Acceptance rate")
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.tick_params(which="both", bottom=True)
# ax.legend(model_names)
ax.set_xlabel("ε")

df = eps_df[[f"{m}_counts" for m in models]]
ax = df.plot.bar(ax=ax3)
tick_step = calc_tick_step(df.max().max() * 1.05)
ax.yaxis.set_major_locator(MultipleLocator(tick_step))
ax.yaxis.set_major_formatter("{x:,.0f}")
ax.yaxis.set_minor_locator(MultipleLocator(tick_step // 5))
ax.tick_params(which="both", left=True)
ax.set_title("Number of offers included")
ax.set_ylabel("Number of offers")
# ax.legend(model_names)
ax.set_xlabel("ε")

plt.suptitle(f"{ctr_name} - ε-RCR")
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(16, 7))

ax = eps_df[[f"{m}_rcr" for m in models]].plot(marker="o", ax=ax, lw=2)
ax.axhline(rcr0, c="r", ls="--")
ax.set_title("RCR comparison")
ax.xaxis.set_minor_locator(MultipleLocator(1))
ax.tick_params(which="both", bottom=True)
# ax.legend(model_names)
ax.set_xlabel("ε")
ax.yaxis.set_major_formatter(PercentFormatter(decimals=0))

plt.suptitle(f"{ctr_name} - ε-RCR")
plt.tight_layout()