In [1]:
import argparse
import logging
import random
import pandas as pd
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
from sklearn.cluster import AffinityPropagation
import numpy as np
from tqdm import tqdm

In [40]:
input_file = 'fi_dev.tsv'
output_file = 'fi_dev_paraphrase_mean_nongreedy.tsv'

# both languages
# model_name = "setu4993/LEALLA-large"

## Finnish model
model_name = 'TurkuNLP/sbert-cased-finnish-paraphrase'

## Russian models
# model_name = 'siberian-lang-lab/evenki-russian-parallel-corpora'
# model_name = 'DeepPavlov/rubert-base-cased-sentence'
threshold = 0.3
result_file = 'fi_dev_rubert_mean_nongreedy_score.txt'

In [41]:
NEW_PERIOD = "new"
OLD_PERIOD = "old"
SENSE_ID_COLUMN = "sense_id"
USAGE_ID_COLUMN = "usage_id"
PERIOD_COLUMN = "period"

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True)
logging.basicConfig(level=logging.INFO)

In [42]:

def load_model(model_name):
    logging.info(f"Loading model {model_name} for sentence embeddings")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    model = model.eval()
    logging.info(f"Loaded model {model_name}")
    return tokenizer, model

tokenizer, model = load_model(model_name)


tokenizer_config.json:   0%|          | 0.00/353 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/424k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/498M [00:00<?, ?B/s]

In [43]:
# Getting representations for input sentences

def get_embeddings(sentences)  :
    tokenizer_kwargs = {
        "return_tensors": "pt",
        "padding": True,
        "truncation": True,
        "max_length": 256,
    }
    sent_inputs = tokenizer(sentences, **tokenizer_kwargs)
    with torch.no_grad():
        sent_outputs = model(**sent_inputs).pooler_output # new example
    return sent_outputs



#### cluster algorithm

In [44]:
def cluster_affinity(new_numpy, dist='euclidean'):
    if dist == 'euclidean':
      ap = AffinityPropagation(random_state=42,affinity = 'euclidean',max_iter= 500,verbose=True)
      clustering = ap.fit(new_numpy)

    elif dist == 'cos':
      ap = AffinityPropagation(random_state=42,affinity = 'precomputed',max_iter= 500,verbose=True)
      affi_matrix = np.zeros([new_numpy.shape[0],new_numpy.shape[0]])
      for i in range(new_numpy.shape[0]):
        for j in range(new_numpy.shape[0]):
          affi_matrix[i][j] = F.cosine_similarity(torch.tensor(new_numpy[i]),torch.tensor(new_numpy[j]),dim=0)
      clustering = ap.fit(affi_matrix)
    return clustering.labels_


#### NON-GREEDY algorithm

In [45]:
# this function returns a dictionary {new_example: senseID}
def get_exs2senses_from_cluster():
    unique_labels = np.unique(cluster_label)
    similarities = np.zeros((len(unique_labels), len(old_senses))) ## unique_senses problem
    for label in unique_labels: ## labels are 1,2,3,4...
        this_cluster = new_numpy[cluster_label == label]
        cluster_mean_vec = torch.Tensor(this_cluster.mean(axis=0))
        for sense_idx, gloss_embedding in enumerate(old_outputs):
            sim = F.cosine_similarity(cluster_mean_vec, gloss_embedding, dim=0)
            similarities[label, sense_idx] = sim

    # assign old senses to labels where sim > threshold
    exs2senses = {}
    for label in unique_labels:
            closest_sense_id = similarities[label].argmax()
            if similarities[label,closest_sense_id] > threshold:
                found_sense_id = old_senses[closest_sense_id]
            else:
                latin_name = old_senses[0].split("_")[0]
                found_sense_id = f"{latin_name}_novel_{label}"

            examples_indices = np.where(cluster_label == label)[0]
            examples = [new_examples[i] for i in examples_indices]
            for ex in examples:
                exs2senses[ex] = found_sense_id    ## key-value pairs {example: sense_id}
    return exs2senses

#### the main() method

