# Generating Plots of Metrics Values on search Data vs. holdout-generalizability Data

## Default Values for Papermill Parameters

In [None]:
PARAM_FULL_RESULT_SET_PATH = "../outputs/p_value_augmented_result_set.csv"
PARAM_FILTERED_RESULT_SET_PATH = "../outputs/p_value_filtered_result_set.csv"
PARAM_QF_PATH = "../outputs/interestingness_measure.pickle"
PARAM_PLOT_BASENAME = "generalizability_plot"
PARAM_DATA_IN_PATH = "../../data"
PARAM_MODELS_IN_PATH = "../../models"

PARAM_DATASET_NAME = "OpenML Adult"
PARAM_DATASET_STAGE = None
PARAM_MODEL_NAME = "sklearn_gaussian_nb_adult_4_splits"

PARAM_PLOT_XMIN = 0
PARAM_PLOT_XMAX = 1
PARAM_PLOT_YMIN = 0
PARAM_PLOT_YMAX = 1
PARAM_PLOT_XLABEL = "PRC AUC on Search Data"
PARAM_PLOT_YLABEL = "PRC AUC on Test Data"

## Import and Set Parameters

In [None]:
from subroc.datasets.metadata import to_DatasetName
from subroc.datasets.reader import DatasetReader, DatasetStage, meta_dict
from subroc.model_serialization import deserialize
from subroc.quality_functions.base_qf import PredictionType
from subroc.quality_functions.soft_classifier_target import SoftClassifierTarget
from subroc import util

import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import pysubgroup as ps

# fill environment variables into params
PARAM_FULL_RESULT_SET_PATH = util.prepend_experiment_output_path(PARAM_FULL_RESULT_SET_PATH)
PARAM_FILTERED_RESULT_SET_PATH = util.prepend_experiment_output_path(PARAM_FILTERED_RESULT_SET_PATH)
PARAM_QF_PATH = util.prepend_experiment_output_path(PARAM_QF_PATH)
PARAM_DATA_IN_PATH = util.prepend_experiment_output_path(PARAM_DATA_IN_PATH)
PARAM_MODELS_IN_PATH = util.prepend_experiment_output_path(PARAM_MODELS_IN_PATH)

# get environment variables
STAGE_OUTPUT_PATH = os.environ.get("STAGE_OUTPUT_PATH", "../outputs")

# Dataset
dataset_reader = DatasetReader(PARAM_DATA_IN_PATH)

DATA_OUT_PATH = f"{STAGE_OUTPUT_PATH}/data/processed"
if not os.path.exists(DATA_OUT_PATH):
    os.makedirs(DATA_OUT_PATH)

DATASET_NAME = to_DatasetName(PARAM_DATASET_NAME)

if DATASET_NAME is None:
    print(f"dataset name '{PARAM_DATASET_NAME}' not supported.")

if PARAM_DATASET_STAGE is None:
    DATASET_STAGE = DatasetStage.PROCESSED_MODEL_READY
else:
    DATASET_STAGE = DatasetStage(PARAM_DATASET_STAGE)

# Model
model = deserialize(PARAM_MODELS_IN_PATH, PARAM_MODEL_NAME)

## Get and Preprocess Data

In [None]:
# read data and preprocess it for the model
dataset_meta = meta_dict[DATASET_NAME]

# prepare classification predictions
dataset_meta.prediction_type = PredictionType.CLASSIFICATION_SOFT

search_data = None
holdout_generalizability_data = None
if DATASET_STAGE == DatasetStage.PROCESSED_MODEL_READY:
    search_data = dataset_reader._read_processed(dataset_meta, "model_ready_test.csv", ",")
    holdout_generalizability_data = dataset_reader._read_processed(dataset_meta, "model_ready_holdout_generalizability.csv", ",")

    search_data_x = search_data.loc[:, search_data.columns != dataset_meta.gt_name]
    search_data[dataset_meta.score_name] = model.predict(search_data_x)
    holdout_generalizability_data_x = holdout_generalizability_data.loc[:, holdout_generalizability_data.columns != dataset_meta.gt_name]
    holdout_generalizability_data[dataset_meta.score_name] = model.predict(holdout_generalizability_data_x)

    # save data with predictions
    out_path = DATA_OUT_PATH + "/" + dataset_meta.dataset_dir
    if not os.path.exists(out_path):
        os.mkdir(out_path)

    search_data.to_csv(out_path + "/" + "model_predicted_test.csv", index=False)
    holdout_generalizability_data.to_csv(out_path + "/" + "model_predicted_holdout_generalizability.csv", index=False)
elif DATASET_STAGE == DatasetStage.PROCESSED_MODEL_PREDICTED:
    search_data = dataset_reader._read_processed(dataset_meta, "model_predicted_test.csv", ",")
    holdout_generalizability_data = dataset_reader._read_processed(dataset_meta, "model_predicted_holdout_generalizability.csv", ",")
elif DATASET_STAGE == DatasetStage.PROCESSED_PERMUTED_MODEL_PREDICTED:
    search_data = dataset_reader._read_processed(dataset_meta, "permuted_model_predicted_test.csv", ",")
    holdout_generalizability_data = dataset_reader._read_processed(dataset_meta, "permuted_model_predicted_holdout_generalizability.csv", ",")

# sd objects
target = SoftClassifierTarget(dataset_meta.gt_name, dataset_meta.score_name)

## Read the Full Result Set

In [None]:
full_result_set = pd.read_csv(f"{PARAM_FULL_RESULT_SET_PATH}")
full_result_set

## Read the Filtered Result Set

In [None]:
filtered_result_set = pd.read_csv(f"{PARAM_FILTERED_RESULT_SET_PATH}")
filtered_result_set

