In [None]:
# Setting auto reload of custom functions
%load_ext autoreload
%autoreload 2

import pickle
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import sys

sys.path.append("../")
sys.path.append("../src")

from data import MDSData
from mds_pid import MDSPID

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_stats(df, col):
    mean = df[col].mean()
    median = df[col].median()
    std = df[col].std()

    return (mean, median, std)

def calculate_ranking_of_highest_probability(lists_of_probabilities):
    position_counts = [0]*10
    for probabilities_list in lists_of_probabilities:
        max_index = np.argmax(probabilities_list)
        position_counts[max_index] += 1
       
    total_lists = len(lists_of_probabilities)
    
    ranking = [i/total_lists for i in position_counts]

    return ranking

def get_relative_results_df(results, total_sources):
    df = pd.DataFrame(results)[["total_positive_mi", "redundancy", "union", "synergy", "unique"]]
    df["unique"] = df.apply(lambda row: np.array(row["unique"]) / np.array(row["total_positive_mi"]), axis=1)

    df = df.loc[df["total_positive_mi"] > 0]
    df_out = pd.DataFrame(df[["total_positive_mi", "redundancy", "union", "synergy"]].to_numpy() / df[["total_positive_mi"]].to_numpy(),
                      columns=["total_positive_mi", "redundancy", "union", "synergy"], index=df.index)

    df_out["unique"] = df["unique"].apply(lambda x: sum(x)/len(x))
    df_out["unique_variance"] = df["unique"].apply(lambda unique_values: np.var(unique_values))

    df_out["total_sources"] = total_sources
    
    return df_out

def print_pid_stats(mds_pid_results, file_name, total_sources):
    print(file_name)

    mds_pid_results.print_dataset_prepared_stats()
    print("\n")

    df = get_relative_results_df(mds_pid_results.results, total_sources)

    mean_r, median_r, std_r = get_stats(df, "redundancy")
    mean_u, median_u, std_u = get_stats(df, "union")
    mean_s, median_s, std_s = get_stats(df, "synergy")
    mean_unique, median_unique, std_unique = get_stats(df, "unique")
    mean_unique_var, median_unique_var, std_unique_var = get_stats(df, "unique_variance")

    print(f"Redundancy -- Mean: {mean_r}, Median: {median_r}, Std_dev: {std_r}")
    print(f"Union -- Mean: {mean_u}, Median: {median_u}, Std_dev: {std_u}")
    print(f"Synergy -- Mean: {mean_s}, Median: {median_s}, Std_dev: {std_s}")
    print(f"Unique -- Mean: {mean_unique}, Median: {median_unique}, Std_dev: {std_unique}")
    print(f"Unique variance -- Mean: {mean_unique_var}, Median: {median_unique_var}, Std_dev: {std_unique_var}")

    print("\n")

In [None]:
results_path = "../outputs/results/"

In [None]:
multiNews = ""

for total_sources in range(2,11):
    file_name = f"multi_news_fixedMDS_dataset__sources_{total_sources}_sample_PID.pkl"
    with open(f"{results_path}{file_name}", "rb") as f:
        mds_pid_results = pickle.load(f)

        c_df = get_relative_results_df(mds_pid_results.results, total_sources)
        c_df["dataset"] = "MultiNews"

        if total_sources == 2:
            multiNews = c_df
        else:
            multiNews = pd.concat([multiNews, c_df], ignore_index=True)

In [None]:
for total_sources in range(2,11):
    file_name = f"multi_news_fixedMDS_dataset__sources_{total_sources}_sample_PID.pkl"
    with open(f"{results_path}{file_name}", "rb") as f:
        mds_pid_results = pickle.load(f)        
        print_pid_stats(mds_pid_results, file_name, total_sources)

In [None]:
g = sns.pairplot(multiNews[["redundancy", "unique_variance", "total_sources", "unique"]], hue='total_sources')

In [None]:
g = sns.boxplot(x='total_sources', y='redundancy', data=multiNews)

In [None]:
ranking = {}
source = "MultiNews"
for total_sources in range(2, 11):
    all_unique_lists = multiNews[multiNews["total_sources"] == total_sources]["unique"]

    if not total_sources in ranking:
        ranking[total_sources] = {}
        
    ranking[total_sources][source] = calculate_ranking_of_highest_probability(all_unique_lists)

ranking