In [None]:
import pickle
import numpy as np
from copy import deepcopy
from pprint import pprint
import matplotlib.pyplot as plt


ENV_NAMES = [
    "ADV",
    "COND",
    "D-NEG",
    "S-NEG",
    "ONLY",
    "QNT",
    "QUES",
    "SMP-Q",
    "SUP",
]

FULL_ENV_NAMES = [
    "adverbs",
    "conditional",
    "determiner_negation_biclausal",
    "sentential_negation_biclausal",
    "only",
    "quantifier",
    "questions",
    "simplequestions",
    "superlative",
]

def mn2mt(mn):
    """model_name to model_type"""
    if "gulorgood" in mn:
        return "full"
    if "filter" in mn:
        return "no_npi"
    return "no_env"


with open('results/all_results.pickle', 'rb') as f:
    raw_results = pickle.load(f)

with open('results/probing_results_no_dup.pickle', 'rb') as f:
    probing_results = pickle.load(f)
    
with open('results/median_rank_results.pickle', 'rb') as f:
    median_rank_results = pickle.load(f)
    

for mn, result in raw_results.items():
    if "probe" in probing_results[mn]:
        raw_results[mn]["probe"] = probing_results[mn]["probe"]
    if "median_rank" in median_rank_results[mn]:
        raw_results[mn]["median_rank"] = median_rank_results[mn]["median_rank"]

    
clean_results = {
    "probe": {
        "full": [],
        "no_npi": [],
    },
    "warstadt": {
        "full": [],
        "no_env": [[] for _ in FULL_ENV_NAMES],
    },
    "median_rank": {
        "full": [],
        "no_env": [[] for _ in FULL_ENV_NAMES],
    },
}


for model_name, result in raw_results.items():
    model_type = mn2mt(model_name)
    
    if model_type in ["full", "no_npi"]:
        probe_result = [
            result["probe"][env][1,"hx"]["accuracy"] 
            for env in ["uniform"] + FULL_ENV_NAMES
        ]

        clean_results["probe"][model_type].append(np.array(probe_result))
    else:
        env = [env for env in FULL_ENV_NAMES if env in model_name[:len(env)]][0]
        env_idx = FULL_ENV_NAMES.index(env)

        clean_results["median_rank"][model_type][env_idx].append(result["median_rank"][env][0])
        clean_results["warstadt"][model_type][env_idx].append(result["warstadt"]["warstadt"][env])

    if model_type == "full":
        median_rank_result = [
            result["median_rank"][env][0]
            for env in ["all"] + FULL_ENV_NAMES
        ]
        warstadt_result = [
            result["warstadt"]["warstadt"][env]
            for env in FULL_ENV_NAMES
        ]

        clean_results["median_rank"][model_type].append(np.array(median_rank_result))
        clean_results["warstadt"][model_type].append(np.array(warstadt_result))

           

for task, task_results in clean_results.items():
    aggregator = np.median if task == "median_rank" else np.mean
    
    for model_type, result in task_results.items():
        axis = 1 if model_type == "no_env" else 0
        if task == "median_rank" and model_type == "no_env":
            result.insert(0, [10, 10, 10])
        clean_results[task][model_type] = aggregator(result, axis=axis), np.std(result, axis=axis)

In [None]:
import matplotlib.lines as lines

plt.rcParams["figure.figsize"] = 10, 7

task2model_types = {
    "probe": ["Full (exp. 1)", r"Full $\backslash$ NPI (exp. 4)"], 
    "median_rank": ["Full (exp. 3)", r"Full $\backslash$ ENV$\cap$NPI (exp. 5b)"], 
    "warstadt": ["Full (exp. 2)", r"Full $\backslash$ ENV$\cap$NPI (exp. 5a)"], 
}

task2title = {
    "probe": "Monotonicity probing accuracy",
    "median_rank": r"LM decoder & monotonicity DC cosine similarity $-$ Median NPI rank",
    "warstadt": r"NPI acceptability accuracy"
}

task2caption = {
    "probe": "DC evaluated on held-out environment class",
    "median_rank": "Environment class monotonicity DC trained on",
    "warstadt": "Environment class"
}

