### Model embeddings
This notebook creates sense embeddings using the pre-trained xl-lexeme model.

### Usage
Set the `source_file` variable in the second cell to the path of the file containing the context sentences created by the `extract_context.ipynb` script.
To create example embeddings, use the `get_average_vector` function in the fifth cell.

In [2]:
import json
import sys

sys.path.insert(1, '../xl-lexeme/WordTransformer')

from InputExample import InputExample
from WordTransformer import WordTransformer

model = WordTransformer('pierluigic/xl-lexeme') # load xl-lexeme model

In [8]:
source_file = '../data/outputs/dictionary_context/wordnet_sense_id_context_[3].json' 
output_file = f"../data/outputs/sense_embeddings/english/{source_file.split('/')[-1].split('.')[0]}_embeddings.json"

In [4]:
# get vector for word usage
def get_average_vector(inputs):
    vectors = model.encode(inputs) # get vectors for all inputs
    return sum(vectors) / len(vectors) # return average vector

In [9]:
# count total senses
with open(source_file) as f:
    data = json.load(f)
    total_senses = 0

    for d in data:
        for s in data[d]:
            total_senses += 1

In [10]:
# create sense embeddings
total_embeddings = 0

for entry in data:
    for sense in data[entry]:
        
        # get data from json
        usages = sense['usages']
        gloss = usages['usage']
        target = usages['target']
        
        inputs = InputExample(texts=gloss, positions=target) # create input example (see xl-lexeme documentation)
        #embedding = get_average_vector(inputs).tolist() # get average vector (for example embeddings)
        embedding = model.encode(inputs).tolist() # get vector (for gloss embeddings)
        total_embeddings += 1
        sense['embedding'] = embedding

    # print percentage done
    print(f"{total_embeddings / total_senses * 100:.2f} %", end='\r')

print(f"Total embeddings: {total_embeddings}")

Total embeddings: 117659


In [11]:
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)