In [31]:
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import sys 
sys.path.append('/home/hhansen/decon/decon_env/DecontextEmbeddings')
import os 
EMBEDDING_DATA_DIR = '/home/hhansen/decon/decon_env/data'
os.environ['EMBEDDING_DATA_DIR'] = EMBEDDING_DATA_DIR
os.environ['EMBEDDING_EVALUATION_DATA_PATH'] = '/home/hhansen/decon/decon_env/DecontextEmbeddings/helpers/embedding_evaluation/data/'
DATA_DIR = '/home/hhansen/decon/decon_env/DecontextEmbeddings/data'
os.environ['DATA_DIR'] = DATA_DIR


from helpers.embedding_evaluation.evaluate import Evaluation as wordsim_evaluate
from helpers.things_evaluation.evaluate import read_embeddings, load_behav, load_sorting, match_behv_sim, evaluate as run_evaluation
from helpers.data import yield_static_data
from helpers.intersection import get_intersection_words
from helpers.plot import get_ax
from helpers.data import load_spose_dimensions


from scipy.stats import spearmanr, pearsonr
from collections import defaultdict
import os 
from scipy.spatial.distance import squareform
from sklearn.linear_model import ElasticNetCV, MultiTaskElasticNetCV, RidgeCV, LassoCV, LinearRegression
from sklearn.model_selection import KFold, RepeatedKFold
from scipy.stats import spearmanr, pearsonr

from sklearn.preprocessing import StandardScaler

In [4]:
with open(f'spose_similarity.pkl', 'rb') as r_file:
    spose_similarity = pickle.load(r_file)
spose_similarity = squareform(spose_similarity, force='tovector', checks=False)
spose_similarity.shape

(360825,)

# Prepare FRRSA layer results

In [5]:
matching = 'main_word'
frrsa_sim_matrix_per_layer = defaultdict(list)

with open(f'frrsa_results_{matching}.pkl', 'rb') as r_file:
    results_frrsa = pickle.load(r_file)
    for model, _ in results_frrsa.items():
        if model != 'w2v' and model != 'glove':
            for layer in _:
                sim_matrix = layer[1][:, :, 0]
                sim_vector = squareform(sim_matrix, force='tovector', checks=False)
                frrsa_sim_matrix_per_layer[model].append(sim_vector)

In [6]:
for model in frrsa_sim_matrix_per_layer:
    df = pd.DataFrame(np.asarray(frrsa_sim_matrix_per_layer[model]).T)
    df['spose_similarity'] = spose_similarity
    df = df[~df.isin([9999]).any(axis=1)]
    frrsa_sim_matrix_per_layer[model] = df

In [7]:
frrsa_sim_matrix_per_layer['bert-base'].shape

(963, 14)

In [8]:
b = pd.DataFrame({'a': [1,2], 'b': [3,4]})
b[b['a'].isin([2])]
b[b.isin([2])]

Unnamed: 0,a,b
0,,
1,2.0,


# Mean baseline average

In [9]:
for model, result in frrsa_sim_matrix_per_layer.items():
    filtered_spose_similarity = result['spose_similarity']
    n_samples = filtered_spose_similarity.shape[0]
    mean_sim = result.iloc[:, :13].to_numpy().mean(axis=1)
    corr = pearsonr(filtered_spose_similarity, mean_sim)
    print(f'{model}: Pearsonr: {corr[0]} n_samples: {n_samples}')

bert-base: Pearsonr: 0.5140261802233548 n_samples: 963
gpt-2: Pearsonr: 0.5545182855673332 n_samples: 968
bert-large: Pearsonr: 0.20202278022732031 n_samples: 7
gpt-2-medium: Pearsonr: -0.2335221864583686 n_samples: 8


# Regression based 

In [58]:
def fit_model(layer_frrsa_similarity, spose_similarity, random_state):
    cv = RepeatedKFold(n_splits=5, n_repeats=3, random_state=random_state)
    ratios = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
    n_alphas = 10
    clf = ElasticNetCV(l1_ratio=ratios, n_alphas=n_alphas, fit_intercept=False, cv=cv, random_state=random_state)
    #clf = LinearRegression(fit_intercept=False)
    clf.fit(layer_frrsa_similarity, spose_similarity)
    return clf