def plot_scores(results):
    for task, result in results.items():
        print(task)
        fig = plt.figure()

        means = np.array([arr[0] for arr in result.values()])
        stds = np.array([arr[1] for arr in result.values()])

        sizes = np.full(means.shape[1], 0.2)

        fig, axs = plt.subplots(1,1)
        axs.axis('off')

        cvals = means
        cmin, cmax = 0.25, 0.75

        if task == "median_rank":
            cvals = means - np.min(means) + 2
            cvals = np.log(cvals) / np.max(np.log(cvals))
            cvals = cvals / 2 + 0.5
            cvals = cvals * -1 + 1.5 
            cvals += (1 - np.max(cvals))

        cvals = np.append(cvals, [np.zeros(means.shape[1])+.5], axis=0)
        colors = plt.cm.PiYG(cvals)
        
        if task == "median_rank":
            rounded_means = means.astype(np.int)
            scores = rounded_means.astype(np.str)
        elif stds is not None:
            rounded_means = np.round(means, 2)
            scores = rounded_means.astype(np.str)
            rounded_stds = np.round(stds, 2)
            str_stds = rounded_stds.astype(np.str)
            for (i, j), mean in np.ndenumerate(scores):
                std = str_stds[i, j]
                mean = f"{mean}0" if len(mean) == 3 else mean
                std = f"{std}0" if len(std) == 3 else std
                scores[i, j] = f"{mean}$_{{\pm{std}}}$"

        env_names = [["All-ENV"] + ENV_NAMES] if scores.shape[1] != len(ENV_NAMES) else [ENV_NAMES]
        scores = np.append(scores, env_names, axis=0)

        model_types = task2model_types[task]
        
        table = axs.table(
            rowLabels=model_types+[""],
            rowLoc="right",
            cellText=scores,
            cellColours=colors,
            loc='center',
            cellLoc='center',
            colLoc="center",
            colWidths=sizes,
            edges="closed",
        )
        table.scale(1, 3)

        for idx, cell in enumerate(table.properties()['child_artists']):
            i, j = idx // cvals.shape[1], idx % cvals.shape[1]
            cell.get_text().set_fontsize(18)

            try:
                cell_text = cell.get_text().get_text()
                if idx < np.prod(cvals.shape) and not (cmin < cvals[i, j] < cmax):
                    cell.get_text().set_color('white')
            except ValueError:
                pass

        ## Table values layout
        for row_cell in table.properties()['child_artists'][-scores.shape[0]:]:
            row_cell.get_text().set_fontsize(18)
            row_cell.set_text_props(fontweight="bold")

        ## Model name layout
        for cell in table._cells:
            if cell[0] == scores.shape[0]-1:
                col_cell = table._cells[cell]
                col_cell.set_color("w")
                col_cell.get_text().set_color("black")
                col_cell.set_text_props(fontweight="bold")
                col_cell.set_height(0.1)

        ## Reset table cell borders
        for key, cell in table.get_celld().items():
            cell.set_linewidth(4)
            cell.set_edgecolor("w")

        ## Title with adaptive padding
        plt.title(task2title[task], pad=(-160)+(scores.shape[0]*20)-10, fontweight="bold", fontsize=20)

        ## Lower caption text
        label_offset = 0.27 if scores.shape[0] == 4 else 0.3
        plt.text(
            0.5, 
            label_offset, 
            task2caption[task], 
            fontsize=18, 
            fontweight="normal", 
            horizontalalignment='center'
        )
        
        ## Uniform line divider
        if "All-ENV" in env_names[0]:
            bg_line = lines.Line2D(
                [-.107,-.107], [0.475,.6], 
                linewidth=6,
                color="white",
                transform=fig.transFigure, 
                figure=fig
            )
            line = lines.Line2D(
                [-.107,-.107], [0.4,.62], 
                linewidth=2,
                color="black",
                dashes=(2,1),
                transform=fig.transFigure, 
                figure=fig
            )
            fig.lines.extend([bg_line, line])

        plt.show()

        fig.savefig(f"figures/{task}.pdf", bbox_inches='tight')


plot_scores(clean_results)