In [46]:
targets = pd.read_csv(input_file, sep="\t")
for target_word in tqdm(targets.word.unique()):

    this_word = targets[targets.word == target_word]
    new,old = this_word[this_word[PERIOD_COLUMN] == NEW_PERIOD],this_word[this_word[PERIOD_COLUMN] == OLD_PERIOD]

    new_examples = new.example.to_list()

    old_senses = old[SENSE_ID_COLUMN].to_list()
    old_examples = old.example.to_list()
    old_gloss = old.gloss.to_list()

    old_sentence = [] ## old examples (or old gloss)
    for i, sense in enumerate(old_senses):
        if isinstance(old_examples[i],str):
          old_sentence.append(old_examples[i])
        else:
          old_sentence.append(old_gloss[i])

    # Get representations for the new examples and old examples/gloss
    new_outputs = get_embeddings(new_examples)
    old_outputs = get_embeddings(old_sentence)

    ## cluster algorithms
    new_numpy = new_outputs.detach().numpy()
    cluster_label = cluster_affinity(new_numpy)

    # Aligning the old and new senses, return {example: sense_id} key-value pairs
    exs2senses = get_exs2senses_from_cluster()

    new_usage_ids = new[USAGE_ID_COLUMN]
    assert len(new_examples) == new_usage_ids.shape[0] ## make sure the counts of new-sense are equal

    for usage_id, example in zip(new_usage_ids, new_examples):
        system_answer = exs2senses[example] # new sense
        row_number = targets[targets[USAGE_ID_COLUMN] == usage_id].index
        targets.loc[row_number, SENSE_ID_COLUMN] = system_answer

logging.info(f"Writing the result to {output_file}")
targets.to_csv(output_file, sep="\t", index=False)


  1%|          | 2/254 [00:03<08:22,  1.99s/it]

Converged after 17 iterations.


  1%|          | 3/254 [00:05<08:52,  2.12s/it]

Converged after 16 iterations.


  2%|▏         | 4/254 [00:09<10:34,  2.54s/it]

Converged after 23 iterations.


  2%|▏         | 5/254 [00:09<07:56,  1.91s/it]

Converged after 15 iterations.


  2%|▏         | 6/254 [00:21<21:23,  5.17s/it]

Converged after 20 iterations.


  3%|▎         | 7/254 [00:24<18:49,  4.57s/it]

Converged after 22 iterations.


  3%|▎         | 8/254 [00:27<16:27,  4.01s/it]

Converged after 21 iterations.


  4%|▎         | 9/254 [00:41<29:10,  7.15s/it]

Converged after 36 iterations.


  4%|▍         | 11/254 [00:47<21:12,  5.24s/it]

Converged after 20 iterations.


  5%|▍         | 12/254 [00:52<20:43,  5.14s/it]

Converged after 15 iterations.


  6%|▌         | 15/254 [00:55<09:45,  2.45s/it]

Converged after 16 iterations.


  7%|▋         | 17/254 [01:00<09:54,  2.51s/it]

Converged after 17 iterations.


  7%|▋         | 18/254 [01:01<08:26,  2.14s/it]

Converged after 15 iterations.


  7%|▋         | 19/254 [01:04<08:47,  2.24s/it]

Converged after 17 iterations.


  8%|▊         | 20/254 [01:05<07:14,  1.86s/it]

Converged after 65 iterations.


  9%|▉         | 23/254 [01:07<04:45,  1.24s/it]

Converged after 15 iterations.


 10%|█         | 26/254 [01:14<07:55,  2.08s/it]

Converged after 18 iterations.


 11%|█         | 27/254 [01:15<06:08,  1.62s/it]

Converged after 15 iterations.


 11%|█         | 28/254 [01:18<08:08,  2.16s/it]

Converged after 28 iterations.


 11%|█▏        | 29/254 [01:20<07:24,  1.97s/it]

Converged after 15 iterations.


 12%|█▏        | 30/254 [01:20<06:07,  1.64s/it]

Converged after 16 iterations.


 12%|█▏        | 31/254 [01:23<06:39,  1.79s/it]

Converged after 19 iterations.


 14%|█▍        | 35/254 [01:26<04:55,  1.35s/it]

