In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
df_ablation = pd.read_csv('../../output/corr_ablation_decoding.csv')
df = pd.read_csv('../../output/corr_mi_tasks_wsh_rouge_2.csv')

df = pd.concat([df, df_ablation])


In [None]:
df = df[~df['metadata/Decoding config'].isna()]

# df = df[~df['metadata/Decoding config'].str.contains('short')]
display(df)
print(df.columns)


# tasks names
tasks_names = [c.split('/')[0] for c in df.columns if "kl" in c]
print(tasks_names)

In [None]:

df['metadata/Dataset name'].unique()

In [None]:


df['source'] = df["Unnamed: 0"]





In [None]:
# print decoding config



# Commmon metrics
        
        




In [None]:

def plot_correlation_matrix(df):
    ROUGES = ["common/rouge2"]
    MI = ['I(summary -> text)', 'I(text -> summary)']
    SHM = [c for c in df.columns if "SHMetric" in c and "proba_1" in c]
    


    map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis",
                 "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
                 "roberta-base-openai-detector" : "GPT detector",
                 }

    classification_tasks = [c + "/proba_of_error" for c in map_tasks.keys()]
    
    

    # heatmap of correlation between MI and ROUGE

    df_corr = df[~df['metadata/Decoding config'].str.contains("short")][ROUGES + SHM + MI + classification_tasks].corr()

    # make diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)




    sns.heatmap(df_corr, annot=True, cmap=cmap, vmin=-1, vmax=1, center=0, square=True, linewidths=.5, cbar_kws={"shrink": .5})



In [None]:
# Make a table with the correlation between MI and ROUGE and the SHmetrics, grouped by dataset


def make_correlation_table(df):
    df =df.copy()

    ROUGES = ["common/rougeLsum"]
    MI = ['I(summary -> text)', 'I(text -> summary)']
    SHM = [c for c in df.columns if "SHMetric" in c and "proba_1" in c]


    map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis",
                 "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
                 "roberta-base-openai-detector" : "GPT detector",
                 }
    
    classification_tasks = [c + "/proba_of_error" for c in map_tasks.keys()]
    
    # make proba_of_error proba_of_success
    df[classification_tasks] = 1 - df[classification_tasks]
    # rename
    df = df.rename(columns={c + "/proba_of_error" : c + "/proba_of_success" for c in map_tasks.keys()})
    
    classification_tasks = [c + "/proba_of_success" for c in map_tasks.keys()]
    
    
    df = df[~df['metadata/Decoding config'].str.contains("short")]
    
    print(df['metadata/Decoding config'].unique())
    
    df = df[~df['metadata/Decoding config'].isin([f"beam_sampling_{k}" for k in [5, 10, 20, 50]])]
    
    display(df)
    datasets = set(df['metadata/Dataset name'].dropna().unique())
    datasets -= set(['peer_read', 'arxiv'])
    
    print(datasets)

    # create a dataframe with the correlation between MI and ROUGE and the SHmetrics, grouped by dataset
    df_corr = pd.DataFrame(columns=['Dataset name', 'Metric', 'Correlation'])

    for dataset in datasets:
        # select dataset
        df_dataset = df[df['metadata/Dataset name'] == dataset]
        df_dataset = df_dataset[ROUGES + SHM + MI + classification_tasks].corr()
        # add dataset name
        df_dataset['Dataset name'] = dataset
        
        # add metric name
        df_dataset['Metric'] = df_dataset.index
        
        # melt dataframe
        df_dataset = df_dataset.melt(id_vars=['Dataset name', 'Metric'], var_name="Correlation", value_name="Value")
        
        # append to main dataframe
        df_corr = df_corr.append(df_dataset)
        

    def rename_metrics(x):
        splits = x.split('/')
        
        if len(splits) == 1:
            if splits[0] == "I(summary -> text)":
                return "$I(T,S)$"
            else:
                return x
        else:
            if splits[0] in  map_tasks.keys():
                return map_tasks[splits[0]]
            else:
                if splits[1] == "rougeLsum":
                    return "\\texttt{ROUGE-L}"
                else:
                    return splits[1]
        
    df_corr = df_corr.pivot(index=['Dataset name', 'Metric'], columns='Correlation', values='Value')

    # Keep shmetric only in columns
    df_corr = pd.concat({'SH.' : df_corr[[c for c in df_corr.columns if "SHMetric" in c]], 'CT.' : df_corr[classification_tasks]}, axis=1)
    

    idx = pd.IndexSlice
    # Select index to be displayed
    df_corr = df_corr.loc[idx[:, ['common/rougeLsum', 'I(summary -> text)'] + SHM], :]
    df_corr = df_corr.dropna()

    # rename columns
    df_corr.columns = pd.MultiIndex.from_tuples([(c[0].replace('_', '-'), rename_metrics(c[1])) for c in df_corr.columns])

    df_corr = df_corr.reset_index()
    # rename Metric
    df_corr[('Metric', '')] = df_corr[('Metric', '')].apply(rename_metrics)

    df_corr = df_corr.set_index(["Dataset name", 'Metric'])
    df_corr = df_corr.sort_index()
    
    # Remove "_" from column names


    return df_corr
    
