In [None]:
from datasets import load_dataset
import pandas as pd

import pickle

import sys
sys.path.append("../../")
sys.path.append("../../src")

from mds_pid import MDSPID

In [None]:
def get_stats(df, col):
    mean = df[col].mean()
    median = df[col].median()
    std = df[col].std()

    return (mean, median, std)

def print_pid_stats(df):
   
    mean_r, median_r, std_r = get_stats(df, "redundancy")
    mean_u, median_u, std_u = get_stats(df, "union")
    mean_s, median_s, std_s = get_stats(df, "synergy")

    df["unique_total"] = df["unique"].apply(lambda x: sum(x)/len(x))
    mean_unique, median_unique, std_unique = get_stats(df, "unique_total")

    print(f"Redundancy -- Mean: {mean_r}, Median: {median_r}, Std_dev: {std_r}")
    print(f"Union -- Mean: {mean_u}, Median: {median_u}, Std_dev: {std_u}")
    print(f"Synergy -- Mean: {mean_s}, Median: {median_s}, Std_dev: {std_s}")
    print(f"Unique -- Mean: {mean_unique}, Median: {median_unique}, Std_dev: {std_unique}")

In [None]:
dataset = load_dataset("mtc/multirc_train_all_answers_and_random")

df_dataset = pd.DataFrame(dataset["train"])
df_dataset["n_docs"] = df_dataset["document"].apply(lambda x: x.count("|||||")+1)

only_2_sources = df_dataset[df_dataset["n_docs"]==2].reset_index(drop=True)
only_3_sources = df_dataset[df_dataset["n_docs"]==3].reset_index(drop=True)
only_4_sources = df_dataset[df_dataset["n_docs"]==4].reset_index(drop=True)
only_6_sources = df_dataset[df_dataset["n_docs"]==6].reset_index(drop=True)

sources_combined = pd.concat([only_2_sources, only_3_sources, only_4_sources, only_6_sources], ignore_index=True)

In [None]:
multirc = ""
results_path = "multirc/"
for total_sources in range(2, 7):
    if total_sources == 5:
        continue

    file_name = f"multiRC_MDS_fixedMDS_dataset__sources_{total_sources}_PID.pkl"
    with open(f"{results_path}{file_name}", "rb") as f:
        mds_pid_results_df_multirc = pickle.load(f)

        c_df = pd.DataFrame(mds_pid_results_df_multirc.results)[["total_positive_mi", "redundancy", "union", "synergy", "unique"]]
        
        if total_sources == 2:
            multirc = c_df
        else:
            multirc = pd.concat([multirc, c_df], ignore_index=True)

In [None]:
df_concat = pd.concat([sources_combined, multirc], axis=1)

In [None]:
for split in [0, 1, -1]:
    df_isAnwer = df_concat[df_concat["isAnswer"] == split]
    print(f"split: {split}")
    print_pid_stats(df_isAnwer)