In [1]:
from libs.corp_df import *
from libs.get_docs import get_pickled_docs as gpd
import numpy as np
import torch
from sklearn.metrics import silhouette_score
from tqdm.auto import tqdm
import random
import pandas as pd
import itertools

In [2]:
# retrieve documents
print('Retrieving Documents...')
cdir = 'corpora/articles_stage2.pickle'
docs = gpd(cdir, verbose=True)

Retrieving Documents...
TIME ELAPSED: 30.32s


In [3]:
roots_raw = []
print('getting raw roots...')
for doc in tqdm(docs):
    for pg in doc:
        for sent in pg:
            for tok in sent:
                for lem in tok:
                    if lem.pos_tag in ['VB', 'BN', 'BNT']:
                        if lem.shoresh is not None:
                            roots_raw.append(lem.shoresh)
                            
# filter by minimum # occurrences
print('filtering...')
valid_roots = set([i for i in tqdm(roots_raw) if roots_raw.count(i) > 19])
print(len(valid_roots))

getting raw roots...


  0%|          | 0/524 [00:00<?, ?it/s]

filtering...


  0%|          | 0/26165 [00:00<?, ?it/s]

308


In [4]:
def select_roots(i):
    return random.sample(valid_roots, i)

def get_df(num_roots):
    roots = select_roots(num_roots)
    cats = ['lemma', 'root', 'binyan', 'raw_embedding']
    dfdict = {cat: [] for cat in cats}
    
    for doc in docs:
        for pg in doc:
            for sent in pg:
                for tok in sent:
                    if tok.tokenizer_index != 1:
                        for lem in tok:
                            if lem.shoresh in roots:
                                # get metadata
                                dfdict['lemma'].append(lem.lemma)
                                dfdict['root'].append(lem.shoresh)
                                dfdict['binyan'].append(lem.binyan)
                                dfdict['raw_embedding'].append(tok.embedding)
    
    return pd.DataFrame(dfdict)

def get_summed_layers(embedding, layers):
    if type(layers) == int:
        return embedding[layers].numpy()
    else:
        return sum([embedding[l].numpy() for l in layers]).numpy()

def roots_to_num(roots):
    unique_roots = set(roots)
    return {root: num for root, num in zip(unique_roots, range(len(unique_roots)))}

def silhouette_exp(df, layers):
    # so there's no data issues
    cdf = df.copy()
    # get embeddings for each layer
    cdf['chosen_embedding'] = cdf['raw_embedding'].apply(lambda embed: get_summed_layers(embed, layers))
    # get root IDs
    cdf['root_id'] = cdf['root'].map(roots_to_num(cdf['root']))
    embeds = cdf['chosen_embedding'].to_list()
    labs = cdf['root_id'].to_numpy()
    # get silhouette score
    score = silhouette_score(embeds, labs, metric='cosine')
    
    del cdf
    return score

In [5]:
results = {'exp_id': [], 'roots': [], 'best_layer': []}
    
for i in range(12):
    results['silhouette_layer_' + str(i)] = []

NUM_ROOTS = 5
for i in tqdm(range(100)):
    results['exp_id'].append(i)
    df = get_df(NUM_ROOTS)
    results['roots'].append(df['root'].unique())
    scores = []
    for layer in range(12):
        score = silhouette_exp(df, layer)
        scores.append((layer, score))
        results['silhouette_layer_' + str(layer)].append(score)
        
    results['best_layer'].append(max(scores, key=lambda j: j[1])[0])
    
resultsdf = pd.DataFrame(results)

  0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
print(resultsdf['best_layer'].unique())

[0]


In [7]:
resultsdf.sort_values(by='silhouette_layer_0', ascending=False, inplace=True)
resultsdf.head()

Unnamed: 0,exp_id,roots,best_layer,silhouette_layer_0,silhouette_layer_1,silhouette_layer_2,silhouette_layer_3,silhouette_layer_4,silhouette_layer_5,silhouette_layer_6,silhouette_layer_7,silhouette_layer_8,silhouette_layer_9,silhouette_layer_10,silhouette_layer_11
23,23,"[צרר, צרך, כרז, ברך, מחש]",0,0.662512,0.632522,0.60361,0.561553,0.542485,0.537145,0.52649,0.513411,0.498907,0.475186,0.461631,0.459922
42,42,"[חסל, צלח, דמה, כנע, טהר]",0,0.471508,0.441891,0.420218,0.386346,0.360357,0.360402,0.338064,0.308588,0.289905,0.260138,0.264804,0.278537
16,16,"[צבע, רצה, כשל, קנה, זהר]",0,0.468705,0.438921,0.409975,0.367003,0.342957,0.333371,0.324255,0.304,0.285999,0.267309,0.253588,0.263429
21,21,"[חוש, פרט, שלם, נכר, כונ]",0,0.465632,0.443504,0.41324,0.38974,0.368373,0.356841,0.3497,0.341256,0.327098,0.314876,0.309766,0.314474
90,90,"[סיע, פסד, הרס, תקן, נוע]",0,0.46259,0.419125,0.37133,0.331106,0.302964,0.298957,0.287007,0.276205,0.262493,0.247598,0.238293,0.244467