table = make_correlation_table(df).transpose()


table.columns = pd.MultiIndex.from_tuples([(c[0].replace('_', '-'), c[1]) for c in table.columns])

style = table.style


style = style.format(precision=2)
style = style.format_index(escape="latex", axis=0)

# highlight max for each dataset with bfseries
list_datasets = set(table.columns.get_level_values(0))
list_metrics = set(table.columns.get_level_values(1))
idx = pd.IndexSlice
for dataset in list_datasets:
    style = style.highlight_max(axis=1, subset=(idx[:], idx[dataset, :]), props='bfseries:')
    
# add background gradient
style = style.background_gradient(cmap='viridis', vmin=0.2, vmax=1)
# convert to latex
path = "../../../papers/Mutual-information-for-summarization/tables/correlation_table.tex"
# create parent
Path(path).parent.mkdir(parents=True, exist_ok=True)

display(style)

latex_code = style.to_latex(clines="skip-last;data", sparse_index=True, sparse_columns=True, caption="Correlation between MI and ROUGE, and Seahorse metrics and probability of success of the classifcation task, grouped by datasets for non-trivial decoding strategies. SH. stands for Seahorse metrics and CT. for classification tasks.", label="tab:correlation_table", environment="table*", hrules=True, convert_css=True)

import re

# add a resize box around the tabular
latex_code = re.sub(r"\\begin{tabular}", r"\\resizebox{0.7\\textwidth}{!}{\\begin{tabular}", latex_code)
latex_code = re.sub(r"\\end{tabular}", r"\\end{tabular}}", latex_code)

# add centering to the table environment
latex_code = re.sub(r"\\begin{table\*}", r"\\begin{table*}[h!]\\centering", latex_code)


# save latex code
#with open(path, 'w') as f:
#     f.write(latex_code)

print(latex_code)



In [None]:
sns.set_theme(style="whitegrid")


ROUGES = ["common/rougeLsum"]
MI = ['I(summary -> text)', 'I(text -> summary)']
SHM = [c for c in df.columns if "SHMetric" in c and "proba_1" in c]


map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis",
             "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
             "roberta-base-openai-detector" : "GPT detector",
             }

classification_tasks_error = [c + "/proba_of_error" for c in map_tasks.keys()]
classification_tasks = [c + "/proba_of_success" for c in map_tasks.keys()]

# make proba_of_error proba_of_success
df[classification_tasks] = 1 - df[classification_tasks_error]

