In [20]:
from pathlib import Path
from glob import glob
import pandas as pd
import numpy as np

from alternationprober.constants import (
    PATH_TO_RESULTS_DIRECTORY,
)

import seaborn as sns
import matplotlib.pyplot as plt

In [21]:
def sorting_key(layer_path):
    """Sort the result paths by the final directory, where ``static`` corresponds to layer 0."""
    if layer_path.stem == "static":
        key = 0
    else:
        key = int(layer_path.stem)
    return key

experiment_results = PATH_TO_RESULTS_DIRECTORY / "linear-probe-for-word-embeddings"

layer_paths = experiment_results.glob("*/")
layer_paths = sorted([path for path in layer_paths if path.is_dir()], key=sorting_key)
layer_dfs = []
for i, layer_path in enumerate(layer_paths):
    for alternation_csv in layer_path.glob("*.csv"):
        if 'predictions' in str(alternation_csv):
            frame = alternation_csv.stem.split('_')[0]
            layer_df = pd.read_csv(alternation_csv)
            layer_df['layer'] = i
            layer_df['frame'] = frame
            layer_df = layer_df.rename(columns={f'{frame}_true': 'label', f'{frame}_predicted': 'predicted'})
            layer_dfs.append(layer_df)
all_df = pd.concat(layer_dfs, axis=0).sort_values(by=['layer', 'frame'])

In [23]:
# Best layer is based on average MCC across all frames
best_layer = 7
best_layer_df = all_df[all_df.layer == best_layer]
best_layer_df = best_layer_df[['verb', 'label', 'predicted', 'frame']]

print(best_layer_df.shape)

error_df = best_layer_df[best_layer_df.predicted != best_layer_df.label]
error_df = error_df.sort_values('frame')

error_dist = error_df[['frame']].value_counts(normalize=True).to_frame()
error_dist.columns = ['prop_error']
error_dist

(3393, 4)


Unnamed: 0_level_0,prop_error
frame,Unnamed: 1_level_1
Refl,0.25
inchoative,0.159091
preposition,0.159091
2object,0.113636
locative,0.113636
with,0.113636
Non-Refl,0.045455
There,0.045455


In [124]:
frame_counts = best_layer_df.groupby('frame')[['label']].sum()
frame_counts['negative'] = best_layer_df.frame.value_counts() - frame_counts['label']
frame_counts.columns = ['positive', 'negative']
frame_counts = frame_counts.astype(int)
frame_counts['total'] = frame_counts.sum(axis=1)
frame_counts = frame_counts.T.reset_index(drop=False)
frame_counts = frame_counts[['inchoative', 'causative', 'preposition', '2object',
    'with', 'locative', 'No-There', 'There', 'Refl', 'Non-Refl']]
frame_counts.columns = ['Inch.', 'Caus.', 'Prep.', '2-Obj',
    'with', 'loc.', 'no-there', 'there', 'Refl', 'No-Refl']
frame_counts.index = ['Positive', 'Negative', 'Total']
print(frame_counts.to_latex())

\begin{tabular}{lrrrrrrrrrr}
\toprule
{} &  Inch. &  Caus. &  Prep. &  2-Obj &  with &  loc. &  no-there &  there &  Refl &  No-Refl \\
\midrule
Positive &     73 &    124 &     65 &     74 &   101 &    86 &       149 &     50 &    84 &       11 \\
Negative &    144 &      0 &    377 &    442 &   242 &   257 &         0 &    192 &   419 &      503 \\
Total    &    217 &    124 &    442 &    516 &   343 &   343 &       149 &    242 &   503 &      514 \\
\bottomrule
\end{tabular}