Converged after 15 iterations.


 14%|█▍        | 36/254 [01:36<13:50,  3.81s/it]

Converged after 21 iterations.


 15%|█▍        | 37/254 [01:37<10:39,  2.95s/it]

Converged after 15 iterations.


 15%|█▌        | 39/254 [01:42<10:30,  2.93s/it]

Converged after 19 iterations.


 16%|█▌        | 40/254 [01:58<24:45,  6.94s/it]

Converged after 25 iterations.


 16%|█▌        | 41/254 [01:59<18:17,  5.15s/it]

Converged after 17 iterations.


 17%|█▋        | 42/254 [02:01<14:09,  4.01s/it]

Converged after 15 iterations.


 17%|█▋        | 43/254 [02:02<11:06,  3.16s/it]

Converged after 20 iterations.


 17%|█▋        | 44/254 [02:03<08:34,  2.45s/it]

Converged after 16 iterations.


 18%|█▊        | 45/254 [02:03<06:52,  1.97s/it]

Converged after 16 iterations.


 18%|█▊        | 46/254 [02:10<11:57,  3.45s/it]

Converged after 18 iterations.


 19%|█▉        | 48/254 [02:20<15:10,  4.42s/it]

Converged after 16 iterations.


 20%|█▉        | 50/254 [02:26<13:37,  4.01s/it]

Converged after 21 iterations.


 20%|██        | 51/254 [02:36<19:20,  5.72s/it]

Converged after 17 iterations.


 20%|██        | 52/254 [02:37<14:40,  4.36s/it]

Converged after 15 iterations.


 21%|██        | 53/254 [03:01<34:11, 10.21s/it]

Converged after 22 iterations.


 21%|██▏       | 54/254 [03:04<26:43,  8.02s/it]

Converged after 16 iterations.


 22%|██▏       | 55/254 [03:05<19:54,  6.00s/it]

Converged after 18 iterations.


 22%|██▏       | 56/254 [03:06<14:32,  4.41s/it]

Converged after 15 iterations.


 22%|██▏       | 57/254 [03:10<14:30,  4.42s/it]

Converged after 22 iterations.


 23%|██▎       | 58/254 [03:11<11:23,  3.49s/it]

Converged after 16 iterations.


 24%|██▎       | 60/254 [03:27<20:13,  6.25s/it]

Converged after 17 iterations.


 24%|██▍       | 62/254 [03:39<21:18,  6.66s/it]

Converged after 51 iterations.


 25%|██▍       | 63/254 [03:43<17:50,  5.60s/it]

Converged after 16 iterations.


 25%|██▌       | 64/254 [03:49<18:44,  5.92s/it]

Converged after 25 iterations.


 26%|██▌       | 65/254 [03:51<15:01,  4.77s/it]

Converged after 15 iterations.


 26%|██▌       | 66/254 [03:54<13:02,  4.16s/it]

Converged after 21 iterations.


 26%|██▋       | 67/254 [03:55<09:50,  3.16s/it]

Converged after 15 iterations.


 27%|██▋       | 68/254 [03:56<08:16,  2.67s/it]

Converged after 15 iterations.


 27%|██▋       | 69/254 [03:58<07:02,  2.29s/it]

Converged after 18 iterations.


 28%|██▊       | 70/254 [04:00<06:39,  2.17s/it]

Converged after 16 iterations.


 28%|██▊       | 71/254 [04:01<05:34,  1.83s/it]

Converged after 16 iterations.


 28%|██▊       | 72/254 [04:03<06:21,  2.10s/it]

Converged after 17 iterations.


 29%|██▊       | 73/254 [04:08<08:48,  2.92s/it]

Converged after 24 iterations.


 29%|██▉       | 74/254 [04:10<08:01,  2.67s/it]

Converged after 22 iterations.


 30%|██▉       | 75/254 [04:13<07:52,  2.64s/it]

Converged after 18 iterations.


 30%|███       | 77/254 [04:20<09:38,  3.27s/it]

Converged after 22 iterations.
Converged after 32 iterations.


 31%|███       | 79/254 [07:55<2:19:34, 47.86s/it]

