In [15]:
from libs.corp_df import *
from libs.get_docs import get_pickled_docs as gpd
import numpy as np
import torch
from sklearn.metrics import silhouette_score
from tqdm.auto import tqdm
import random
import pandas as pd
import itertools

In [2]:
# retrieve documents
print('Retrieving Documents...')
cdir = 'corpora/articles_stage2.pickle'
docs = gpd(cdir, verbose=True)

Retrieving Documents...
TIME ELAPSED: 30.13s


In [3]:
roots_raw = []
print('getting raw roots...')
for doc in tqdm(docs):
    for pg in doc:
        for sent in pg:
            for tok in sent:
                for lem in tok:
                    if lem.pos_tag in ['VB', 'BN', 'BNT']:
                        if lem.shoresh is not None:
                            roots_raw.append(lem.shoresh)
                            
# filter by minimum # occurrences
print('filtering...')
valid_roots = set([i for i in tqdm(roots_raw) if roots_raw.count(i) > 19])
print(len(valid_roots))

getting raw roots...


  0%|          | 0/524 [00:00<?, ?it/s]

filtering...


  0%|          | 0/26165 [00:00<?, ?it/s]

308


In [136]:
def select_roots(i):
    return random.sample(valid_roots, i)

def get_df(num_roots):
    roots = select_roots(num_roots)
    cats = ['lemma', 'root', 'binyan', 'raw_embedding']
    dfdict = {cat: [] for cat in cats}
    
    for doc in docs:
        for pg in doc:
            for sent in pg:
                for tok in sent:
                    if tok.tokenizer_index != 1:
                        for lem in tok:
                            if lem.shoresh in roots:
                                # get metadata
                                dfdict['lemma'].append(lem.lemma)
                                dfdict['root'].append(lem.shoresh)
                                dfdict['binyan'].append(lem.binyan)
                                dfdict['raw_embedding'].append(tok.embedding)
    
    return pd.DataFrame(dfdict)

def get_summed_layers(embedding, layers):
    if type(layers) == int:
        return embedding[layers].numpy()
    else:
        return sum([embedding[l].numpy() for l in layers]).numpy()

def roots_to_num(roots):
    unique_roots = set(roots)
    return {root: num for root, num in zip(unique_roots, range(len(unique_roots)))}

def silhouette_exp(df, layers):
    # so there's no data issues
    cdf = df.copy()
    # get embeddings for each layer
    cdf['chosen_embedding'] = cdf['raw_embedding'].apply(lambda embed: get_summed_layers(embed, layers))
    # get root IDs
    cdf['root_id'] = cdf['root'].map(roots_to_num(cdf['root']))
    embeds = cdf['chosen_embedding'].to_list()
    labs = cdf['root_id'].to_numpy()
    # get silhouette score
    score = silhouette_score(embeds, labs, metric='cosine')
    
    del cdf
    return score

In [150]:
results = {'exp_id': [], 'roots': [], 'best_layer': []}
    
for i in range(12):
    results['silhouette_layer_' + str(i)] = []

NUM_ROOTS = 5
for i in tqdm(range(100)):
    results['exp_id'].append(i)
    df = get_df(NUM_ROOTS)
    results['roots'].append(df['root'].unique())
    scores = []
    for layer in range(12):
        score = silhouette_exp(df, layer)
        scores.append((layer, score))
        results['silhouette_layer_' + str(layer)].append(score)
        
    results['best_layer'].append(max(scores, key=lambda j: j[1])[0])
    
resultsdf = pd.DataFrame(results)

  0%|          | 0/100 [00:00<?, ?it/s]

In [156]:
print(resultsdf['best_layer'].unique())

[0]


In [159]:
resultsdf.sort_values(by='silhouette_layer_0', ascending=False, inplace=True)
resultsdf.head()

Unnamed: 0,exp_id,roots,best_layer,silhouette_layer_0,silhouette_layer_1,silhouette_layer_2,silhouette_layer_3,silhouette_layer_4,silhouette_layer_5,silhouette_layer_6,silhouette_layer_7,silhouette_layer_8,silhouette_layer_9,silhouette_layer_10,silhouette_layer_11
86,86,"[צרר, חוש, צרף, חרב, עין]",0,0.56286,0.540777,0.536892,0.515488,0.497242,0.487823,0.472279,0.463946,0.443638,0.429091,0.41743,0.41708
67,67,"[כול, עזז, חשש, אבה, דבק]",0,0.558901,0.51408,0.480959,0.44591,0.422282,0.411944,0.390666,0.376121,0.351288,0.332329,0.332153,0.334097
14,14,"[כרז, שים, תנה, רגש, ברך]",0,0.527245,0.487096,0.454471,0.418222,0.395167,0.385192,0.367691,0.351244,0.33035,0.310055,0.306229,0.315689
23,23,"[סכן, גיס, רחש, איש]",0,0.480137,0.444539,0.414725,0.385807,0.363707,0.361524,0.34607,0.332891,0.313137,0.290194,0.289118,0.297084
52,52,"[יחס, לחם, הוה, כרז, קרע]",0,0.478021,0.440877,0.411828,0.377016,0.346842,0.341148,0.335202,0.331443,0.318781,0.306709,0.287714,0.291596
