In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')

import yaml
from hydra.experimental import compose, initialize_config_module
import hydra
import torch
from tqdm import tqdm
import json
import faiss
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, List, Dict, Any, Tuple

from bela.transforms.spm_transform import SPMTransform
from bela.evaluation.model_eval import ModelEval, load_file
from bela.utils.prediction_utils import get_predictions_using_windows

logger = logging.getLogger(__name__)
!cat /etc/hostname  # Double check that we are on a gpu node

a100-st-p4d24xlarge-18


In [2]:
@dataclass
class GroundTruthEntity:
    offset: int
    length: int
    text: str
    entity_id: str

    @property
    def mention(self):
        return self.text[self.offset : self.offset + self.length]

    def __repr__(self):
        return f"mention=\"{self.mention}\" -> entity_id={self.entity_id}"


@dataclass
class PredictedEntity:
    offset: int
    length: int
    text: str
    entity_id: str
    md_score: float
    el_score: float

    @property
    def mention(self):
        return self.text[self.offset : self.offset + self.length]

    def __repr__(self):
        return f"mention=\"{self.mention}\" -> entity_id={self.entity_id} (md_score={self.md_score:.2f}, el_score={self.el_score:.2f})"


def print_sample(text: str, ground_truth_entities: List[GroundTruthEntity], predicted_entities: List[PredictedEntity], max_display_length=1000):
    print(f"{text[:max_display_length]=}")
    print("***************** Ground truth entities *****************")
    print(f"{len(ground_truth_entities)=}")
    for ground_truth_entity in ground_truth_entities:
        if ground_truth_entity.offset + ground_truth_entity.length > max_display_length:
            continue
        print(ground_truth_entity)
    print("***************** Predicted entities *****************")
    print(f"{len(predicted_entities)=}")
    for predicted_entity in predicted_entities:
        if predicted_entity.offset + predicted_entity.length > max_display_length:
            continue
        print(predicted_entity)



# Load model

In [3]:
%%prun -s tottime -l 10

# e2e model with isotropic embeddings
checkpoint_path = '/checkpoints/movb/bela/2023-01-13-023711/0/lightning_logs/version_4144/checkpoints/last.ckpt'  # Not working: Unexpected key(s) in state_dict: "saliency_encoder.mlp.0.weight", "saliency_encoder.mlp.0.bias", "saliency_encoder.mlp.3.weight", "saliency_encoder.mlp.3.bias", "saliency_encoder.mlp.6.weight", "saliency_encoder.mlp.6.bias". 
checkpoint_path = '/checkpoints/movb/bela/2022-11-27-225013/0/lightning_logs/version_286287/checkpoints/last_15000.ckpt'  # Works but 0 F1

# E2E checkpoint with new embeddings
# https://fb.quip.com/QVUxA4UcAZ7k#temp:C:OcG977f71fab43d42379521a0dff
# Works but give 0 F1 on tackbp (mention detection is okish, but entity disambiguation is random)
checkpoint_path = '/checkpoints/movb/bela/2023-01-18-220105/0/lightning_logs/version_4820/checkpoints/last.ckpt'  
#model_eval = ModelEval(checkpoint_path, config_name="joint_el_mel_new")
model_eval = ModelEval(checkpoint_path, config_name="joint_el_mel_new_index")