Converged after 23 iterations.


 31%|███▏      | 80/254 [08:05<1:45:26, 36.36s/it]

Converged after 18 iterations.


 32%|███▏      | 81/254 [08:07<1:15:02, 26.02s/it]

Converged after 15 iterations.


 32%|███▏      | 82/254 [08:08<53:05, 18.52s/it]  

Converged after 67 iterations.


 33%|███▎      | 83/254 [08:12<40:50, 14.33s/it]

Converged after 37 iterations.


 33%|███▎      | 84/254 [08:18<33:23, 11.79s/it]

Converged after 16 iterations.


 33%|███▎      | 85/254 [08:20<24:49,  8.81s/it]

Converged after 16 iterations.


 34%|███▍      | 86/254 [08:24<20:25,  7.29s/it]

Converged after 19 iterations.


 34%|███▍      | 87/254 [08:25<15:09,  5.44s/it]

Converged after 17 iterations.


 35%|███▍      | 88/254 [08:42<24:38,  8.91s/it]

Converged after 18 iterations.


 35%|███▌      | 89/254 [08:43<18:12,  6.62s/it]

Converged after 33 iterations.


 35%|███▌      | 90/254 [08:44<13:14,  4.85s/it]

Converged after 33 iterations.


 36%|███▌      | 91/254 [08:45<09:40,  3.56s/it]

Converged after 15 iterations.


 36%|███▌      | 92/254 [08:45<07:24,  2.74s/it]

Converged after 30 iterations.


 37%|███▋      | 93/254 [08:47<06:18,  2.35s/it]

Converged after 21 iterations.


 37%|███▋      | 94/254 [08:48<05:39,  2.12s/it]

Converged after 25 iterations.


 38%|███▊      | 97/254 [08:53<04:43,  1.81s/it]

Converged after 23 iterations.


 39%|███▉      | 99/254 [08:55<04:17,  1.66s/it]

Converged after 16 iterations.


 39%|███▉      | 100/254 [08:57<04:22,  1.70s/it]

Converged after 15 iterations.


 40%|███▉      | 101/254 [08:58<03:50,  1.50s/it]

Converged after 17 iterations.


 40%|████      | 102/254 [09:02<05:10,  2.04s/it]

Converged after 25 iterations.


 41%|████      | 103/254 [09:03<05:00,  1.99s/it]

Converged after 17 iterations.


 41%|████▏     | 105/254 [09:07<04:38,  1.87s/it]

Converged after 17 iterations.


 42%|████▏     | 106/254 [09:09<04:49,  1.96s/it]

Converged after 16 iterations.


 42%|████▏     | 107/254 [09:12<05:31,  2.26s/it]

Converged after 17 iterations.
Converged after 27 iterations.


 43%|████▎     | 109/254 [09:53<23:51,  9.87s/it]

Converged after 19 iterations.


 44%|████▍     | 112/254 [09:55<09:21,  3.95s/it]

Converged after 15 iterations.


 44%|████▍     | 113/254 [09:57<07:17,  3.10s/it]

Converged after 15 iterations.


 45%|████▌     | 115/254 [10:03<07:36,  3.28s/it]

Converged after 21 iterations.


 46%|████▌     | 117/254 [10:04<04:43,  2.07s/it]

Converged after 16 iterations.


 47%|████▋     | 119/254 [10:06<02:56,  1.31s/it]

Converged after 15 iterations.
Converged after 34 iterations.


 48%|████▊     | 121/254 [10:47<21:41,  9.79s/it]

Converged after 19 iterations.


 48%|████▊     | 122/254 [10:50<16:34,  7.54s/it]

Converged after 19 iterations.


 48%|████▊     | 123/254 [10:51<12:13,  5.60s/it]

Converged after 15 iterations.


 49%|████▉     | 124/254 [10:53<10:12,  4.71s/it]

Converged after 16 iterations.


 49%|████▉     | 125/254 [10:55<07:59,  3.71s/it]

Converged after 21 iterations.


 50%|████▉     | 126/254 [10:56<06:12,  2.91s/it]

Converged after 15 iterations.


 50%|█████     | 127/254 [10:57<05:01,  2.38s/it]

