In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')

import yaml
from hydra.experimental import compose, initialize_config_module
import hydra
import torch
from tqdm import tqdm
import json
import faiss
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, List, Dict, Any, Tuple

from bela.transforms.spm_transform import SPMTransform
from bela.evaluation.model_eval import ModelEval, load_file
from bela.utils.prediction_utils import get_predictions_using_windows

logger = logging.getLogger(__name__)
!cat /etc/hostname  # Double check that we are on a gpu node

a100-st-p4d24xlarge-91


In [15]:
@dataclass
class GroundTruthEntity:
    offset: int
    length: int
    text: str
    entity_id: str

    @property
    def mention(self):
        return self.text[self.offset : self.offset + self.length]

    def __repr__(self):
        return f"mention=\"{self.mention}\" -> entity_id={self.entity_id}"

    def __eq__(self, other):
        return self.offset == other.offset and self.length == other.length and self.entity_id == other.entity_id


@dataclass
class PredictedEntity:
    offset: int
    length: int
    text: str
    entity_id: str
    md_score: float
    el_score: float

    @property
    def mention(self):
        return self.text[self.offset : self.offset + self.length]

    def __repr__(self):
        return f"mention=\"{self.mention}\" -> entity_id={self.entity_id} (md_score={self.md_score:.2f}, el_score={self.el_score:.2f})"

    def __eq__(self, other):
        return self.offset == other.offset and self.length == other.length and self.entity_id == other.entity_id




def print_sample(text: str, ground_truth_entities: List[GroundTruthEntity], predicted_entities: List[PredictedEntity], max_display_length=1000):
    print(f"{text[:max_display_length]=}")
    print("***************** Ground truth entities *****************")
    print(f"{len(ground_truth_entities)=}")
    for ground_truth_entity in ground_truth_entities:
        if ground_truth_entity.offset + ground_truth_entity.length > max_display_length:
            continue
        print(ground_truth_entity)
    print("***************** Predicted entities *****************")
    print(f"{len(predicted_entities)=}")
    for predicted_entity in predicted_entities:
        if predicted_entity.offset + predicted_entity.length > max_display_length:
            continue
        print(predicted_entity)

# Load model

In [3]:
%%prun -s tottime -l 10

# e2e model with isotropic embeddings
checkpoint_path = '/checkpoints/movb/bela/2023-01-13-023711/0/lightning_logs/version_4144/checkpoints/last.ckpt'  # Not working: Unexpected key(s) in state_dict: "saliency_encoder.mlp.0.weight", "saliency_encoder.mlp.0.bias", "saliency_encoder.mlp.3.weight", "saliency_encoder.mlp.3.bias", "saliency_encoder.mlp.6.weight", "saliency_encoder.mlp.6.bias". 
checkpoint_path = '/checkpoints/movb/bela/2022-11-27-225013/0/lightning_logs/version_286287/checkpoints/last_15000.ckpt'  # Works but 0 F1

# E2E checkpoint with new embeddings
# https://fb.quip.com/QVUxA4UcAZ7k#temp:C:OcG977f71fab43d42379521a0dff
# Works but give 0 F1 on tackbp (mention detection is okish, but entity disambiguation is random)
checkpoint_path = '/checkpoints/movb/bela/2023-01-18-220105/0/lightning_logs/version_4820/checkpoints/last.ckpt'  
# Overfit on one sample from /fsx/louismartin/bela/data/debug_mention_detection/mel/train.1st.1_sample.txt
checkpoint_path = '/data/home/louismartin/dev/BELA/multirun/2023-02-14/08-45-13/0/lightning_logs/version_0/checkpoints/checkpoint_best.ckpt'
model_eval = ModelEval(checkpoint_path, config_name="joint_el_mel_new_index")

