# Evaluate the Subgroups on the Dataset

## Default Values for Papermill Parameters

In [None]:
PARAM_DATA_IN_PATH = "../../data"
PARAM_DATASET_NAME = "OpenML Adult"
PARAM_DATASET_STAGE = 5  # 4=predicted, 5=permuted

PARAM_PATTERNS_IN_PATH = "../outputs/0.7_0.8_picked_pattern.csv"

PARAM_SEED = 0

## Import and Set Parameters

In [None]:
from subroc.datasets.metadata import to_DatasetName
from subroc.datasets.reader import DatasetReader, DatasetStage
from subroc import util

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

PARAM_DATA_IN_PATH = util.prepend_experiment_output_path(PARAM_DATA_IN_PATH)
PARAM_PATTERNS_IN_PATH = util.prepend_experiment_output_path(PARAM_PATTERNS_IN_PATH)

STAGE_OUTPUT_PATH = os.environ.get("STAGE_OUTPUT_PATH", "../outputs")

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Computer Modern Roman",
    "figure.figsize": [5.1483, 5.1483],
})

# Dataset
dataset_reader = DatasetReader(PARAM_DATA_IN_PATH)

DATASET_NAME = to_DatasetName(PARAM_DATASET_NAME)

if DATASET_NAME is None:
    print(f"dataset name '{PARAM_DATASET_NAME}' not supported.")

DATASET_STAGE = DatasetStage(PARAM_DATASET_STAGE)

# read data and preprocess it for the model
(_, test_data), dataset_meta = dataset_reader.read_dataset(DATASET_NAME, DATASET_STAGE)

evaluation_patterns = pd.read_csv(PARAM_PATTERNS_IN_PATH)

rng = np.random.default_rng(PARAM_SEED)

## Evaluate

In [None]:
from subroc.util import print_metric_colored, create_subgroup, from_str_Conjunction
from subroc.quality_functions.sklearn_metrics import soft_classification_metrics
from subroc.quality_functions.base_qf import label_balance_fraction
from subroc.metrics import average_ranking_loss, prc_auc_score

from sklearn.metrics import RocCurveDisplay, precision_recall_curve
from termcolor import cprint

for pattern in ["Dataset"] + list(evaluation_patterns):
    print(f"################ {pattern} ################")

    subgroup = create_subgroup(test_data, from_str_Conjunction(pattern).selectors)
    subgroup_data = test_data[subgroup.representation]

    print(f"cover size: {len(subgroup_data)}")
    print(f"class balance: {label_balance_fraction(subgroup_data[dataset_meta.gt_name])}")
    print(f"NCR: {sum(test_data.loc[subgroup.representation][dataset_meta.gt_name] == 0) / len(test_data.loc[subgroup.representation])}")

    for metric in soft_classification_metrics:
        try:
            test_y_numpy = subgroup_data[dataset_meta.gt_name].to_numpy()
            metric_value = metric(test_y_numpy, subgroup_data[dataset_meta.score_name])
            print_metric_colored(metric.__name__, metric_value)   
        except ValueError:
            cprint(f"{metric.__name__}: ValueError", color="red")
    
    dataset_sorted_by_score = subgroup_data.sort_values(dataset_meta.score_name)
    scores_sorted = dataset_sorted_by_score.loc[:, dataset_meta.score_name]
    gt_sorted_by_score = dataset_sorted_by_score.loc[:, dataset_meta.gt_name]
    sorted_to_original_index = [index for index, _ in dataset_sorted_by_score.iterrows()]
    sorted_subgroup_representation = \
        [subgroup.representation[original_index] for original_index in sorted_to_original_index]
    sorted_subgroup_y_true = gt_sorted_by_score[sorted_subgroup_representation].to_numpy()
    sorted_subgroup_y_pred = scores_sorted[sorted_subgroup_representation].to_numpy()
    print(f"average_ranking_loss: {average_ranking_loss(sorted_subgroup_y_true, sorted_subgroup_y_pred)}")

    print(f"prc_auc_score: {prc_auc_score(subgroup_data[dataset_meta.gt_name], subgroup_data[dataset_meta.score_name])}")

    RocCurveDisplay.from_predictions(subgroup_data[dataset_meta.gt_name], subgroup_data[dataset_meta.score_name], c="black")
    plt.title("ROC Curve")
    plt.xlim(-0.1, 1.1)
    plt.ylim(-0.1, 1.1)
    print(f"{STAGE_OUTPUT_PATH}/{os.path.basename(PARAM_DATA_IN_PATH)}_ROC_{pattern}.pdf")
    plt.savefig(f"{STAGE_OUTPUT_PATH}/{os.path.basename(PARAM_DATA_IN_PATH)}_ROC_{pattern}.pdf")
    plt.show()

    precision, recall, _ = precision_recall_curve(subgroup_data[dataset_meta.gt_name], subgroup_data[dataset_meta.score_name], drop_intermediate=True)
    plt.plot(recall, precision, c="black")
    plt.title("Precision-Recall Curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.xlim(-0.1, 1.1)
    plt.ylim(-0.1, 1.1)
    plt.savefig(f"{STAGE_OUTPUT_PATH}/{os.path.basename(PARAM_DATA_IN_PATH)}_PR_{pattern}.pdf")