Converged after 17 iterations.


 50%|█████     | 128/254 [11:11<12:17,  5.86s/it]

Converged after 22 iterations.


 51%|█████     | 130/254 [11:12<06:32,  3.17s/it]

Converged after 15 iterations.


 52%|█████▏    | 131/254 [11:15<06:05,  2.97s/it]

Converged after 17 iterations.


 52%|█████▏    | 132/254 [11:25<10:38,  5.23s/it]

Converged after 171 iterations.


 52%|█████▏    | 133/254 [11:39<15:47,  7.83s/it]

Converged after 17 iterations.


 53%|█████▎    | 134/254 [11:54<19:50,  9.92s/it]

Converged after 25 iterations.


 53%|█████▎    | 135/254 [12:02<18:48,  9.48s/it]

Converged after 17 iterations.


 54%|█████▎    | 136/254 [12:03<13:31,  6.87s/it]

Converged after 69 iterations.


 54%|█████▍    | 137/254 [12:06<11:17,  5.79s/it]

Converged after 69 iterations.


 54%|█████▍    | 138/254 [12:09<09:15,  4.79s/it]

Converged after 15 iterations.


 55%|█████▌    | 140/254 [12:12<06:00,  3.16s/it]

Converged after 18 iterations.


 56%|█████▌    | 142/254 [12:18<06:15,  3.35s/it]

Converged after 19 iterations.


 57%|█████▋    | 144/254 [12:21<04:35,  2.50s/it]

Converged after 22 iterations.


 57%|█████▋    | 146/254 [12:25<04:30,  2.51s/it]

Converged after 20 iterations.


 58%|█████▊    | 147/254 [12:28<04:22,  2.45s/it]

Converged after 19 iterations.


 58%|█████▊    | 148/254 [12:31<04:49,  2.73s/it]

Converged after 17 iterations.


 59%|█████▊    | 149/254 [12:38<06:55,  3.95s/it]

Converged after 19 iterations.


 59%|█████▉    | 150/254 [12:39<05:32,  3.20s/it]

Converged after 17 iterations.


 59%|█████▉    | 151/254 [12:41<04:56,  2.88s/it]

Converged after 18 iterations.


 60%|█████▉    | 152/254 [12:44<04:32,  2.67s/it]

Converged after 16 iterations.


 60%|██████    | 153/254 [12:45<04:00,  2.39s/it]

Converged after 17 iterations.


 61%|██████    | 154/254 [12:58<09:03,  5.44s/it]

Converged after 20 iterations.


 61%|██████    | 155/254 [13:02<08:31,  5.16s/it]

Converged after 32 iterations.


 61%|██████▏   | 156/254 [13:04<06:38,  4.07s/it]

Converged after 20 iterations.


 62%|██████▏   | 157/254 [13:05<05:05,  3.15s/it]

Converged after 19 iterations.


 62%|██████▏   | 158/254 [13:06<03:55,  2.45s/it]

Converged after 15 iterations.


 63%|██████▎   | 159/254 [13:10<04:52,  3.08s/it]

Converged after 20 iterations.


 63%|██████▎   | 161/254 [13:12<03:08,  2.03s/it]

Converged after 16 iterations.


 64%|██████▍   | 163/254 [13:14<02:22,  1.56s/it]

Converged after 15 iterations.


 65%|██████▍   | 164/254 [13:16<02:15,  1.50s/it]

Converged after 15 iterations.


 65%|██████▍   | 165/254 [13:21<03:52,  2.62s/it]

Converged after 18 iterations.


 65%|██████▌   | 166/254 [13:22<03:05,  2.11s/it]

Converged after 15 iterations.


 66%|██████▌   | 167/254 [13:26<03:53,  2.68s/it]

Converged after 40 iterations.


 66%|██████▌   | 168/254 [13:28<03:35,  2.51s/it]

Converged after 16 iterations.


 67%|██████▋   | 170/254 [13:30<02:27,  1.76s/it]

Converged after 21 iterations.


 67%|██████▋   | 171/254 [13:31<02:13,  1.61s/it]

Converged after 17 iterations.


 68%|██████▊   | 172/254 [13:47<07:52,  5.77s/it]