path='bela.transforms.joint_el_transform.JointELXlmrRawTextTransform', mod='bela.transforms.joint_el_transform'
path='bela.datamodule.joint_el_datamodule.JointELDataModule', mod='bela.datamodule'
path='bela.datamodule.joint_el_datamodule.JointELDataModule', mod='bela.datamodule.joint_el_datamodule'
path='bela.task.joint_el_task.JointELTask', mod='bela.task'
path='bela.task.joint_el_task.JointELTask', mod='bela.task.joint_el_task'
path='bela.models.hf_encoder.HFEncoder', mod='bela.models'
path='bela.models.hf_encoder.HFEncoder', mod='bela.models.hf_encoder'


Some weights of the model checkpoint at xlm-roberta-large were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


 

         170163130 function calls (170086003 primitive calls) in 149.769 seconds

   Ordered by: internal time
   List reduced from 6809 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 44771119   38.025    0.000   38.025    0.000 {method 'readline' of 'mmap.mmap' objects}
        3   29.437    9.812   74.120   24.707 joint_el_datamodule.py:47(__init__)
        5   15.115    3.023   15.116    3.023 serialization.py:994(load_tensor)
        1   13.817   13.817   13.817   13.817 {built-in method faiss._swigfaiss.GpuIndexFlat_add}
        1   11.451   11.451   12.661   12.661 joint_el_datamodule.py:21(__init__)
        1    4.650    4.650  149.773  149.773 <string>:2(<module>)
 61314008    4.082    0.000    4.082    0.000 {method 'append' of 'list' objects}
        1    3.884    3.884    3.884    3.884 {built-in method faiss._swigfaiss.new_GpuIndexFlatIP}
 44771116    3.524    0.000    3.524    0.000 {method 'tell' of 'mmap.mmap' obje

In [6]:
# Set low thresholds
model_eval.task.md_threshold = 0.01
model_eval.task.el_threshold = 0.01

# End-to-end Eval

In [10]:
model_eval.task.md_threshold = 0.2
model_eval.task.el_threshold = 0.4
print(f"{model_eval.checkpoint_path=}")
print(f"{model_eval.task.md_threshold=}")
print(f"{model_eval.task.el_threshold=}")
datasets = [
    #"/fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl",
    "/fsx/louismartin/bela/data/debug_mention_detection/mel/eval.1st.1_sample.txt",  # Overfit on one sample
    #'/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    ##'/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    #'/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    test_data = load_file(test_data_path)
    test_data = test_data[:100]
    
    predictions = get_predictions_using_windows(model_eval, test_data)
    (f1, precision, recall), (f1_boe, precision_boe, recall_boe) = ModelEval.compute_scores(test_data, predictions)
    #model_results = ModelResults(test_data, predictions)
    #(f1, precision, recall), (f1_boe, precision_boe, recall_boe) = model_results.compute_scores()
    
    print(f"F1 = {f1:.4f}, precision = {precision:.4f}, recall = {recall:.4f}")
    print(f"F1 boe = {f1_boe:.4f}, precision = {precision_boe:.4f}, recall = {recall_boe:.4f}")

model_eval.checkpoint_path='/data/home/louismartin/dev/BELA/multirun/2023-02-14/08-45-13/0/lightning_logs/version_0/checkpoints/checkpoint_best.ckpt'
model_eval.task.md_threshold=0.2
model_eval.task.el_threshold=0.4
Processing /fsx/louismartin/bela/data/debug_mention_detection/mel/eval.1st.1_sample.txt


78it [00:00, 30116.52it/s]
100%|███████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.50s/it]


F1 = 0.0000, precision = 0.0000, recall = 0.0000
F1 boe = 0.7179, precision = 0.8750, recall = 0.6087


# Disambiguation Eval

In [None]:

def shift_shift(text):
    for idx,ch in enumerate(text):
        if not ch.isalpha():
            return idx

def convert_data_for_disambiguation(data, lang):
    # convert examples to 1 entity per example and shift if needed
    if lang=='ar':
        MAX_LENGTH = 600
        MAX_OFFSET = 400
    elif lang == 'ja':
        MAX_LENGTH = 350
        MAX_OFFSET = 250
    else:
        MAX_LENGTH = 800
        MAX_OFFSET = 600
    new_examples = []
    for example in tqdm(data):
        original_text = example['original_text']
        for _, _, ent, _, offset, length in example['gt_entities']:
            shift = 0
            if len(original_text) > MAX_LENGTH and offset > MAX_OFFSET:
                shift = (offset - MAX_OFFSET)
                shift += shift_shift(original_text[shift:])
            new_example = {
                'original_text': original_text[shift:],
                'gt_entities': [[0,0,ent,_,offset-shift,length]],
            }
            new_examples.append(new_example)
    return new_examples


def metrics_disambiguation(test_data, predictions):
    support = 0
    support_only_predicted = 0
    correct = 0
    incorrect_pos = 0

    for example_idx, (example, prediction) in tqdm(enumerate(zip(test_data, predictions))):
#         targets = {
#             (offset,length):ent_id
#             for _,_,ent_id,_,offset,length in example['gt_entities']
#         }
#         prediction = {
#             (offset,length):ent_id
#             for offset,length,ent_id in zip(prediction['offsets'], prediction['lengths'], prediction['entities'])
#         }

#         support += len(targets)
#         support_only_predicted += len(prediction)
        
#         correct += sum(1 for pos,ent_id in prediction.items() if (pos in targets and targets[pos] == ent_id))
#         incorrect_pos += sum(1 for pos,_ in prediction.items() if pos not in targets)
        if len(prediction['entities']) == 0:
            continue
        target = example['gt_entities'][0][2]
        prediction = prediction['entities'][0]
        correct += (target == prediction)
        support += 1

    accuracy = correct/support
    # accuracy_only_predicted = correct/support_only_predicted

    return accuracy, support #, accuracy_only_predicted, support_only_predicted

In [None]:
datasets = [
    '/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    lang = test_data_path[-8:-6]
    test_data = load_file(test_data_path)
    test_data = convert_data_for_disambiguation(test_data[:10000], lang)
    predictions = model_eval.get_disambiguation_predictions(test_data)
    accuracy, support = metrics_disambiguation(test_data, predictions)
    print(f"Accuracty {accuracy}, support {support}")

## Eyeball Samples

In [24]:
text = "Her name is Taylor Swift. New York City is a city in the United States."
text = "My dog is a Shiba Inu. Taylor Swift is a singer."
text = "Her name is Taylor Swift."
text = "He is the Dalai Lama. He is a Buddhist monk."
prediction = model_eval.process_batch([text])[0]
predicted_entities = [PredictedEntity(offset, length, text, entity_id, md_score, el_score) for offset, length, entity_id, md_score, el_score in zip(prediction["offsets"], prediction["lengths"], prediction["entities"], prediction["md_scores"], prediction["el_scores"])]
print_sample(text, [], predicted_entities)

text[:max_display_length]='He is the Dalai Lama. He is a Buddhist monk.'
***************** Ground truth entities *****************
len(ground_truth_entities)=0
***************** Predicted entities *****************
len(predicted_entities)=3
mention="Dalai Lama." -> entity_id=Q37349 (md_score=0.25, el_score=0.01)
mention="i Lama." -> entity_id=Q37349 (md_score=0.37, el_score=0.05)
mention="Buddhist" -> entity_id=Q748 (md_score=0.22, el_score=0.01)


In [19]:

test_data_path = "/fsx/louismartin/bela/retrieved_from_aws_backup/ndecao/TACKBP2015/train_bela_format.jsonl"
test_data_path = "/fsx/louismartin/bela/data/debug_mention_detection/mel/eval.1st.1_sample.txt"  # Overfit on one sample
print(f"Processing {test_data_path}")
test_data = load_file(test_data_path)
#sample = test_data[200]
sample = test_data[0]
prediction = get_predictions_using_windows(model_eval, [sample])[0]
text = sample["original_text"]
max_length = 1024

ground_truth_entities = [GroundTruthEntity(offset, length, text, entity_id) for _, _, entity_id, _, offset, length in sample["gt_entities"]]
predicted_entities = [PredictedEntity(offset, length, text, entity_id, md_score, el_score) for offset, length, entity_id, md_score, el_score in zip(prediction["offsets"], prediction["lengths"], prediction["entities"], prediction["md_scores"], prediction["el_scores"])]
print_sample(text, ground_truth_entities, predicted_entities, max_display_length=100000)

Processing /fsx/louismartin/bela/data/debug_mention_detection/mel/eval.1st.1_sample.txt


78it [00:00, 20122.75it/s]
100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.73it/s]

text[:max_display_length]="   Martina Steuk (née Kämpfert; born 11 November 1959) is a German former track and field athlete who represented East Germany. She competed in the 800 metres and occasionally the 400 metres.  Her first success came at the 1977 European Athletics Junior Championships, where she won the 800 m title. At twenty years old, reached the final at the 1980 Moscow Olympics and placed fourth in a lifetime best time of 1:56.21 behind a Soviet trio of Nadiya Olizarenko, Olga Mineyeva and Tatyana Providokhina that broke the world record. A successful 1981 season followed, which included a win at the 1981 European Cup, and at the 1981 IAAF World Cup she won an 800 m silver behind Lyudmila Veselkova and a gold with the East German 4 × 400 metres relay team. She ranked second in the world that year for the 800 m behind Vesselkova with a time of 1:57.16. That success continued into 1982 with a silver medal at the 1982 European Athletics Indoor Championships, finishing second 




# Debug mention detection bug

In [8]:
# Set low thresholds
model_eval.task.md_threshold = 0.1
model_eval.task.el_threshold = 0.1
text = "Taylor Swift is a singer."
prediction = model_eval.process_batch([text])[0]
predicted_entities = [PredictedEntity(offset, length, text, entity_id, md_score, el_score) for offset, length, entity_id, md_score, el_score in zip(prediction["offsets"], prediction["lengths"], prediction["entities"], prediction["md_scores"], prediction["el_scores"])]
print_sample(text, [], predicted_entities)

text[:max_display_length]='Taylor Swift is a singer.'
***************** Ground truth entities *****************
len(ground_truth_entities)=0
***************** Predicted entities *****************
len(predicted_entities)=1
mention="Taylor Swift" -> entity_id=Q26876 (md_score=0.15, el_score=0.02)


## SPM Encode and Decode

In [25]:
text = "400 metres.  Her"
print(model_eval.transform.processor.encode(text))
print(model_eval.transform.processor.decode([4, 1839]))

[4081, 106383, 4, 1839]
. Her


# Inference time

In [6]:
def batch_samples(samples, batch_size):
    # Yield batches of samples
    for i in range(0, len(samples), batch_size):
        yield samples[i : i + batch_size]


texts = [sample["original_text"] for sample in load_file("/fsx/movb/data/matcha/mewsli-9/en.jsonl")]
batch_size = 1024
print(f"Processing {len(texts)} texts, {batch_size=}, {model_eval.transform.max_seq_len=}")
%time _ = [model_eval.process_batch(batch_texts) for batch_texts in tqdm(batch_samples(texts, batch_size), desc="Inference")]

12679it [00:00, 70671.85it/s]


Processing 12679 texts, batch_size=1024, model_eval.transform.max_seq_len=256


13it [04:00, 18.53s/it]

CPU times: user 11min 50s, sys: 35.3 s, total: 12min 25s
Wall time: 4min





# Draft

In [None]:

class Sample:
    text: str
    ground_truth_entities: List[GroundTruthEntity]
    predicted_entities: List[PredictedEntity]

    def __init__(self, text, ground_truth_entities, predicted_entities):
        self.text = text
        self.ground_truth_entities = ground_truth_entities
        self.predicted_entities = predicted_entities
        self._compute_scores()

    def _compute_scores(self):
        self.true_positives = [predicted_entity for predicted_entity in self.predicted_entities if predicted_entity in self.ground_truth_entities]
        self.false_positives = [predicted_entity for predicted_entity in self.predicted_entities if predicted_entity not in self.ground_truth_entities]
        self.false_negatives = [ground_truth_entity for ground_truth_entity in self.ground_truth_entities if ground_truth_entity not in self.predicted_entities]
        # Bag of entities
        gold_entity_ids = set([ground_truth_entity.entity_id for ground_truth_entity in self.ground_truth_entities])
        predicted_entity_ids = set([predicted_entity.entity_id for predicted_entity in self.predicted_entities])
        self.true_positives_boe = [predicted_entity_id for predicted_entity_id in predicted_entity_ids if predicted_entity_id in gold_entity_ids]
        self.false_positives_boe = [predicted_entity_id for predicted_entity_id in predicted_entity_ids if predicted_entity_id not in gold_entity_ids]
        self.false_negatives_boe = [ground_truth_entity_id for ground_truth_entity_id in gold_entity_ids if ground_truth_entity_id not in predicted_entity_ids]


    def print(self, max_display_length=1000):
        print(f"{self.text[:max_display_length]=}")
        print("***************** Ground truth entities *****************")
        print(f"{len(self.ground_truth_entities)=}")
        for ground_truth_entity in self.ground_truth_entities:
            if ground_truth_entity.offset + ground_truth_entity.length > max_display_length:
                continue
            print(ground_truth_entity)
        print("***************** Predicted entities *****************")
        print(f"{len(self.predicted_entities)=}")
        for predicted_entity in self.predicted_entities:
            if predicted_entity.offset + predicted_entity.length > max_display_length:
                continue
            print(predicted_entity)



class ModelResults:
    def __init__(self, data, predictions, md_threshold=0.2, el_threshold=0.05, verbose=False):
        self.data = data
        self.predictions = predictions
        self.md_threshold = md_threshold
        self.el_threshold = el_threshold
        self.verbose = verbose
        self.samples = []
        self._compute_scores()

    def _compute_scores(self):
        self.samples = []
        for ground_truth_sample, predicted_sample in zip(self.data, self.predictions):
            ground_truth_entities = [
                GroundTruthEntity(
                    offset=offset,
                    length=length,
                    text=ground_truth_sample['original_text'],
                    entity_id=ent_id,
                )
                for offset, length, ent_id, _, _, _ in ground_truth_sample['gt_entities']
            ]
            predicted_entities = [
                PredictedEntity(
                    offset=offset,
                    length=length,
                    text=ground_truth_sample['original_text'],
                    entity_id=ent_id,
                    md_score=md_score,
                    el_score=el_score,
                )
                for offset, length, ent_id, md_score, el_score in zip(
                    predicted_sample['offsets'],
                    predicted_sample['lengths'],
                    predicted_sample['entities'],
                    predicted_sample['md_scores'],
                    predicted_sample['el_scores'],
                )
                if (el_score > self.el_threshold and md_score > self.md_threshold)
            ]
            sample = Sample(
                text=ground_truth_sample['original_text'],
                ground_truth_entities=ground_truth_entities,
                predicted_entities=predicted_entities,
            )
            self.samples.append(sample)
        
        def safe_division(numerator, denominator):
            return numerator / denominator if denominator > 0 else 0.0

        self.precision = safe_division(
            sum([len(sample.true_positives) for sample in self.samples]),
            sum([len(sample.predicted_entities) for sample in self.samples]),
        )
        self.recall = safe_division(
            sum([len(sample.true_positives) for sample in self.samples]),
            # TODO: Check that we can't predict the same entity twice
            sum([len(sample.ground_truth_entities) for sample in self.samples]),
        )
        self.f1 = safe_division(2 * self.precision * self.recall, self.precision + self.recall)
        self.precision_boe = safe_division(
            sum([len(sample.true_positives_boe) for sample in self.samples]),
            sum([len(sample.predicted_entities) for sample in self.samples]),
        )
        self.recall_boe = safe_division(
            sum([len(sample.true_positives_boe) for sample in self.samples]),
            sum([len(sample.ground_truth_entities) for sample in self.samples]),
        )
        self.f1_boe = safe_division(2 * self.precision_boe * self.recall_boe, self.precision_boe + self.recall_boe)
        #return (f1, precision, recall), (f1_boe, precision_boe, recall_boe)
