# Compute the IoUs of Each Subgroup in a Result Set with Each Subgroup in a List

## Default Values for Papermill Parameters

In [None]:
PARAM_DATA_IN_PATH = "../../data"
PARAM_DATASET_NAME = "OpenML Adult"

PARAM_PATTERNS_IN_PATH = "../outputs/0.7_0.8_picked_pattern.csv"
PARAM_RESULT_SET_IN_PATH = "../outputs/sd_result_set_average_ranking_loss.csv"
PARAM_RESULT_SET_PATTERNS_COLUMN = "pattern"
PARAM_RESULT_SET_SORT_COLUMN = "interestingness"
PARAM_TOP_K = 10

PARAM_SEED = 0

DUMMY_PARAM_CLASS_BALANCE_WEIGHT = 0
DUMMY_PARAM_COVER_SIZE_WEIGHT = 0
DUMMY_PARAM_ENABLE_GENERALIZATION_AWARENESS = "False"
DUMMY_PARAM_MAX_SIZE_FRACTION = 0.006
DUMMY_PARAM_MIN_SIZE_FRACTION = 0.004
DUMMY_PARAM_QF = "prc_auc_score"

## Import and Set Parameters

In [None]:
from subroc.datasets.metadata import to_DatasetName
from subroc.datasets.reader import DatasetReader, DatasetStage
from subroc import util

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

PARAM_DATA_IN_PATH = util.prepend_experiment_output_path(PARAM_DATA_IN_PATH)
PARAM_PATTERNS_IN_PATH = util.prepend_experiment_output_path(PARAM_PATTERNS_IN_PATH)
PARAM_RESULT_SET_IN_PATH = util.prepend_experiment_output_path(PARAM_RESULT_SET_IN_PATH)

STAGE_OUTPUT_PATH = os.environ.get("STAGE_OUTPUT_PATH", "../outputs")

num_figures = 4
venue = "ECAI Main Track"
linewidth = {"ECAI Main Track": 3.40457}[venue] # latex linewidth in inch, tested using the line "linewidth: \printinunitsof{in}\prntlen{\linewidth}"
fontsize = {"ECAI Main Track": 9}[venue]  # tested using the line "fontsize: \makeatletter \f@size \makeatother"
figure_padding = 0.02
estimated_width_ylabel_yticks = 0.523618
figsize = (linewidth - estimated_width_ylabel_yticks)/num_figures - figure_padding*2

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Computer Modern Roman",
    "font.size": fontsize,
    "figure.figsize": [figsize, figsize],
})

# Dataset
dataset_reader = DatasetReader(PARAM_DATA_IN_PATH)

DATASET_NAME = to_DatasetName(PARAM_DATASET_NAME)

if DATASET_NAME is None:
    print(f"dataset name '{PARAM_DATASET_NAME}' not supported.")

DATASET_STAGE = DatasetStage.PROCESSED_PERMUTED_MODEL_PREDICTED

# read data and preprocess it for the model
(_, test_data), dataset_meta = dataset_reader.read_dataset(DATASET_NAME, DATASET_STAGE)

reference_patterns = pd.read_csv(PARAM_PATTERNS_IN_PATH, names=[PARAM_RESULT_SET_PATTERNS_COLUMN], header=None)

# read the result set
try:
    result_set = pd.read_csv(f"{PARAM_RESULT_SET_IN_PATH}")
except pd.errors.EmptyDataError:
    result_set = pd.DataFrame()

rng = np.random.default_rng(PARAM_SEED)

## Sort Result Set

In [None]:
if PARAM_RESULT_SET_SORT_COLUMN is not None:
    result_set.sort_values(by=PARAM_RESULT_SET_SORT_COLUMN, inplace=True, ascending=False)

## Compute IoUs

In [None]:
from subroc.util import create_subgroup, from_str_Conjunction, iou


def get_plot_title():
    if not util.str_to_bool(DUMMY_PARAM_ENABLE_GENERALIZATION_AWARENESS):
        return "baseline"
    
    return fr"$\alpha={DUMMY_PARAM_COVER_SIZE_WEIGHT}, \beta={DUMMY_PARAM_CLASS_BALANCE_WEIGHT}$"


for reference_idx, reference_row in reference_patterns.iterrows():
    ious = []
    for result_set_idx, result_set_row in result_set[:PARAM_TOP_K].iterrows():
        reference_pattern = reference_row[PARAM_RESULT_SET_PATTERNS_COLUMN]
        result_set_pattern = result_set_row[PARAM_RESULT_SET_PATTERNS_COLUMN]

        print(f"################ {reference_pattern} --- VS --- {result_set_pattern} ################")

        reference_subgroup = create_subgroup(test_data, from_str_Conjunction(reference_pattern).selectors)
        reference_subgroup_idx = np.nonzero(reference_subgroup.representation)[0]

        result_set_subgroup = create_subgroup(test_data, from_str_Conjunction(result_set_pattern).selectors)
        result_set_subgroup_idx = np.nonzero(result_set_subgroup.representation)[0]

        ious.append(iou(set(reference_subgroup_idx), set(result_set_subgroup_idx)))
        print(f"IoU = {ious[-1]}")
        print()
    
    xs = list(range(1, len(ious)+1))

    plt.step(xs, ious, where="mid", c="black", linewidth=1)
    plt.scatter(xs, ious, c="black", s=30, marker="|", linewidths=1)
    plt.ylim(0, 1)

    plt.ylabel("IoU")

    plt.xlabel("Rank")
    if len(xs) < 10:
        plt.xticks(xs, xs)
    else:
        plt.xticks(xs[::2], xs[::2])
    plt.gca().set_position([0, 0, 1, 1])
    plt.savefig(f"{STAGE_OUTPUT_PATH}/{os.path.basename(PARAM_DATA_IN_PATH)}_result_set_ious.pdf", bbox_inches="tight", pad_inches=figure_padding)
    