Converged after 19 iterations.


 68%|██████▊   | 173/254 [13:48<06:02,  4.48s/it]

Converged after 16 iterations.


 69%|██████▊   | 174/254 [13:50<04:50,  3.64s/it]

Converged after 15 iterations.


 69%|██████▉   | 175/254 [13:51<03:41,  2.81s/it]

Converged after 15 iterations.


 70%|███████   | 178/254 [13:58<03:55,  3.09s/it]

Converged after 19 iterations.


 70%|███████   | 179/254 [13:59<03:12,  2.57s/it]

Converged after 19 iterations.


 71%|███████   | 180/254 [14:07<05:07,  4.15s/it]

Converged after 17 iterations.


 72%|███████▏  | 182/254 [14:09<02:53,  2.42s/it]

Converged after 15 iterations.


 72%|███████▏  | 183/254 [14:09<02:12,  1.86s/it]

Converged after 37 iterations.


 73%|███████▎  | 186/254 [14:50<08:06,  7.15s/it]

Converged after 30 iterations.


 74%|███████▍  | 188/254 [15:06<08:54,  8.10s/it]

Converged after 22 iterations.


 75%|███████▍  | 190/254 [15:14<06:57,  6.52s/it]

Converged after 17 iterations.


 75%|███████▌  | 191/254 [15:16<05:29,  5.22s/it]

Did not converge
Converged after 34 iterations.


 76%|███████▌  | 193/254 [15:56<11:03, 10.88s/it]

Converged after 18 iterations.


 78%|███████▊  | 197/254 [16:03<04:28,  4.70s/it]

Converged after 20 iterations.


 78%|███████▊  | 198/254 [16:06<03:52,  4.16s/it]

Converged after 17 iterations.


 79%|███████▊  | 200/254 [16:12<03:13,  3.58s/it]

Converged after 24 iterations.


 79%|███████▉  | 201/254 [16:12<02:20,  2.66s/it]

Converged after 15 iterations.


 80%|███████▉  | 202/254 [16:13<01:49,  2.10s/it]

Converged after 16 iterations.


 80%|███████▉  | 203/254 [16:17<02:23,  2.82s/it]

Converged after 15 iterations.


 80%|████████  | 204/254 [16:58<11:41, 14.03s/it]

Converged after 22 iterations.


 81%|████████  | 206/254 [17:02<06:32,  8.17s/it]

Converged after 18 iterations.


 81%|████████▏ | 207/254 [17:07<05:29,  7.02s/it]

Converged after 20 iterations.


 82%|████████▏ | 208/254 [17:08<03:59,  5.21s/it]

Converged after 16 iterations.


 82%|████████▏ | 209/254 [17:10<03:23,  4.52s/it]

Converged after 19 iterations.


 83%|████████▎ | 210/254 [17:16<03:37,  4.95s/it]

Converged after 17 iterations.


 83%|████████▎ | 211/254 [17:18<02:48,  3.91s/it]

Converged after 18 iterations.


 83%|████████▎ | 212/254 [17:23<02:56,  4.21s/it]

Converged after 20 iterations.


 84%|████████▍ | 213/254 [17:26<02:37,  3.83s/it]

Converged after 20 iterations.


 84%|████████▍ | 214/254 [17:28<02:16,  3.41s/it]

Converged after 15 iterations.


 85%|████████▍ | 215/254 [17:31<02:08,  3.29s/it]

Converged after 17 iterations.


 85%|████████▌ | 217/254 [17:33<01:17,  2.11s/it]

Converged after 15 iterations.


 86%|████████▌ | 218/254 [17:35<01:09,  1.93s/it]

Converged after 17 iterations.


 87%|████████▋ | 220/254 [17:37<00:53,  1.58s/it]

Converged after 18 iterations.


 87%|████████▋ | 221/254 [17:38<00:46,  1.42s/it]

Converged after 16 iterations.


 88%|████████▊ | 223/254 [17:48<01:49,  3.52s/it]

Converged after 29 iterations.


 88%|████████▊ | 224/254 [17:49<01:22,  2.75s/it]

