In [1]:
%matplotlib widget
import matplotlib.pyplot as plt
from ldg.pickle import pickle_read
import pandas as pd
import re

In [3]:
K = 50
K_RANGE = range(5,55,5)
def read_scores(metric, model='bert-base-cased'):
    scores= {}
    for i in K_RANGE:
        scores[i] = pickle_read(f'{metric}_predictions/{model}.tsv.{i}.pkl')
    return scores

def read_score(metric='cosine', model='bert-base-cased', freq_cutoff=20, rarity=10, eval='rec'):
    return pickle_read(f'{metric}_predictions/{model}.tsv.{freq_cutoff}.pklrarity{rarity}.{eval}')

POS_PATTERN = re.compile(r'.*\.(.*?)\..*')
NE_PATTERN = re.compile(r'(NE)')
def get_pos(label):
    posm = re.match(POS_PATTERN, label)
    nem = re.match(NE_PATTERN, label)
    if posm:
        return posm.groups()[0]
    if nem:
        return nem.groups()[0]
    return None

def read_data(metric, model='bert-base-cased'):
    data = pd.read_csv(f'{metric}_predictions/{model}.tsv', sep='\t')
    data['pos'] = data.label.apply(get_pos)
    return data

euclidean_scores = read_scores('euclidean')
cosine_scores = read_scores('cosine')
print(cosine_scores[5][1])
print(euclidean_scores[5][1])

euclidean_data = read_data('euclidean')
cosine_data = read_data('cosine')
cosine_data.keys()

{'label': 0.7377845310451674, 'synset': 0.74368370873555, 'lemma': 0.9177392444285544}
{'label': 0.7288165892027172, 'synset': 0.7350137051602907, 'lemma': 0.9096353235609582, 'synset_different_lemma': 0.006197115957573591}


Index(['sentence', 'label', 'synset', 'lemma', 'label_freq_in_train',
       'label_1', 'label_2', 'label_3', 'label_4', 'label_5',
       ...
       'distance_42', 'distance_43', 'distance_44', 'distance_45',
       'distance_46', 'distance_47', 'distance_48', 'distance_49',
       'distance_50', 'pos'],
      dtype='object', length=256)

In [66]:
def xyzize(scores_dict, key='synset'):
    x, y, z = [], [], []
    for cutoff, p_at_k_dict in scores_dict.items():
        for k, scores in p_at_k_dict.items():
            if k > cutoff and False:
                continue
            x.append(cutoff)
            y.append(k)
            z.append(scores[key])
    return x, y, z

def plot_surface(scores_dict, key='synset'):
    plt.figure(figsize=(8,8))
    ax = plt.axes(projection='3d')
    ax.set_xlabel('occurrence cutoff')
    ax.set_ylabel('k')
    ax.set_zlabel(f'{key} recall')
    #ax.set_zlim(0.3, 1)
    x, y, z = xyzize(scores_dict, key=key)
    ax.scatter(x, y, z, c=z, cmap='viridis', linewidth=0.5)
    plt.show()
    
def plot_p_vs_k(p_at_k_dict, p_at_k_dict_2=None, key='synset'):
    plt.figure(figsize=(7,7))
    ax = plt.axes()
    ax.set_xlabel('k')
    ax.set_ylabel(f'{key} precision')
    ax.set_ylim(0, 1)
    
    k, scores = list(zip(*p_at_k_dict.items()))
    scores = [x[key] for x in scores]
    k = list(k)
    ax.scatter(x=k, y=scores, c=scores)
    if p_at_k_dict_2:
        k, scores = list(zip(*p_at_k_dict_2.items()))
        scores = [x[key] for x in scores]
        k = list(k)       
        ax.scatter(x=k, y=scores, c=scores, marker='x')
    plt.show()    
    
def plot_p_vs_k_grid(p_at_k_dict_2=None, key='synset', eval='rec'):
    fig, axs = plt.subplots(4,4,figsize=(8,8))
    plt.subplots_adjust(hspace=0.5, wspace=0.5)
    fig.suptitle(("Recall" if eval=='rec' else "Precision") + " at K")
    
    i = 0
    for freq_cutoff in [5, 10, 25, 50]:
        for rarity in [5, 10, 30, 50]:
            ax = axs[i//4][i%4]
            i += 1
            ax.set_title(f'c={freq_cutoff}, rar={rarity/100}')
            ax.set_ylim(0,0.30)
            scores = read_score(freq_cutoff=freq_cutoff, rarity=rarity, eval='rec')
            
            k, scores = list(zip(*scores.items()))
            scores = [x[key] for x in scores]
            k = list(k)
            ax.scatter(x=k, y=scores, c=scores)
    plt.show()
    
#score = read_score(
#    freq_cutoff=5,
#    rarity=5,
#    eval='prec'
#)
#plot_p_vs_k(score, key='synset')
plot_p_vs_k_grid(key='synset', eval='rec')
plot_p_vs_k_grid(key='synset', eval='prec')
#score

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
# data columns:
# sentence: original sentence
# label, lemma, synset, label_freq_in_train
# label_i
# synset_i
# lemma_i
# distance_i
list(cosine_data.keys())

def distance_correctness_corr(data, key='synset'):
    
    plt.figure(figsize=(8,5))
    ax = plt.axes()
    ax.set_xlabel('k')
    ax.set_ylabel(f'distance correlated with correctness')
    
    df = data.copy()
    for i in range(1,K+1):
        df['correct'] = df[f'{key}_{i}'] == df[f'{key}']
        r = df[['correct', f'distance_{i}']].corr()['correct'][1]
        ax.scatter(i, r, color='blue')
    plt.show()
    

def distance_correctness_histogram(data, key='synset'):
    df = data.copy()
    df['first_correct'] = df[f'{key}_1'] == df[f'{key}']
    df[['first_correct', 'distance_1']].pivot(columns='first_correct').distance_1.plot.hist(stacked=True, bins=50)
    plt.show()
    
def distance_correctness_histogram_by_pos(data, key='synset'):
    fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10,8))
    df = data.copy()
    df['first_correct'] = df[f'{key}_1'] == df[f'{key}']
    for i, pos in enumerate(['NE', 'a', 'n', 'r', 's', 'v']):
        df2 = df[df['pos'] == pos]
        df2 = df2[['first_correct', 'distance_1']]
        df2 = df2.pivot(columns='first_correct')
        ax = axes[i//2, i%2]
        ax.set_title({"n": "Noun", "v": "Verb", "a": "Adj", "s": "Adj satellite", "r": "Adverb", "NE": "Named Entity"}[pos])
        df2.distance_1.plot.hist(stacked=True, bins=50, ax=ax)
    plt.show()
    
distance_correctness_corr(cosine_data)
#distance_correctness_histogram(euclidean_data[euclidean_data.pos=='r'])
#distance_correctness_histogram(cosine_data)
print("Cosine correctness")
distance_correctness_histogram_by_pos(cosine_data)
print("Euclidean correctness")
distance_correctness_histogram_by_pos(euclidean_data)
    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Cosine correctness


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Euclidean correctness


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
cosine_scores[5]