def plot_multiple_datasets_correlations(df, COLS, name):
    sns.set_theme(style="whitegrid")
    df =df.copy()


    df = df[~df['metadata/Decoding config'].str.contains("short")]


    df = df[~df['metadata/Decoding config'].isin([f"beam_sampling_{k}" for k in [5, 10, 20, 50]])]

    datasets = set(df['metadata/Dataset name'].dropna().unique())
    datasets -= set(['peer_read', 'arxiv'])
     
     
     
    fig, axes = plt.subplots(len(datasets), len(COLS), figsize=(25, 10), sharey=False, sharex=False, dpi=300)
     
    def rename_cols(x):
        if "SHMetric" in x:
            return x.split('/')[1]
        else:
            return map_tasks[x.split('/')[0]]
     
    for idx, col in enumerate(COLS):
        for didx, ds in enumerate(datasets):
            group = df[df['metadata/Dataset name'] == ds]
            sns.set_theme(style="whitegrid")
            
            
            sns.regplot(data=group, x="I(summary -> text)", y=col, ax=axes[didx,idx], x_ci=None, ci=False, scatter=False, line_kws={'alpha': 0.9, 'linewidth': 5})
            sns.scatterplot(data=group, x="I(summary -> text)", y=col, hue='metadata/Model name', ax=axes[didx,idx], palette='tab10', s=400, markers='o')
            
            
            axes[didx,idx].set_xlabel("")
            if didx == 0:
                axes[didx,idx].set_title(rename_cols(col), fontsize=22, fontweight='bold')

            axes[didx,idx].set_ylabel("")
            if idx == 0:
                axes[didx,idx].set_ylabel(ds, fontsize=22, fontweight='bold')
             
            # make xtick labels bigger
            axes[didx,idx].tick_params(axis='x', labelsize=18)
            axes[didx,idx].tick_params(axis='y', labelsize=18)
            
            

            
            # add grid
            axes[didx,idx].grid(True, which='both', axis='both', linestyle='--')
             
    fig.tight_layout()
    
    # global legend below the figure
    handles, labels = axes[0,0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.1), ncol=3, fontsize=18)
    # remove all legends
    for ax in axes.flatten():
        ax.get_legend().remove()
     
    path = f"../../../papers/Mutual-information-for-summarization/img/multiple_datasets_correlations_{name}.png"
        # create parent
    Path(path).parent.mkdir(parents=True, exist_ok=True)
     
    plt.savefig(path, dpi=300, bbox_inches='tight')
     
     

plot_multiple_datasets_correlations(df, COLS=SHM, name="shmetrics")
plot_multiple_datasets_correlations(df, COLS=classification_tasks, name="classification_tasks")


In [None]:


def plot_classification_tasks_proba_kl(df, dataset):
    df = df[~df['metadata/Decoding config'].str.contains("short")].copy()
    df = df[df['metadata/Dataset name'] == dataset]

    df = df[~df['metadata/Decoding config'].isin([f"beam_sampling_{k}" for k in [5, 10, 20, 50]])]

    map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis",
                 "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
                 "roberta-base-openai-detector" : "GPT detector",
                 }
    # select only the tasks we want

    # create a discrete sequential color palette with viridis


    def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
        sns.regplot(data=data, x=x, y=y, ci=None, scatter=False, ax=ax, x_ci='sd')
        sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette="tab10")
        return ax


    fig, axes = plt.subplots(len(map_tasks), 2, figsize=(10, 7), sharey=False, sharex=False, dpi=300)

    for tidx, task in enumerate(map_tasks.keys()):
        topplot = df
        # rename columns
        topplot = topplot.rename(columns={"metadata/Decoding size": "Decoding size", "metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})

        custom_reg_plot(data=topplot, x="I(summary -> text)", y=f"{task}/proba_of_error", hue="Model name", style='Model name', ax=axes[tidx, 0])
        custom_reg_plot(data=topplot, x="I(summary -> text)", y=f"{task}/kl", hue="Model name", style='Model name', ax=axes[tidx, 1])

        # annotate with r value
        axes[tidx, 0].annotate(f"r={topplot['I(summary -> text)'].corr(df[f'{task}/proba_of_error']):.2f}", xy=(0.05, 0.2), xycoords='axes fraction', fontsize=12,
                               horizontalalignment='left', verticalalignment='top')
        axes[tidx, 1].annotate(f"r={topplot['I(summary -> text)'].corr(df[f'{task}/kl']):.2f}", xy=(0.05, 0.1), xycoords='axes fraction', fontsize=12,)



        # add title
        axes[tidx, 0].set_title(map_tasks[task], fontsize=20,  fontweight='bold')
        axes[tidx, 1].set_title(map_tasks[task], fontsize=20, fontweight='bold')

        # add y label
        axes[tidx, 0].set_ylabel("P(error)", fontsize=16, fontweight='bold')
        axes[tidx, 1].set_ylabel("KL", fontsize=16, fontweight='bold')



    # add global legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.15), ncol=2, fontsize=16)
    # remove all legends
    for ax in axes.flatten():
        ax.get_legend().remove()

    fig.tight_layout()

    # save figure
    path = f"../../../papers/Mutual-information-for-summarization/img/classification_tasks/{dataset}_classification_tasks.png"
    # create parent
    Path(path).parent.mkdir(parents=True, exist_ok=True)

    fig.savefig(path, dpi=300, bbox_inches='tight')


plot_classification_tasks_proba_kl(df, dataset="cnn_dailymail")
plot_classification_tasks_proba_kl(df, dataset="rotten_tomatoes")
plot_classification_tasks_proba_kl(df, dataset="xsum")

In [None]:
def plot_model_average_mi_per_dataset(df):
    
    df = df[~df['metadata/Decoding config'].str.contains("short")].copy()

    df = df[~df['metadata/Decoding config'].isin([f"beam_sampling_{k}" for k in [5, 10, 20, 50]])]


    datasets = set(df['metadata/Dataset name'].dropna().unique())
    datasets -= set(['peer_read', 'arxiv'])
    
    df = df[df['metadata/Dataset name'].isin(datasets)]
    
    # sort by I(summary -> text)
    df = df.sort_values(by="I(summary -> text)", ascending=False)
    
    # bar plot
    
    ax = sns.barplot(data=df, y="metadata/Model name", x="I(summary -> text)", hue="metadata/Dataset name", orient="h")
    
    # ylim
    ax.set_xlim([40, 100])
    
    
    
plot_model_average_mi_per_dataset(df)
    
    

In [None]:



def plot_classification_correlations(df, COL, path ="test.png"):

    df = df[~df['metadata/Decoding config'].str.contains('short')]

    map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis",
                 "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
                 "roberta-base-openai-detector" : "GPT detector",
                 }
    # select only the tasks we want

    # create a discrete sequential color palette with viridis
    # palette = sns.color_palette("viridis", len(df['metadata/Decoding size'].unique()))



    def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
        sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
        sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette="tab10")
        return ax


    fig, axes = plt.subplots(len(map_tasks), 2, figsize=(10, 7), sharey=False, sharex=True, dpi=300)

    for tidx, task in enumerate(map_tasks.keys()):
        topplot = df
        # rename columns
        topplot = topplot.rename(columns={"metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})

        custom_reg_plot(data=topplot, x=COL, y=f"{task}/proba_of_error", hue="Model name", style='Model name', ax=axes[tidx, 0])
        custom_reg_plot(data=topplot, x=COL, y=f"{task}/kl", hue="Model name", style='Model name', ax=axes[tidx, 1])

        # annotate with r value
        axes[tidx, 0].annotate(f"r={topplot[COL].corr(df[f'{task}/proba_of_error']):.2f}", xy=(0.05, 0.2), xycoords='axes fraction', fontsize=12,
                               horizontalalignment='left', verticalalignment='top')
        axes[tidx, 1].annotate(f"r={topplot[COL].corr(df[f'{task}/kl']):.2f}", xy=(0.05, 0.1), xycoords='axes fraction', fontsize=12, )



        # add title
        axes[tidx, 0].set_title(map_tasks[task], fontsize=20)
        axes[tidx, 1].set_title(map_tasks[task], fontsize=20)

        # add y label
        axes[tidx, 0].set_ylabel("P(error)", fontsize=16)
        axes[tidx, 1].set_ylabel("KL", fontsize=16)



    # add global legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=2, fontsize=14)
    # remove all legends
    for ax in axes.flatten():
        ax.get_legend().remove()


    fig.tight_layout()

    # save figure
    path = Path(f"../../../papers/Mutual-information-for-summarization/img/") / path
    # create parent
    Path(path).parent.mkdir(parents=True, exist_ok=True)

    # fig.savefig(path, dpi=300, bbox_inches='tight')


plot_classification_correlations(df[df['metadata/Dataset name'] == "peer_read"], COL="common/rougeLsum", path="rotten_tomaties.png")
plot_classification_correlations(df[df['metadata/Dataset name'] == "peer_read"], COL="I(summary -> text)", path="rotten_tomaties.png")



In [None]:


def plot_seahorse_metrics(df, COL):
    df = df[~df['metadata/Decoding config'].str.contains('short')]
    def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
        sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
        sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette='mako')
        return ax

    tasks_sh = set([c.split('/')[1] for c in df.columns if "SH" in c])
    metrics = set([c.split('/')[-1] for c in df.columns if "SH" in c])


    fig, axes = plt.subplots(2, len(tasks_sh) // 2, figsize=(15, 7), sharey=True, sharex=False, dpi=300)
    axes = axes.flatten()

    toplot = df.rename(columns={ "metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})

    for tidx, task in enumerate(tasks_sh):
        sns.set_theme(style="whitegrid")

        custom_reg_plot(data=toplot, x=COL, y=f"SHMetric/{task}/proba_1", hue="Model name", style='Model name', ax=axes[tidx])
        # annotate with r value
        axes[tidx].annotate(f"r={toplot[COL].corr(df[f'SHMetric/{task}/proba_1']):.2f}", xy=(0.05, 0.95), xycoords='axes fraction', fontsize=16,
                            horizontalalignment='left', verticalalignment='top')

        # change y title to be more readable
        axes[tidx].set_ylabel("P(yes)", fontsize=16)
        axes[tidx].set_xlabel(COL, fontsize=16)

        # add title
        axes[tidx].set_title(f"{task}", fontsize=20)




    # add global legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=3, fontsize=16)
    # remove all legends
    for ax in axes:
        ax.get_legend().remove()

    fig.tight_layout()
    # save figure
    path = f"../../../papers/Mutual-information-for-summarization/img/ablation_decoding_size/cnn_dailymail_sh_metrics.png"
    # create parent
    Path(path).parent.mkdir(parents=True, exist_ok=True)


     # plt.savefig(path, dpi=300, bbox_inches='tight')
    
plot_seahorse_metrics(df[df['metadata/Dataset name'] == "rotten_tomatoes"], COL="common/rougeLsum")
plot_seahorse_metrics(df[df['metadata/Dataset name'] == "rotten_tomatoes"], COL="I(summary -> text)")



# Other stuff


In [None]:

# List of datasets
datasets = df['metadata/Dataset name'].dropna().unique()

for dataset in datasets:
    
    # plot I(summary -> text) for each model
    sns.set_theme(style="whitegrid")
    # keep only top_p_sampling
    # df_top_p = df[df['metadata/Decoding config'] == 'top_p_sampling']
    # select dataset
    df_top_p = df[df['metadata/Dataset name'] == dataset]
    df_top_p = df[df['metadata/Decoding size'] == 50]
    
    # sort by I(summary -> text)
    df_top_p = df_top_p.sort_values(by="I(summary -> text)", ascending=False)
    
    ax =sns.barplot(data=df_top_p, y="metadata/Model name", x="I(summary -> text)", orient="h")
    
    # change y title to be more readable
    ax.set(ylabel="Model name")
    
    # change xlim based to be a little bit less than the min and a little bit more than the max
    ax.set_xlim([df_top_p['I(summary -> text)'].min() - 5, df_top_p['I(summary -> text)'].max() + 5])
    
    # save figure
    #plt.savefig(f"../../../papers/Mutual-information-for-summarization/img/model_comparison/{dataset}_top_p.png", dpi=300, bbox_inches='tight')
    # plt.clf()
    


In [None]:




ddf = df.drop('Unnamed: 0', axis=1)
# df = df.set_index([metadata for metadata in df.columns if "metadata" in metadata])
ddf = ddf.melt(id_vars=[metadata for metadata in ddf.columns if "metadata" in metadata] ,var_name="Score", value_name="Value")








In [None]:

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# palette = sns.color_palette("viridis", len(df['metadata/Decoding size'].unique()))



def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette="tab10")
    return ax


df["I(summary -> text) - I(text -> summary)"] =  df["I(summary -> text)"] - df["I(text -> summary)"]

# COL = 'I(summary -> text) - I(text -> summary)'
COL = "I(summary -> text)"

tasks_sh = set([c.split('/')[1] for c in df.columns if "SH" in c])
metrics = set([c.split('/')[-1] for c in df.columns if "SH" in c])





datasets = df['metadata/Dataset name'].dropna().unique()
print(datasets)
for ds in datasets:
    toplot = df.rename(columns={"metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config", "metadata/Dataset name": "Dataset name"})
    toplot = toplot[toplot['Dataset name'] == ds]
    
    display(toplot)

    # fig, axes = plt.subplots(2, len(tasks_sh) // 2, figsize=(15, 7), sharey=True, sharex=False, dpi=300)
    # axes = axes.flatten()
    # for tidx, task in enumerate(tasks_sh):
    #     sns.set_theme(style="whitegrid")
    # 
    #     custom_reg_plot(data=toplot, x=COL, y=f"SHMetric/{task}/proba_1", hue="Model name", style='Model name', ax=axes[tidx])
    #     # annotate with r value
    #     axes[tidx].annotate(f"r={toplot[COL].corr(df[f'SHMetric/{task}/proba_1']):.2f}", xy=(0.05, 0.95), xycoords='axes fraction', fontsize=16,
    #                         horizontalalignment='left', verticalalignment='top')
    # 
    #     # change y title to be more readable
    #     axes[tidx].set_ylabel("P(yes)", fontsize=16)
    #     axes[tidx].set_xlabel(COL, fontsize=16)
    # 
    #     # add title
    #     axes[tidx].set_title(f"{task}", fontsize=20)
    # 
    # 
    # 
    # 
    # # add global legend
    # handles, labels = axes[0].get_legend_handles_labels()
    # fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=3, fontsize=16)
    # # remove all legends
    # for ax in axes:
    #     if ax.get_legend():
    #         ax.get_legend().remove()
    # 
    # 
    # fig.tight_layout()
    # # save figure
    # path = f"../../../papers/Mutual-information-for-summarization/img/tasks_perfs/{ds}_sh_metrics.png"
    # # create parent
    # Path(path).parent.mkdir(parents=True, exist_ok=True)
    # 
    # 
    # plt.savefig(path, dpi=300, bbox_inches='tight')


# Ablation decoding size

## SH metrics

In [None]:


def custom_reg_plot(data, x=None, y=None, hue=None, **kwargs):
    ax = plt.gca()
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette='mako')
    return ax

df = pd.read_csv('../../output/corr_ablation_shmetrics.csv')


# keep only long and top p

df['metadata/Decoding size'] = df['metadata/Decoding config'].apply(lambda x: x.split('_')[-1])
# change type to int
df['metadata/Decoding size'] = df['metadata/Decoding size'].astype(int)


In [None]:
df

In [None]:

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette='mako')
    return ax



df["I(summary -> text) - I(text -> summary)"] =  df["I(summary -> text)"] - df["I(text -> summary)"]

# COL = 'I(summary -> text) - I(text -> summary)'
COL = "I(summary -> text)"

tasks_sh = set([c.split('/')[1] for c in df.columns if "SH" in c])
metrics = set([c.split('/')[-1] for c in df.columns if "SH" in c])


fig, axes = plt.subplots(2, len(tasks_sh) // 2, figsize=(15, 7), sharey=True, sharex=False, dpi=300)
axes = axes.flatten()

toplot = df.rename(columns={"metadata/Decoding size": "Decoding size", "metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})

for tidx, task in enumerate(tasks_sh):
        sns.set_theme(style="whitegrid")
        
        custom_reg_plot(data=toplot, x=COL, y=f"SHMetric/{task}/proba_1", hue="Decoding size", style='Model name', ax=axes[tidx])
        # annotate with r value
        axes[tidx].annotate(f"r={toplot[COL].corr(df[f'SHMetric/{task}/proba_1']):.2f}", xy=(0.05, 0.95), xycoords='axes fraction', fontsize=16,
                horizontalalignment='left', verticalalignment='top')
        
        # change y title to be more readable
        axes[tidx].set_ylabel("P(yes)", fontsize=16)
        axes[tidx].set_xlabel(COL, fontsize=16)
        
        # add title
        axes[tidx].set_title(f"{task}", fontsize=20)
        

        
        
# add global legend
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=3, fontsize=16)
# remove all legends
for ax in axes:
    ax.get_legend().remove()

fig.tight_layout()
# save figure
path = f"../../../papers/Mutual-information-for-summarization/img/ablation_decoding_size/cnn_dailymail_sh_metrics.png"
# create parent
Path(path).parent.mkdir(parents=True, exist_ok=True)


plt.savefig(path, dpi=300, bbox_inches='tight')
# plt.clf()
#plt.close('all')


## Classifications performance

In [None]:
df = df[~df['metadata/Decoding config'].isna()]

# df = df[~df['metadata/Decoding config'].str.contains('short')]
display(df)
print(df.columns)


# tasks names
tasks_names = [c.split('/')[0] for c in df.columns if "kl" in c]
print(tasks_names)

def custom_reg_plot(data, x=None, y=None, hue=None, **kwargs):
    ax = plt.gca()
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette='mako')
    return ax


df["I(summary -> text) - I(text -> summary)"] =  df["I(summary -> text)"] - df["I(text -> summary)"]

COL = 'I(summary -> text) - I(text -> summary)'
for task in tasks_names:
    for metric in ['proba_of_error', 'l2', 'l1', 'kl', 'dot', 'proba_of_error']:
        print(task)

        sns.set_theme(style="whitegrid")
        # g =sns.lmplot(data=df, x="I(summary -> text)", y=f"{task}/{metric}", col="metadata/Dataset name", facet_kws={'sharey': False, 'sharex': False}, scatter=False)
        # g.map_dataframe(sns.scatterplot, x="I(summary -> text)", y=f"{task}/{metric}", alpha=0.7, hue="metadata/Model name", s=100)
        g = sns.FacetGrid(df, col="metadata/Dataset name", sharey=False, sharex=False, col_wrap=2, height=4, aspect=1.5)
        g.map_dataframe(custom_reg_plot, x=COL, y=f"{task}/{metric}", hue="metadata/Decoding size", style='metadata/Model name')



        # change y title to be more readable
        g.set(ylabel=metric)

        # put legend outside, under the plot in the center
        g.add_legend(loc='lower center', bbox_to_anchor=(0.20, -0.30), ncol=3)

        # save figure
        path = f"../../../papers/Mutual-information-for-summarization/img/ablation_decoding_size/{task}_{metric}.png"
        # create parent
        Path(path).parent.mkdir(parents=True, exist_ok=True)


        plt.savefig(path, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close('all')


In [None]:

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

palette = sns.color_palette("viridis", len(df['metadata/Decoding size'].unique()))



def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette=palette)
    return ax




df["I(summary -> text) - I(text -> summary)"] =  df["I(summary -> text)"] - df["I(text -> summary)"]

# COL = 'I(summary -> text) - I(text -> summary)'
COL = "I(summary -> text)"

tasks_sh = set([c.split('/')[1] for c in df.columns if "SH" in c])
metrics = set([c.split('/')[-1] for c in df.columns if "SH" in c])


fig, axes = plt.subplots(2, len(tasks_sh) // 2, figsize=(15, 7), sharey=True, sharex=False, dpi=300)
axes = axes.flatten()

toplot = df.rename(columns={"metadata/Decoding size": "Decoding size", "metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})

for tidx, task in enumerate(tasks_sh):
    sns.set_theme(style="whitegrid")

    custom_reg_plot(data=toplot, x=COL, y=f"SHMetric/{task}/proba_1", hue="Decoding size", style='Model name', ax=axes[tidx])
    # annotate with r value
    axes[tidx].annotate(f"r={toplot[COL].corr(df[f'SHMetric/{task}/proba_1']):.2f}", xy=(0.05, 0.95), xycoords='axes fraction', fontsize=16,
                        horizontalalignment='left', verticalalignment='top')

    # change y title to be more readable
    axes[tidx].set_ylabel("P(yes)", fontsize=16)
    axes[tidx].set_xlabel(COL, fontsize=16)

    # add title
    axes[tidx].set_title(f"{task}", fontsize=20)




# add global legend
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=3, fontsize=16)
# remove all legends
for ax in axes:
    ax.get_legend().remove()

fig.tight_layout()
# save figure
path = f"../../../papers/Mutual-information-for-summarization/img/ablation_decoding_size/cnn_dailymail_sh_metrics.png"
# create parent
Path(path).parent.mkdir(parents=True, exist_ok=True)


plt.savefig(path, dpi=300, bbox_inches='tight')
# plt.clf()
#plt.close('all')


In [None]:
tasks_names

In [None]:

tasks_names

map_tasks = {"mrm8488_distilroberta-finetuned-financial-news-sentiment-analysis" : "Sentiment analysis", 
             "wesleyacheng_news-topic-classification-with-bert" : "Topic classification",
             "roberta-base-openai-detector" : "GPT detector",
             }
# select only the tasks we want

# create a discrete sequential color palette with viridis
palette = sns.color_palette("viridis", len(df['metadata/Decoding size'].unique()))



def custom_reg_plot(data, x=None, y=None, hue=None, ax=None, **kwargs):
    sns.regplot(data=data, x=x, y=y, ci=95, scatter=False, ax=ax, x_ci='sd')
    sns.scatterplot(data=data, x=x, y=y, hue=hue, alpha=1, s=100, ax=ax, **kwargs, palette=palette)
    return ax


fig, axes = plt.subplots(len(map_tasks), 2, figsize=(10, 7), sharey=False, sharex=True, dpi=300)

for tidx, task in enumerate(map_tasks.keys()):
    topplot = df
    # rename columns
    topplot = topplot.rename(columns={"metadata/Decoding size": "Decoding size", "metadata/Model name": "Model name", "metadata/Decoding config": "Decoding config"})
    
    custom_reg_plot(data=topplot, x="I(summary -> text)", y=f"{task}/proba_of_error", hue="Decoding size", style='Model name', ax=axes[tidx, 0])
    custom_reg_plot(data=topplot, x="I(summary -> text)", y=f"{task}/kl", hue="Decoding size", style='Model name', ax=axes[tidx, 1])
    
    # annotate with r value
    axes[tidx, 0].annotate(f"r={topplot['I(summary -> text)'].corr(df[f'{task}/proba_of_error']):.2f}", xy=(0.05, 0.2), xycoords='axes fraction', fontsize=12,
                        horizontalalignment='left', verticalalignment='top')
    axes[tidx, 1].annotate(f"r={topplot['I(summary -> text)'].corr(df[f'{task}/kl']):.2f}", xy=(0.05, 0.1), xycoords='axes fraction', fontsize=12,)
    

    
    # add title
    axes[tidx, 0].set_title(map_tasks[task], fontsize=20)
    axes[tidx, 1].set_title(map_tasks[task], fontsize=20)
    
    # add y label
    axes[tidx, 0].set_ylabel("P(error)", fontsize=16)
    axes[tidx, 1].set_ylabel("KL", fontsize=16)
    


# add global legend
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=2, fontsize=14)
# remove all legends
for ax in axes.flatten():
    ax.get_legend().remove()


fig.tight_layout()

# save figure
path = f"../../../papers/Mutual-information-for-summarization/img/ablation_decoding_size/classification.png"
# create parent
Path(path).parent.mkdir(parents=True, exist_ok=True)

fig.savefig(path, dpi=300, bbox_inches='tight')