Converged after 22 iterations.


 89%|████████▉ | 226/254 [18:40<06:48, 14.58s/it]

Did not converge


 89%|████████▉ | 227/254 [18:44<05:08, 11.43s/it]

Converged after 16 iterations.


 90%|████████▉ | 228/254 [18:45<03:35,  8.28s/it]

Converged after 20 iterations.


 90%|█████████ | 229/254 [18:46<02:32,  6.11s/it]

Converged after 16 iterations.


 91%|█████████ | 230/254 [18:47<01:51,  4.66s/it]

Converged after 15 iterations.


 91%|█████████ | 231/254 [18:57<02:22,  6.18s/it]

Converged after 21 iterations.


 91%|█████████▏| 232/254 [19:10<02:59,  8.18s/it]

Converged after 157 iterations.


 92%|█████████▏| 233/254 [19:11<02:09,  6.16s/it]

Converged after 16 iterations.


 92%|█████████▏| 234/254 [19:18<02:06,  6.32s/it]

Converged after 21 iterations.


 93%|█████████▎| 235/254 [19:19<01:31,  4.80s/it]

Converged after 18 iterations.


 93%|█████████▎| 236/254 [19:21<01:11,  3.99s/it]

Converged after 20 iterations.


 93%|█████████▎| 237/254 [19:22<00:50,  2.98s/it]

Converged after 15 iterations.


 94%|█████████▎| 238/254 [19:30<01:12,  4.54s/it]

Converged after 22 iterations.


 94%|█████████▍| 239/254 [19:33<01:00,  4.02s/it]

Converged after 18 iterations.


 94%|█████████▍| 240/254 [19:40<01:09,  4.96s/it]

Converged after 18 iterations.


 95%|█████████▌| 242/254 [19:45<00:46,  3.84s/it]

Converged after 15 iterations.


 96%|█████████▌| 243/254 [19:46<00:33,  3.02s/it]

Converged after 17 iterations.


 96%|█████████▌| 244/254 [19:49<00:27,  2.80s/it]

Converged after 17 iterations.


 96%|█████████▋| 245/254 [19:50<00:22,  2.46s/it]

Converged after 27 iterations.


 97%|█████████▋| 247/254 [19:52<00:11,  1.63s/it]

Converged after 15 iterations.


 98%|█████████▊| 248/254 [19:59<00:19,  3.19s/it]

Converged after 30 iterations.


 98%|█████████▊| 249/254 [20:00<00:13,  2.69s/it]

Converged after 15 iterations.


 98%|█████████▊| 250/254 [20:02<00:09,  2.41s/it]

Converged after 15 iterations.


 99%|█████████▉| 251/254 [20:03<00:05,  1.98s/it]

Converged after 15 iterations.


 99%|█████████▉| 252/254 [20:04<00:03,  1.63s/it]

Converged after 15 iterations.


100%|██████████| 254/254 [20:14<00:00,  4.78s/it]

Converged after 18 iterations.





##Evaluation

In [47]:
### evaluation codes
import logging
import numpy as np
import pandas as pd
from sklearn.metrics import adjusted_rand_score, f1_score
from tqdm import tqdm

logging.basicConfig(
    format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)

gold_data = pd.read_csv(input_file, sep="\t")
predictions = pd.read_csv(output_file, sep="\t")

assert len(gold_data) == len(predictions)
assert (
        gold_data[gold_data.period == "new"].example.tolist()
        == predictions[predictions.period == "new"].example.tolist()
)

logger.info(f"Data loaded from {input_file} and {output_file}")
logger.info(f"{len(gold_data)} example usages")
logger.info("Computing Adjusted Rand Index and F1 for predicted senses...")

ari_scores = []
f1_scores = []

