In [None]:
import os
import shutil
import json
import pandas as pd
import numpy as np

from itertools import cycle
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.style as style
import matplotlib

import utils_general
import utils_podcasts

METADATA_DF = pd.read_csv(utils_general.PATH_TO_2020_TESTSET_DF, sep="\t")

transformation_dict = {"repeats": "Repeats",
        "interjections":"Interjections",
        "false-starts":"False Starts",
        "interjections-and-false-starts":"Interjections + False Starts",
        "repeats-and-false-starts":"Repeats + False Starts",
        "repeats-and-interjections":"Repeats + Interjections",
        "all-3":"Repeats + Interjections + False Starts"}

model_minmax_dict = {
    "bart": [0.105, 0.140],
    "pegasus": [0.02, 0.14],
    "t5": [0.02,0.14]
}

In [None]:
REPEATS_VALS = [0,1,2,3,4,5,6,7,8,9,10]
ROUGE_TYPE = "rougeL" # options are: "rouge1", "rouge2", "rougeL", "rougeLsum"

for m in utils_general.INFERENCE_MODEL_LIST:
    print(m)
    
    # set up colors, line widths, and font sizes for plotting
    style.use("tableau-colorblind10")
    matplotlib.rcParams['lines.linewidth'] = 2
    matplotlib.rcParams.update({'font.size': 9})

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = cycle(prop_cycle.by_key()['color'])
    next(colors)
    next(colors)
    next(colors)
    next(colors)

    # write out the chart for that input and model
    current_figure_name = "ROUGE_SCORES__" + m + "__" + ROUGE_TYPE + ".png"
    fontsize = "large"
    fig, ax = plt.subplots()
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    fig.set_size_inches(8,8)
    
    # get the zero csv file for the model and load it into a df
    zero_df = pd.read_csv(os.path.join(utils_general.PATH_TO_CSV, "0" + "__" + m + "__ROUGE_SCORES.csv"))
    zero_df = zero_df.drop(columns="N_parameter")
    zero_df["N_parameter"] = 0
    
    for t in ["repeats","interjections","false-starts","interjections-and-false-starts","repeats-and-false-starts","repeats-and-interjections", "all-3"]:

        # get the current csv file for the trf and load it into a df
        current_csv_df = pd.read_csv(os.path.join(utils_general.PATH_TO_CSV, t + "__" + m + "__ROUGE_SCORES.csv"))
        df = pd.concat([current_csv_df, zero_df])  # add in the zero data points bc they were run separately

        # collect up all the means 
        means = []
        for value in REPEATS_VALS:

            # get temp df with current repeat value
            temp_df = df.loc[df["N_parameter"] == value]

            # take the mean
            mean = temp_df[ROUGE_TYPE].mean()
            std = temp_df[ROUGE_TYPE].mean()

            # append that mean to the list
            means.append(round(mean, ndigits=3))

        # plot those means
        if t in transformation_dict:
            ax.plot(REPEATS_VALS, means, label=transformation_dict[t], marker="o", color=next(colors))
        else:
            ax.plot(REPEATS_VALS, means, marker="o", color=next(colors))

    # set labels
    ax.set_ylabel(ROUGE_TYPE.capitalize() +" Mean", fontsize="x-large")
    ax.set_xlabel("N Parameter", fontsize="x-large")
    ax.set_title(m + " " + ROUGE_TYPE +" Summarization Quality", fontsize="xx-large")
    
    # change figure limits
    if m in model_minmax_dict:
        ax.set_ylim(model_minmax_dict[m])  # to manually set axis limits
    
    # set the legend
    ax.legend(loc="best", fontsize="medium")
    
    # save the figure
    plt.savefig(os.path.join(utils_general.PATH_TO_FIGS, current_figure_name), bbox_inches="tight")
