In [None]:
import pickle
import nltk
from nltk.corpus import wordnet as wn
from nltk.corpus.reader.wordnet import Synset
nltk.data.path.append('../data')  # noqa
from tqdm.notebook import tqdm


with open('../data/sim_matrix.p', 'rb') as f:
    nominos, synsets, similarities = pickle.load(f)


values, indices = similarities.sort(dim=1, descending=True)

In [None]:
groups = []
for n_idx, nomino in enumerate(tqdm(nominos)):
    most_sim = values[n_idx][0].item()
    matches = []
    for k in range(len(indices)):
        if (v := values[n_idx][k].item()) < 0.5 or k > 10:
            break
        matches.append((v, (s := wn.synset(synsets[indices[n_idx, k]]))))
    if matches:
        groups.append((nomino, matches))

In [None]:
next(g for g in groups if g[0].entry == 'yahe')

In [None]:
from collections import Counter, defaultdict
from pprint import pprint

def get_hypers(xs: list[Synset], min_depth: int = 0, max_depth: int = -1) -> dict[Synset, float]:
    hypers = [h for x in xs for hs in x.hypernym_paths() for h in hs]
    return {hyper: count / len(hypers) for hyper, count in Counter(hypers).items() if count} 
    # hyperss = [{h for path in x.hypernym_paths() for h in path[min_depth:max_depth]} for x in xs]
    # counter = Counter([hyper for hypers in hyperss for hyper in hypers])
    # return {hyper: count / len(xs) for hyper, count in counter.items() if count >= 0.15 * len(xs)}


concord_to_synset_counts = defaultdict(lambda: defaultdict(lambda: 0))
synset_to_concord_counts = defaultdict(lambda: defaultdict(lambda: 0))
synset_global_counts = defaultdict(lambda: 0)

for nomino, matches in tqdm(groups):
    hypers = get_hypers([match for _, match in matches])
    for hyper, value in hypers.items():
        concord_to_synset_counts[nomino.subject_concord][hyper] = concord_to_synset_counts[nomino.subject_concord][hyper] + value
        synset_to_concord_counts[hyper][nomino.subject_concord] = synset_to_concord_counts[hyper][nomino.subject_concord] + value
        synset_global_counts[hyper] += 1
   

In [None]:
concord_to_synset_counts['u-/i-']

In [None]:
from math import log2 as log

threshold = 10

concord_norms = {concord: sum(vs.values()) for concord, vs in concord_to_synset_counts.items()}

synset_norms = {synset: sum(vs.values()) for synset, vs in synset_to_concord_counts.items()}
norm = sum(concord_norms.values())
log_norm = log(norm)

joint = {concord: {k: (log(v) - log_norm) for k, v in synset_counts.items() if synset_global_counts[k] > threshold} for concord, synset_counts in concord_to_synset_counts.items()}

pmi = defaultdict(lambda: defaultdict(lambda: 0))
for concord, synset_ps in joint.items():
    for synset, value in synset_ps.items():
        pmi[concord][synset] = (concord_to_synset_counts[concord][synset]/norm) * (value - (log(concord_norms[concord]) + log(synset_norms[synset]) - 2 * log_norm))

pmi = {concord: sorted([(k, v) for k, v in synset_counts.items()], key=lambda x: x[1], reverse=True) for concord, synset_counts in pmi.items()}

In [None]:
ks1 = {k for k, _ in pmi['i-/zi-'][:20]}
ks2 = {k for k, _ in pmi['ki-/vi-'][:20]}
ks3 = {k for k, _ in pmi['li-/ya-'][:20]}
ks4 = {k for k, _ in pmi['i-'][:20]}
ks5 = {k for k, _ in pmi['ya-'][:20]}

print(ks1&ks2&ks3)

In [None]:
def bold(x: float) -> str:
    if x < 1.:
        return f'{x:.1f}'
    return '\\textbf{' + f'{x:.1f}' + '}'


def pp(k: Synset, ks: list[Synset], v: float) -> str:

    
    if any (k0 in k.hypernyms() for k0 in ks):
        return '\\textcolor{gray!80}{' + k.name()+ f'~({bold(v * 100)})' + '}' 
    return k.name() + '~' + f'({bold(v * 100)})'
  

# print(f'{sum((v for concord in pmi.keys() for _, v in pmi[concord]))   :.3f}')

for concord in ['a-/wa-', 'i-/zi-', 'u-', 'ki-/vi-', 'u-/i-', 'li-/ya-', 'ya-', 'u-/zi-', 'i-']:
    print(concord)
    print(f'{sum((v for _, v in pmi[concord])) * 1 :.3f}')
    print(', '.join(
        pp(k, [k for k, _ in pmi[concord][:idx]], v) 
        for idx, (k, v) in enumerate(pmi[concord][:20])
        # if not any(k0 in k.hypernyms() for k0, _ in pmi[concord][:idx])
    ).replace('_', '\_'))

In [None]:
Synset.h