for targetword in tqdm(gold_data.word.unique()):
    gold_senses = gold_data[
        (gold_data.word == targetword) & (gold_data.period == "new")
        ].sense_id.values
    pred_senses = predictions[
        (predictions.word == targetword) & (predictions.period == "new")
        ].sense_id.values
    ari = adjusted_rand_score(gold_senses, pred_senses)
    ari_scores.append(ari)
    logger.debug(f"ARI for {targetword}: {ari}")
    old_senses = set(
        gold_data[(gold_data.word == targetword) & (gold_data.period == "old")]
            .sense_id.unique()
            .tolist()
    )
    if len(old_senses) == 0:
        logger.info(f"Not computing F1 for {targetword}: no old senses")
        continue
    test_usages = gold_data[
        (gold_data.word == targetword)
        & (gold_data.period == "new")
        & (gold_data.sense_id.isin(old_senses))
        ]
    test_usages_ids = set(test_usages.usage_id.tolist())
    if len(test_usages) == 0:
        test_usages_predicted = predictions[
            (predictions.word == targetword)
            & (predictions.period == "new")
            & (predictions.sense_id.isin(old_senses))
            ]
        if len(test_usages_predicted) == 0:
            f1_scores.append(1.0)
            logger.info(
                f"Macro F1 set to 1.0 for {targetword}: "
                f"no new usages with old senses, and none predicted"
            )
        else:
            f1_scores.append(0.0)
            logger.info(
                f"Macro F1 set to 0.0 for {targetword}: "
                f"old senses predicted when there are none"
            )
        continue
    test_usages_gold_senses = test_usages.sense_id.tolist()
    test_usages_predictions = predictions[predictions.usage_id.isin(test_usages_ids)]
    test_usages_predicted_senses = test_usages_predictions.sense_id.tolist()
    assert len(test_usages_gold_senses) == len(test_usages_predicted_senses)
    test_usages_predicted_senses = [
        "novel" if el not in old_senses else el
        for el in test_usages_predicted_senses
    ]
    f1 = f1_score(
        test_usages_gold_senses,
        test_usages_predicted_senses,
        average="macro",
        zero_division=0.0,
    )
    logger.debug(f"Macro F1 for {targetword}: {f1}")
    f1_scores.append(f1)

average_ari = np.mean(ari_scores)
logger.info(
    f"Average ARI across {len(ari_scores)} target words: {average_ari:0.3f}"
)
average_f1 = np.mean(f1_scores)
logger.info(
    f"Average macro-F1 across {len(f1_scores)} target words: {average_f1:0.3f}"
)


print(f"ARI: {average_ari:0.3f}")
print(f"F1: {average_f1:0.3f}")

# with open(result_file, "w") as out:

#     print(f"ARI: {average_ari:0.3f}", file=out)
#     print(f"F1: {average_f1:0.3f}", file=out)


100%|██████████| 254/254 [00:03<00:00, 80.35it/s]

ARI: 0.571
F1: 0.670





In [None]:


# def get_sentence_embeddings(outputs: ModelOutput, no_pooling: bool) -> torch.Tensor:
#   if no_pooling:
#       return outputs.last_hidden_state[:, 0, :]
#   return outputs.pooler_output
# from transformers import (
#     AutoModel,
#     AutoTokenizer,
#     BatchEncoding,
#     PreTrainedModel,
#     PreTrainedTokenizerFast,
# )
# from transformers.utils import ModelOutput

# def find_target_id(example: str, target: str, orth: str) -> int:
#     idx = example.lower().find(target)
#     if idx == -1:
#         idx = example.lower().find(orth)
#     return max(0, idx)

# def find_token_id(inputs: BatchEncoding, batch_id: int, target_id: int) -> int:
#     token_id = inputs.char_to_token(batch_id, target_id)
#     if token_id is None:
#         return 1
#     return token_id

# def get_target_embeddings(
#     inputs: BatchEncoding, outputs: ModelOutput, examples: list, target: str, orths: list
# ) -> torch.Tensor:
#     # Find the starting index of the target word in each example sentence; use 0 as a fallback
#     target_ids = [find_target_id(example, target, orth) for example, orth in zip(examples, orths)]
#     # Get the index of the 1st subtoken of the target word
#     token_ids = [find_token_id(inputs, i, j) for i, j in enumerate(target_ids)]
#     # Get the last hidden layer of this subtoken (or the baseline CLS token as a fallback)
    # return torch.stack([outputs.last_hidden_state[i, j, :] for i, j in enumerate(token_ids)])