Some weights of the model checkpoint at xlm-roberta-large were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


 

         170161569 function calls (170084442 primitive calls) in 110.106 seconds

   Ordered by: internal time
   List reduced from 6805 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 44771119   22.447    0.000   22.447    0.000 {method 'readline' of 'mmap.mmap' objects}
        3   22.347    7.449   50.189   16.730 joint_el_datamodule.py:47(__init__)
       67   12.799    0.191   12.802    0.191 serialization.py:994(load_tensor)
        1   12.300   12.300   12.300   12.300 {built-in method faiss._swigfaiss.GpuIndexFlat_add}
        1    9.243    9.243   10.246   10.246 joint_el_datamodule.py:21(__init__)
        1    4.084    4.084    4.084    4.084 {built-in method faiss._swigfaiss.new_GpuIndexFlatIP}
        1    3.742    3.742  110.110  110.110 <string>:2(<module>)
 61313958    3.383    0.000    3.383    0.000 {method 'append' of 'list' objects}
 44771116    2.724    0.000    2.724    0.000 {method 'tell' of 'mmap.mmap' obje

In [6]:
# Set low thresholds
model_eval.task.md_threshold = 0.01
model_eval.task.el_threshold = 0.01

# End-to-end Eval

In [11]:
print(f"{model_eval.checkpoint_path=}")
print(f"{model_eval.task.md_threshold=}")
print(f"{model_eval.task.el_threshold=}")
datasets = [
    "/fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl",
    #'/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    test_data = load_file(test_data_path)
    #test_data = test_data[:1000]
    
    predictions = get_predictions_using_windows(model_eval, test_data)
    (f1, precision, recall), (f1_boe, precision_boe, recall_boe) = ModelEval.compute_scores(test_data, predictions)
    
    print(f"F1 = {f1:.4f}, precision = {precision:.4f}, recall = {recall:.4f}")
    print(f"F1 boe = {f1_boe:.4f}, precision = {precision_boe:.4f}, recall = {recall_boe:.4f}")

model_eval.checkpoint_path='/checkpoints/movb/bela/2023-01-18-220105/0/lightning_logs/version_4820/checkpoints/last.ckpt'
model_eval.task.md_threshold=0.01
model_eval.task.el_threshold=0.01
Processing /fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl


269it [00:00, 9825.29it/s]
 50%|████████████████████████████████                                | 2/4 [01:01<01:01, 30.51s/it]

# Disambiguation Eval

In [None]:

def shift_shift(text):
    for idx,ch in enumerate(text):
        if not ch.isalpha():
            return idx

def convert_data_for_disambiguation(data, lang):
    # convert examples to 1 entity per example and shift if needed
    if lang=='ar':
        MAX_LENGTH = 600
        MAX_OFFSET = 400
    elif lang == 'ja':
        MAX_LENGTH = 350
        MAX_OFFSET = 250
    else:
        MAX_LENGTH = 800
        MAX_OFFSET = 600
    new_examples = []
    for example in tqdm(data):
        original_text = example['original_text']
        for _, _, ent, _, offset, length in example['gt_entities']:
            shift = 0
            if len(original_text) > MAX_LENGTH and offset > MAX_OFFSET:
                shift = (offset - MAX_OFFSET)
                shift += shift_shift(original_text[shift:])
            new_example = {
                'original_text': original_text[shift:],
                'gt_entities': [[0,0,ent,_,offset-shift,length]],
            }
            new_examples.append(new_example)
    return new_examples


def metrics_disambiguation(test_data, predictions):
    support = 0
    support_only_predicted = 0
    correct = 0
    incorrect_pos = 0

    for example_idx, (example, prediction) in tqdm(enumerate(zip(test_data, predictions))):
#         targets = {
#             (offset,length):ent_id
#             for _,_,ent_id,_,offset,length in example['gt_entities']
#         }
#         prediction = {
#             (offset,length):ent_id
#             for offset,length,ent_id in zip(prediction['offsets'], prediction['lengths'], prediction['entities'])
#         }

#         support += len(targets)
#         support_only_predicted += len(prediction)
        
#         correct += sum(1 for pos,ent_id in prediction.items() if (pos in targets and targets[pos] == ent_id))
#         incorrect_pos += sum(1 for pos,_ in prediction.items() if pos not in targets)
        if len(prediction['entities']) == 0:
            continue
        target = example['gt_entities'][0][2]
        prediction = prediction['entities'][0]
        correct += (target == prediction)
        support += 1

    accuracy = correct/support
    # accuracy_only_predicted = correct/support_only_predicted

    return accuracy, support #, accuracy_only_predicted, support_only_predicted

In [None]:
datasets = [
    '/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    lang = test_data_path[-8:-6]
    test_data = load_file(test_data_path)
    test_data = convert_data_for_disambiguation(test_data[:10000], lang)
    predictions = model_eval.get_disambiguation_predictions(test_data)
    accuracy, support = metrics_disambiguation(test_data, predictions)
    print(f"Accuracty {accuracy}, support {support}")

# Inference time

In [6]:
def batch_samples(samples, batch_size):
    # Yield batches of samples
    for i in range(0, len(samples), batch_size):
        yield samples[i : i + batch_size]


texts = [sample["original_text"] for sample in load_file("/fsx/movb/data/matcha/mewsli-9/en.jsonl")]
batch_size = 1024
print(f"Processing {len(texts)} texts, {batch_size=}, {model_eval.transform.max_seq_len=}")
%time _ = [model_eval.process_batch(batch_texts) for batch_texts in tqdm(batch_samples(texts, batch_size), desc="Inference")]

12679it [00:00, 70671.85it/s]


Processing 12679 texts, batch_size=1024, model_eval.transform.max_seq_len=256


13it [04:00, 18.53s/it]

CPU times: user 11min 50s, sys: 35.3 s, total: 12min 25s
Wall time: 4min





## Eyeball Samples

In [8]:
text = "Taylor Swift lives in New York City. New York City is a city in the United States."
prediction = model_eval.process_batch([text])[0]
predicted_entities = [PredictedEntity(offset, length, text, entity_id, md_score, el_score) for offset, length, entity_id, md_score, el_score in zip(prediction["offsets"], prediction["lengths"], prediction["entities"], prediction["md_scores"], prediction["el_scores"])]
print_sample(text, [], predicted_entities)

text[:max_display_length]='Taylor Swift lives in New York City. New York City is a city in the United States.'
***************** Ground truth entities *****************
len(ground_truth_entities)=0
***************** Predicted entities *****************
len(predicted_entities)=19
mention="Taylor" -> entity_id=Q26876 (md_score=0.01, el_score=0.03)
mention="Taylor Swift lives" -> entity_id=Q26876 (md_score=0.58, el_score=0.78)
mention="Taylor Swift lives in New York City. New" -> entity_id=Q26876 (md_score=0.04, el_score=0.53)
mention="Swift lives" -> entity_id=Q26876 (md_score=0.02, el_score=0.06)
mention="lives" -> entity_id=Q26876 (md_score=0.03, el_score=0.01)
mention="New York City. New" -> entity_id=Q60 (md_score=0.13, el_score=0.19)
mention="York City" -> entity_id=Q60 (md_score=0.01, el_score=0.01)
mention="York City." -> entity_id=Q60 (md_score=0.04, el_score=0.03)
mention="York City. New" -> entity_id=Q60 (md_score=0.41, el_score=0.45)
mention="York City. New York City is" -> en

In [10]:

test_data_path = "/fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl"
print(f"Processing {test_data_path}")
test_data = load_file(test_data_path)
sample = test_data[200]
prediction = get_predictions_using_windows(model_eval, [sample])[0]
text = sample["original_text"]
max_length = 1024

ground_truth_entities = [GroundTruthEntity(offset, length, text, entity_id) for _, _, entity_id, _, offset, length in sample["gt_entities"]]
predicted_entities = [PredictedEntity(offset, length, text, entity_id, md_score, el_score) for offset, length, entity_id, md_score, el_score in zip(prediction["offsets"], prediction["lengths"], prediction["entities"], prediction["md_scores"], prediction["el_scores"])]
print_sample(text, ground_truth_entities, predicted_entities, max_display_length=1000)

Processing /fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl


269it [00:00, 8626.63it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.87it/s]

text[:max_display_length]='<DOC id="ENG_NW_001048_20150228_F0000001S"> <SOURCE>http://www.telegraph.co.uk/news/worldnews/europe/russia/11441729/David-Cameron-says-callous-murder-of-Boris-Nemtsov-must-be-rapidly-investigated.html</SOURCE> <DATE_TIME>2015-02-28T00:00:00</DATE_TIME> <HEADLINE> David Cameron says \'callous murder\' of Boris Nemtsov must be rapidly investigated </HEADLINE> <TEXT> <P> Prime Minister says "callous murder" of Russian opposition politician "must be fully, rapidly and transparently investigated, and those responsible brought to justice" </P> <P> David Cameron has said he is "shocked and sickened" by the murder of Boris Nemtsov. </P> <P> The Prime Minister said the "callous" killing of the Russian opposition politician "must be fully, rapidly and transparently investigated, and those responsible brought to justice". </P> <P> Mr Nemtsov, a leading critic of president Vladimir Putin and of the war in Ukraine, was gunned down near the Kremlin on the eve of a major r




# Draft