In [None]:
import nltk
from nltk.corpus import wordnet as wn
nltk.data.path.append('../data')  # noqa

all_synsets = wn.all_synsets(pos=wn.NOUN)
all_synsets = sorted(all_synsets, key=lambda x: len(x.definition()))
print(f'Read {len(all_synsets)} synsets.')

In [None]:
from tokenization import get_tokenizer, encode
print('Encoding...')
tokenizer = get_tokenizer('sentence-transformers/all-MiniLM-L12-v2')
tokenized_synsets = [(synset, encode(tokenizer, f'{synset.lemmas()[0].name()}, {synset.definition()}')) for synset in all_synsets]
print(f'Done.')

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

DEVICE = 'cpu'

cfn = (
    lambda xs:
    ([s for s, _ in xs],
     pad_sequence([torch.tensor(x) for _, x in xs], batch_first=True, padding_value=0).to(DEVICE)
     )
)

synset_dl = DataLoader(
    tokenized_synsets,
    shuffle=False,
    batch_size=1024,
    collate_fn=cfn
)

from model_wrappers import SBERT
from tqdm.notebook import tqdm

vectorizer = SBERT().to(DEVICE)
vectorizer.eval()

print('Vectorizing synsets...')
with torch.no_grad():
    synset_vectors = [
        (synset.name(), vector)
        for synsets, xs in tqdm(synset_dl)
        for synset, vector in zip(synsets, vectorizer(xs).cpu())
    ]

In [None]:
import pickle

with open('../data/tokenized.p', 'rb') as f:
    (tokenized_nominos, _) = pickle.load(f)
    tokenized_nominos = [(x, y) for x, y, _ in tokenized_nominos]
print(f'Read {len(tokenized_nominos)} nominos.')

nomino_dl = DataLoader(
    tokenized_nominos,
    shuffle=False,
    batch_size=1024,
    collate_fn=cfn
)

print('Vectorizing nominos...')
with torch.no_grad():
    nomino_vectors = [
        (nomino, vector)
        for nominos, xs in tqdm(nomino_dl)
        for nomino, vector in zip(nominos, vectorizer(xs).cpu())
    ]

In [None]:
similarities = torch.stack([x for _, x in nomino_vectors]) @ torch.stack([x for _, x in synset_vectors]).t()


with open('../data/sim_matrix.p', 'wb') as f:
    pickle.dump(([n for n, _ in nomino_vectors], [s for s, _ in synset_vectors], similarities), f)