## Read and Configure the Interestingness Measure

In [None]:
def read_and_configure_qf(data):
    with open(PARAM_QF_PATH, "rb") as qf_file:
        qf = pickle.load(qf_file)

    if isinstance(qf, ps.GeneralizationAwareQF):
        qf = qf.qf

    # Disable any significance-related changes to the qf value
    qf.subgroup_size_weight = 0
    qf.subgroup_class_balance_weight = 0
    qf.random_sampling_p_value_factor = False
    qf.random_sampling_normalization = False

    # update the representation of the qf-specific constraints if necessary
    if hasattr(qf, "constraints"):
        for constraint in qf.constraints:
            if hasattr(constraint, "update"):
                constraint.update(data)
    
    return qf


search_qf = read_and_configure_qf(search_data)
holdout_generalizability_qf = read_and_configure_qf(holdout_generalizability_data)

## Calculate Plot Points

In [None]:
def calculate_metric_value(pattern, qf, data):
    # sort data and set up some datastructures to access sorted data
    dataset_sorted_by_score = data.sort_values(dataset_meta.score_name)
    scores_sorted = dataset_sorted_by_score.loc[:, dataset_meta.score_name]
    gt_sorted_by_score = dataset_sorted_by_score.loc[:, dataset_meta.gt_name]
    sorted_to_original_index = [index for index, _ in dataset_sorted_by_score.iterrows()]

    # recreate the pysubgroup object for the subgroup with a representation for the dataset
    sel_conjunction = util.from_str_Conjunction(pattern)
    subgroup = util.create_subgroup(data, sel_conjunction.selectors)

    # calculate statistics
    statistics = qf.calculate_statistics(subgroup, target, data)

    # check constraints
    if not ps.constraints_satisfied(
            qf.constraints,
            subgroup,
            statistics,
            data,
    ):
        return np.nan
    
    # get true and predicted labels for subgroup cover
    sorted_subgroup_representation = \
        [subgroup.representation[original_index] for original_index in sorted_to_original_index]
    sorted_subgroup_y_true = gt_sorted_by_score[sorted_subgroup_representation].to_numpy()
    sorted_subgroup_y_pred = scores_sorted[sorted_subgroup_representation].to_numpy()
    
    # compute the metric values
    return qf.metric(sorted_subgroup_y_true, sorted_subgroup_y_pred)


full_search_metric_values = []
full_holdout_generalizability_metric_values = []

for i, result in enumerate(full_result_set.itertuples()):
    full_search_metric_values.append(calculate_metric_value(result.pattern, search_qf, search_data))
    full_holdout_generalizability_metric_values.append(calculate_metric_value(result.pattern, holdout_generalizability_qf, holdout_generalizability_data))

filtered_search_metric_values = []
filtered_holdout_generalizability_metric_values = []

for i, result in enumerate(filtered_result_set.itertuples()):
    filtered_search_metric_values.append(calculate_metric_value(result.pattern, search_qf, search_data))
    filtered_holdout_generalizability_metric_values.append(calculate_metric_value(result.pattern, holdout_generalizability_qf, holdout_generalizability_data))

search_overall_metric_value = calculate_metric_value("Dataset", search_qf, search_data)
holdout_generalizability_overall_metric_value = calculate_metric_value("Dataset", holdout_generalizability_qf, holdout_generalizability_data)

print("full_search_metric_values:", full_search_metric_values)
print("full_holdout_generalizability_metric_values:", full_holdout_generalizability_metric_values)

## Generate the Plot

In [None]:
def scatter_plot(search_metric_values, holdout_generalizability_metric_values, c, linewidths, filtered):
    xs = []
    ys = []
    only_xs = []

    for search_metric_value, holdout_generalizability_metric_value in zip(search_metric_values, holdout_generalizability_metric_values):
        if np.isnan(holdout_generalizability_metric_value):
            only_xs.append(search_metric_value)
            continue
            
        xs.append(search_metric_value)
        ys.append(holdout_generalizability_metric_value)

    plt.scatter(search_metric_values, holdout_generalizability_metric_values, s=30, c=c, marker="x", linewidths=linewidths)

    for x in only_xs:
        if filtered:
            plt.axvline(x, color="black", linestyle="--", linewidth=0.5)
        else:
            plt.axvline(x, color="gray", linestyle="--", linewidth=0.5)



plt.figure(figsize=(4,4))
scatter_plot(full_search_metric_values, full_holdout_generalizability_metric_values, c="gray", linewidths=0.5, filtered=False)
scatter_plot(filtered_search_metric_values, filtered_holdout_generalizability_metric_values, c="black", linewidths=1, filtered=True)
plt.grid(True, which="major", linestyle="dotted")

if PARAM_PLOT_XMIN is not None:
    plt.xlim(left=PARAM_PLOT_XMIN)
if PARAM_PLOT_XMAX is not None:
    plt.xlim(right=PARAM_PLOT_XMAX)
if PARAM_PLOT_YMIN is not None:
    plt.ylim(bottom=PARAM_PLOT_YMIN)
if PARAM_PLOT_YMAX is not None:
    plt.ylim(top=PARAM_PLOT_YMAX)

plt.xlabel(PARAM_PLOT_XLABEL)
plt.ylabel(PARAM_PLOT_YLABEL)

plt.axvline(search_overall_metric_value, color="black", linewidth=0.75)
plt.axhline(holdout_generalizability_overall_metric_value, color="black", linewidth=0.75)

plt.savefig(f"{STAGE_OUTPUT_PATH}/{PARAM_PLOT_BASENAME}.pdf")
plt.close()