In [59]:
def score(test_spose_similarity, pred_spose_similarity):
    return pearsonr(test_spose_similarity, pred_spose_similarity)[0]

In [60]:
def fit_and_score_regression(layer_frrsa_sim, spose_similarity):
    print(f'Predictor shape: {layer_frrsa_sim.shape} Target shape: {spose_similarity.shape}')
    outer_cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=42)
    n_samples = layer_frrsa_sim.shape[0]
    scores = []

    for i_split, (outer_train_ind, outer_test_ind) in enumerate(outer_cv.split(range(n_samples))):
        train_layer_frrsa_similarity = layer_frrsa_sim[outer_train_ind, :]
        #print(f'Number train samples: {train_layer_frrsa_similarity.shape[0]}')

        train_spose_similarity = spose_similarity[outer_train_ind]
        model = fit_model(train_layer_frrsa_similarity, train_spose_similarity, 42)

        test_spose_similarity = spose_similarity[outer_test_ind]
        test_frrsa_similarity = layer_frrsa_sim[outer_test_ind, :]
        #print(f'Number test samples: {test_frrsa_similarity.shape[0]}')

        pred_spose_similarity = model.predict(test_frrsa_similarity)
        pearson = score(test_spose_similarity, pred_spose_similarity)
        scores.append(pearson)
        
    return scores

In [61]:
for model, result in frrsa_sim_matrix_per_layer.items():
    #scaler = StandardScaler()
    spose_similarity = result.loc[:, 'spose_similarity'].to_numpy()
    result = result.drop(columns=['spose_similarity'])
    layer_frrsa_sim = result.to_numpy()
    #layer_frrsa_sim = scaler.fit_transform(layer_frrsa_sim)

    scores = fit_and_score_regression(layer_frrsa_sim, spose_similarity)
    cv_score = np.asarray(scores).mean()
    print(f'{model}: {cv_score}')

(963, 13) (963,)
bert-base: 0.5455151120201068
(968, 13) (968,)
gpt-2: 0.6025194106882601
(7, 25) (7,)


  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = c

  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(
  model = cd_fast.enet_coordinate_descent(


ValueError: x and y must have length at least 2.

In [13]:
df= frrsa_sim_matrix_per_layer['bert-base']
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,spose_similarity
1099,0.306237,0.273510,0.285984,0.284254,0.265938,0.255552,0.282335,0.271220,0.235269,0.212936,0.196851,0.204278,0.230541,0.129707
3468,0.303125,0.222166,0.237801,0.238570,0.252180,0.198098,0.204494,0.245597,0.230679,0.200986,0.178957,0.225390,0.277387,0.216209
3814,0.314255,0.335804,0.356492,0.341691,0.432666,0.467250,0.540340,0.509627,0.527027,0.591948,0.509716,0.450000,0.361018,0.516748
3837,0.288585,0.293443,0.298734,0.293094,0.324815,0.318068,0.412525,0.481433,0.529062,0.450301,0.415534,0.412009,0.376312,0.327900
3856,0.342099,0.353948,0.327357,0.342607,0.381556,0.286950,0.283015,0.376693,0.327220,0.341442,0.289354,0.300546,0.385862,0.251489
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
359761,0.301224,0.249336,0.241258,0.208347,0.214260,0.210785,0.236306,0.235901,0.172584,0.172250,0.231718,0.178078,0.234691,0.312069
360011,0.342137,0.297560,0.295694,0.286483,0.265788,0.238111,0.258982,0.250928,0.244811,0.221458,0.220769,0.283343,0.346612,0.640844
360178,0.345948,0.307158,0.347527,0.382970,0.431268,0.447351,0.472872,0.427386,0.469078,0.443052,0.432534,0.405102,0.478486,0.268397
360407,0.302457,0.283571,0.324616,0.305253,0.366046,0.329993,0.315409,0.323675,0.307080,0.274973,0.355095,0.361250,0.279596,0.225815
