In [1]:
import warnings
warnings.filterwarnings('ignore')

import yaml
from hydra.experimental import compose, initialize_config_module
import hydra
import torch
from tqdm import tqdm
import json
import faiss
import logging
from collections import defaultdict

from typing import Optional, List, Dict, Any, Tuple

from bela.transforms.spm_transform import SPMTransform

logger = logging.getLogger(__name__)

ModuleNotFoundError: No module named 'bela'

In [2]:
# text_encodings = data["text_encodings"].to(device)
# text_pad_mask = data["batch"]["attention_mask"].to(device)
# gold_mention_offsets = data["batch"]["mention_offsets"].to(device)
# gold_mention_lengths = data["batch"]["mention_lengths"].to(device)
# entities_ids = data["batch"]["entities"].to(device)
# tokens_mapping = data["batch"]["tokens_mapping"].to(device)

In [None]:
# Set git editor to vim
!git config --global core.editor vim

In [9]:
def load_file(path):
    all_data = []
    with open(path, 'rt') as fd:
        for line in tqdm(fd):
            data = json.loads(line)
            all_data.append(data)
    return all_data


def convert_sp_to_char_offsets(
    text: str,
    sp_offsets: List[int],
    sp_lengths: List[int],
    sp_tokens_boundaries: List[List[int]],
) -> Tuple[List[int], List[int]]:
    """
    Function convert sentecepiece offsets and lengths to character level
    offsets and lengths for a given `text`.
    """
    
    char_offsets: List[int] = []
    char_lengths: List[int] = []
    text_utf8_chars: List[str] = [char for char in text]

    for sp_offset, sp_length in zip(sp_offsets, sp_lengths):
        # sp_offsets include cls_token, while boundaries doesn't
        if sp_offset == 0:
            continue

        sp_offset = sp_offset - 1
        char_offset = sp_tokens_boundaries[sp_offset][0]
        char_end = sp_tokens_boundaries[sp_offset + sp_length - 1][1]

        # sp token boundaries include whitespaces, so remove them
        while text_utf8_chars[char_offset].isspace():
            char_offset += 1
            assert char_offset < len(text_utf8_chars)

        char_offsets.append(char_offset)
        char_lengths.append(char_end - char_offset)

    return char_offsets, char_lengths
    

class ModelEval:
    def __init__(self, checkpoint_path, config_name="joint_el_mel"):
        self.device = torch.device("cuda:0")
        
        logger.info("Create task")
        with initialize_config_module("bela/conf"):
            cfg = compose(config_name=config_name)
            
        self.transform = hydra.utils.instantiate(cfg.task.transform)
        datamodule = hydra.utils.instantiate(cfg.datamodule, transform=self.transform)
        self.task = hydra.utils.instantiate(cfg.task, datamodule=datamodule, _recursive_=False)
        
        self.task.setup("train")
        self.task = self.task.eval()
        self.task = self.task.to(self.device)
        self.embeddings = self.task.embeddings
        self.faiss_index = self.task.faiss_index
        
        # logger.info("Create GPU index")
        # self.create_gpu_index()
        
        logger.info("Create ent index")
        self.ent_idx = []
        for ent in datamodule.ent_catalogue.idx:
            self.ent_idx.append(ent)
        
        logger.info("Load checkpoint")
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        self.task.load_state_dict(checkpoint["state_dict"])
        
    def create_gpu_index(self, gpu_id=0):
        flat_config = faiss.GpuIndexFlatConfig()
        flat_config.device = gpu_id
        flat_config.useFloat16 = True

        res = faiss.StandardGpuResources()

        self.faiss_index = faiss.GpuIndexFlatIP(res, embeddings.shape[1], flat_config)
        self.faiss_index.add(self.embeddings)
        
    def lookup(
        self,
        query: torch.Tensor,
    ):
        scores, indices = self.faiss_index.search(query, k=1)

        return scores.squeeze(-1).to(self.device), indices.squeeze(-1).to(self.device)
    
    def process_batch(self, texts): 
        batch: Dict[str, Any] = {"texts": texts}
        model_inputs = self.transform(batch)

        token_ids = model_inputs["input_ids"].to(self.device)
        text_pad_mask = model_inputs["attention_mask"].to(self.device)
        tokens_mapping = model_inputs["tokens_mapping"].to(self.device)
        sp_tokens_boundaries = model_inputs["sp_tokens_boundaries"].tolist()

        with torch.no_grad():
            _, last_layer = self.task.encoder(token_ids)
            text_encodings = last_layer
            text_encodings = self.task.project_encoder_op(text_encodings)

            mention_logits, mention_bounds = self.task.mention_encoder(
                text_encodings, text_pad_mask, tokens_mapping
            )

            (
                chosen_mention_logits,
                chosen_mention_bounds,
                chosen_mention_mask,
                mention_pos_mask,
            ) = self.task.mention_encoder.prune_ctxt_mentions(
                mention_logits,
                mention_bounds,
                num_cand_mentions=50,
                threshold=self.task.md_threshold,
            )

            mention_offsets = chosen_mention_bounds[:, :, 0]
            mention_lengths = (
                chosen_mention_bounds[:, :, 1] - chosen_mention_bounds[:, :, 0] + 1
            )
            mention_lengths[mention_offsets == 0] = 0

            mentions_repr = self.task.span_encoder(
                text_encodings, mention_offsets, mention_lengths
            )

            # flat mentions and entities indices (mentions_num x embedding_dim)
            flat_mentions_repr = mentions_repr[mention_lengths != 0]
            mentions_scores = torch.sigmoid(chosen_mention_logits)

            # retrieve candidates top-1 ids and scores
            cand_scores, cand_indices = self.lookup(
                flat_mentions_repr.detach()
            )

            entities_repr = self.embeddings[cand_indices].to(self.device)

            chosen_mention_limits: List[int] = (
                chosen_mention_mask.int().sum(-1).detach().cpu().tolist()
            )
            flat_mentions_scores = mentions_scores[mention_lengths != 0].unsqueeze(-1)
            cand_scores = cand_scores.unsqueeze(-1)

            el_scores = torch.sigmoid(
                self.task.el_encoder(
                    flat_mentions_repr,
                    entities_repr,
                    flat_mentions_scores,
                    cand_scores,
                )
            ).squeeze(1)

        predictions = []
        cand_idx = 0
        example_idx = 0
        for offsets, lengths, md_scores in zip(
            mention_offsets, mention_lengths, mentions_scores
        ):
            ex_sp_offsets = []
            ex_sp_lengths = []
            ex_entities = []
            ex_md_scores = []
            ex_el_scores = []
            for offset, length, md_score in zip(offsets, lengths, md_scores):
                if length != 0:
                    if md_score >= self.task.md_threshold:
                        ex_sp_offsets.append(offset.detach().cpu().item())
                        ex_sp_lengths.append(length.detach().cpu().item())
                        ex_entities.append(self.ent_idx[cand_indices[cand_idx].detach().cpu().item()])
                        ex_md_scores.append(md_score.item())       
                        ex_el_scores.append(el_scores[cand_idx].item())     
                    cand_idx += 1

            char_offsets, char_lengths = convert_sp_to_char_offsets(
                texts[example_idx],
                ex_sp_offsets,
                ex_sp_lengths,
                sp_tokens_boundaries[example_idx],
            )

            predictions.append(
                {
                    "offsets": char_offsets,
                    "lengths": char_lengths,
                    "entities": ex_entities,
                    "md_scores": ex_md_scores,
                    "el_scores": ex_el_scores,
                }
            )
            example_idx += 1

        return predictions
    
    def process_disambiguation_batch(self, texts, mention_offsets, mention_lengths, entities):
        batch: Dict[str, Any] = {
            "texts": texts,
            "mention_offsets": mention_offsets,
            "mention_lengths": mention_lengths,
            "entities": entities,
        }
        model_inputs = self.transform(batch)

        token_ids = model_inputs["input_ids"].to(self.device)
        mention_offsets = model_inputs["mention_offsets"]
        mention_lengths = model_inputs["mention_lengths"]
        tokens_mapping = model_inputs["tokens_mapping"].to(self.device)
        sp_tokens_boundaries = model_inputs["sp_tokens_boundaries"].tolist()

        with torch.no_grad():
            _, last_layer = self.task.encoder(token_ids)
            text_encodings = last_layer
            text_encodings = self.task.project_encoder_op(text_encodings)

            mentions_repr = self.task.span_encoder(
                text_encodings, mention_offsets, mention_lengths
            )

            flat_mentions_repr = mentions_repr[mention_lengths != 0]
            # retrieve candidates top-1 ids and scores
            cand_scores, cand_indices = self.lookup(
                flat_mentions_repr.detach()
            )
            predictions = []
            cand_idx = 0
            example_idx = 0
            for offsets, lengths in zip(
                mention_offsets, mention_lengths,
            ):
                ex_sp_offsets = []
                ex_sp_lengths = []
                ex_entities = []
                ex_dis_scores = []
                for offset, length in zip(offsets, lengths):
                    if length != 0:
                        ex_sp_offsets.append(offset.detach().cpu().item())
                        ex_sp_lengths.append(length.detach().cpu().item())
                        ex_entities.append(self.ent_idx[cand_indices[cand_idx].detach().cpu().item()])
                        ex_dis_scores.append(cand_scores[cand_idx].detach().cpu().item())           
                        cand_idx += 1
                        
                # print("ex_sp_offsets", ex_sp_offsets)
                # print("ex_sp_lengths", ex_sp_lengths)

                char_offsets, char_lengths = convert_sp_to_char_offsets(
                    texts[example_idx],
                    ex_sp_offsets,
                    ex_sp_lengths,
                    sp_tokens_boundaries[example_idx],
                )
                
                # print("char_offsets", char_offsets)
                # print("char_lengths", char_lengths)

                predictions.append({
                    "offsets": char_offsets,
                    "lengths": char_lengths,
                    "entities": ex_entities,
                    "scores": ex_dis_scores
                })
                example_idx+= 1

        return predictions
    
    def get_predictions(self, test_data, batch_size=256):
        all_predictions = []
        for batch_start in tqdm(range(0,len(test_data),batch_size)):
            batch = test_data[batch_start:batch_start+batch_size]
            texts = [example['original_text'] for example in batch]
            predictions = self.process_batch(texts)
            all_predictions.extend(predictions)
        return all_predictions
    
    def get_disambiguation_predictions(self, test_data, batch_size=256):
        all_predictions = []
        for batch_start in tqdm(range(0,len(test_data),batch_size)):
            batch = test_data[batch_start:batch_start+batch_size]
            texts = [example['original_text'] for example in batch]
            mention_offsets = [[offset for _,_,_,_,offset,_ in example['gt_entities']] for example in batch]
            mention_lengths = [[length for _,_,_,_,_,length in example['gt_entities']] for example in batch]
            entities = [[0 for _,_,_,_,_,_ in example['gt_entities']] for example in batch]

            predictions = self.process_disambiguation_batch(texts, mention_offsets, mention_lengths, entities)
            all_predictions.extend(predictions)
        return all_predictions

In [8]:
!pip install -e .

[0mObtaining file:///data/home/louismartin/dev/BELA
[31mERROR: file:///data/home/louismartin/dev/BELA does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

In [6]:
#from bela.evaluation.model_eval import ModelEval  # Use overriden ModelEval class
# checkpoint_path = '/checkpoints/movb/bela/2022-11-27-225013/0/lightning_logs/version_286287/checkpoints/last_15000.ckpt'
# checkpoint_path = '/checkpoints/movb/bela/2023-01-13-023711/0/lightning_logs/version_4144/checkpoints/last.ckpt'

# e2e model with isotropic embeddings
checkpoint_path = '/checkpoints/movb/bela/2023-01-18-220105/0/lightning_logs/version_4820/checkpoints/last.ckpt'
model_eval = ModelEval(checkpoint_path, config_name="joint_el_mel_new")

MissingConfigException: Primary config module 'bela.conf' not found.
Check that it's correct and contains an __init__.py file

In [5]:
sp_transform = SPMTransform(max_seq_len=100000)

In [6]:
def get_windows(text):
    tokens = sp_transform([text])[0]
    tokens = tokens[1:-1]
    window_length = 254
    windows = []
    for window_start in range(0,len(tokens),window_length//2):
        start_pos = tokens[window_start][1]
        if window_start + window_length >= len(tokens):
            end_pos = tokens[-1][2]
        else:
            end_pos = tokens[window_start + window_length][2]
        windows.append((start_pos, end_pos))
    return windows

In [7]:
def convert_predictions_to_dict(example_predictions):
    predictions = []
    if len(example_predictions) > 0:
        offsets, lengths, entities, md_scores, el_scores = zip(*example_predictions) 
    else:
        offsets, lengths, entities, md_scores, el_scores = [], [], [], [], []
    return {
        'offsets': offsets,
        'lengths': lengths,
        'entities': entities,
        'md_scores': md_scores,
        'el_scores': el_scores,
    }

In [8]:
def group_predictions_by_example(all_predictions, extended_examples):
    grouped_predictions = defaultdict(list)
    for prediction, extended_example in zip(all_predictions, extended_examples):
        window_start = extended_example['window_start']
        prediction = dict(prediction)
        prediction['offsets'] = [offset + window_start for offset in prediction['offsets']]
        grouped_predictions[extended_example['document_id']].append((
            prediction
        ))
    
    predictions = {}
    for document_id, example_prediction_list in grouped_predictions.items():
        example_predictions = []
        for prediction in example_prediction_list:
            for offset,length,ent,md_score,el_score in zip(
                prediction['offsets'],
                prediction['lengths'],
                prediction['entities'],
                prediction['md_scores'],
                prediction['el_scores'],
            ):
                example_predictions.append((offset,length,ent,md_score,el_score))
                example_predictions = sorted(example_predictions)
        predictions[document_id] = example_predictions
        
    return predictions

In [9]:
def merge_predictions(example_predictions):
    filtered_example_predictions = []

    current_end = None
    current_offset = None
    current_length = None
    current_ent_id = None
    current_md_score = None
    current_el_score = None

    for offset,length,ent_id,md_score,el_score in example_predictions:
        if current_end is None:
            current_end = offset + length
            current_offset = offset
            current_length = length
            current_ent_id = ent_id
            current_md_score = md_score
            current_el_score = el_score
            continue

        if offset < current_end:
            # intersection of two predictions
            if md_score > current_md_score:
                current_ent_id = ent_id
                current_offset = offset
                current_length = length
                current_md_score = md_score
                current_el_score = el_score
        else:
            filtered_example_predictions.append((
                current_offset,
                current_length,
                current_ent_id,
                current_md_score,
                current_el_score
            ))
            current_ent_id = ent_id
            current_offset = offset
            current_length = length
            current_md_score = md_score
            current_el_score = el_score

        current_end = offset+length

    if current_offset is not None:
        filtered_example_predictions.append((
            current_offset,
            current_length,
            current_ent_id,
            current_md_score,
            current_el_score
        ))
    
    return filtered_example_predictions

In [10]:
def correct_mention_offsets(example_predictions, text):
    TEXT_SEPARATORS = [' ','.',',','!','?','-','\n']
    corrected_example_predictions = []
    for offset,length,ent_id,md_score,el_score in example_predictions:
        while offset !=0 and offset<len(text) and (text[offset-1] not in TEXT_SEPARATORS or text[offset] in TEXT_SEPARATORS):
            offset += 1
            length -= 1
        while offset+length < len(text) and text[offset+length] not in TEXT_SEPARATORS:
            length += 1
        corrected_example_predictions.append((
            offset,length,ent_id,md_score,el_score
        ))
    return corrected_example_predictions

In [11]:
def get_predictions_using_windows(test_data):
    extended_examples = []

    for example in test_data:
        text = example['original_text']
        windows = get_windows(text)
        for idx, (start_pos, end_pos) in enumerate(windows):
            new_text = text[start_pos:end_pos]
            extended_examples.append({
                'document_id': example['document_id'],
                'original_text': new_text,
                'gt_entities': example['gt_entities'],
                'window_idx': idx,
                'window_start': start_pos,
                'window_end': end_pos,
            })

    all_predictions = model_eval.get_predictions(extended_examples)
    predictions_dict = group_predictions_by_example(all_predictions, extended_examples)

    predictions = []
    for example in test_data:
        document_id = example['document_id']
        text = example['original_text']
        example_predictions = predictions_dict[document_id]
        example_predictions = merge_predictions(example_predictions)
        example_predictions = correct_mention_offsets(example_predictions, text)
        example_predictions = convert_predictions_to_dict(example_predictions)
        predictions.append(example_predictions)
        
    return predictions

In [12]:
def compute_scores(data, predictions, md_threshold=0.2, el_threshold=0.05):
    tp, fp, support = 0, 0, 0
    tp_boe, fp_boe, support_boe = 0, 0, 0

    predictions_per_example = []
    for example, example_predictions in zip(data, predictions):

        example_targets = {
            (offset,length):ent_id
            for _,_,ent_id,_,offset,length in example['gt_entities']
        }

        example_predictions = {
            (offset, length):ent_id
            for offset, length, ent_id, md_score, el_score in zip(
                example_predictions['offsets'],
                example_predictions['lengths'],
                example_predictions['entities'],
                example_predictions['md_scores'],
                example_predictions['el_scores'],
            )
            if (el_score > el_threshold and md_score > md_threshold) 
        }

        predictions_per_example.append((len(example_targets), len(example_predictions)))

        for pos, ent in example_targets.items():
            support += 1
            if pos in example_predictions and example_predictions[pos] == ent:
                tp += 1
        for pos, ent in example_predictions.items():
            if pos not in example_targets or example_targets[pos] != ent:
                fp += 1

        example_targets_set = set(example_targets.values())
        example_predictions_set = set(example_predictions.values())

        for ent in example_targets_set:
            support_boe += 1
            if ent in example_predictions_set:
                tp_boe += 1
        for ent in example_predictions_set:
            if ent not in example_targets_set:
                fp_boe += 1


    def compute_f1_p_r(tp, fp, fn):
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * tp / (2 * tp + fp + fn)
        return f1, precision, recall

    fn = support - tp
    fn_boe = support_boe - tp_boe
    return compute_f1_p_r(tp, fp, fn), compute_f1_p_r(tp_boe, fp_boe, fn_boe)

In [13]:
# predictions = get_predictions_using_windows(test_data)
# (f1, precision, recall), (f1_boe, precision_boe, recall_boe) = compute_scores(
#     test_data, predictions
# )
# (f1, precision, recall), (f1_boe, precision_boe, recall_boe)

In [14]:
# (f1, precision, recall), (f1_boe, precision_boe, recall_boe) = compute_scores(
#     test_data, predictions
# )
# (f1, precision, recall), (f1_boe, precision_boe, recall_boe)

In [24]:
datasets = [
    '/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]

In [26]:
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    test_data = load_file(test_data_path)
    # test_data = test_data[0:300]
    
    predictions = get_predictions_using_windows(test_data)
    (f1, precision, recall), (f1_boe, precision_boe, recall_boe) = compute_scores(test_data, predictions)
    
    print(f"F1 = {f1:.4f}, precision = {precision:.4f}, recall = {recall:.4f}")
    print(f"F1 boe = {f1_boe:.4f}, precision = {precision_boe:.4f}, recall = {recall_boe:.4f}")

Processing /fsx/movb/data/matcha/mewsli-9/ta.jsonl


1000it [00:00, 11604.33it/s]
100%|██████████| 14/14 [01:11<00:00,  5.11s/it]


F1 = 0.0490, precision = 0.0410, recall = 0.0609
F1 boe = 0.3388, precision = 0.2991, recall = 0.3908
Processing /fsx/movb/data/matcha/mewsli-9/ar.jsonl


1468it [00:00, 2019.90it/s]
100%|██████████| 30/30 [02:36<00:00,  5.22s/it]


F1 = 0.0107, precision = 0.0085, recall = 0.0143
F1 boe = 0.3904, precision = 0.3109, recall = 0.5243
Processing /fsx/movb/data/matcha/mewsli-9/en.jsonl


12679it [00:02, 5611.53it/s]
100%|██████████| 200/200 [16:23<00:00,  4.92s/it]


F1 = 0.0075, precision = 0.0065, recall = 0.0089
F1 boe = 0.4499, precision = 0.4063, recall = 0.5038
Processing /fsx/movb/data/matcha/mewsli-9/fa.jsonl


165it [00:00, 10144.39it/s]
100%|██████████| 2/2 [00:07<00:00,  3.64s/it]


F1 = 0.0491, precision = 0.0370, recall = 0.0729
F1 boe = 0.4552, precision = 0.3529, recall = 0.6408
Processing /fsx/movb/data/matcha/mewsli-9/sr.jsonl


15011it [00:01, 9743.64it/s] 
100%|██████████| 169/169 [13:48<00:00,  4.90s/it]


F1 = 0.0946, precision = 0.0667, recall = 0.1627
F1 boe = 0.3871, precision = 0.2780, recall = 0.6372
Processing /fsx/movb/data/matcha/mewsli-9/tr.jsonl


997it [00:00, 21305.42it/s]
100%|██████████| 11/11 [00:53<00:00,  4.89s/it]


F1 = 0.1327, precision = 0.1362, recall = 0.1294
F1 boe = 0.5149, precision = 0.5488, recall = 0.4849
Processing /fsx/movb/data/matcha/mewsli-9/de.jsonl


13703it [00:02, 6112.08it/s]
100%|██████████| 189/189 [15:14<00:00,  4.84s/it]


F1 = 0.2122, precision = 0.1566, recall = 0.3290
F1 boe = 0.4329, precision = 0.3339, recall = 0.6154
Processing /fsx/movb/data/matcha/mewsli-9/es.jsonl


10284it [00:01, 7382.07it/s]
100%|██████████| 144/144 [11:51<00:00,  4.94s/it]


F1 = 0.0806, precision = 0.0651, recall = 0.1059
F1 boe = 0.4563, precision = 0.3852, recall = 0.5596
Processing /fsx/movb/data/matcha/mewsli-9/ja.jsonl


3410it [00:00, 6122.71it/s]
100%|██████████| 47/47 [03:48<00:00,  4.85s/it]


F1 = 0.0002, precision = 0.0003, recall = 0.0001
F1 boe = 0.2689, precision = 0.4229, recall = 0.1971


In [16]:
def shift_shift(text):
    for idx,ch in enumerate(text):
        if not ch.isalpha():
            return idx

def convert_data_for_disambiguation(data, lang):
    # convert examples to 1 entity per example and shift if needed
    if lang=='ar':
        MAX_LENGTH = 600
        MAX_OFFSET = 400
    elif lang == 'ja':
        MAX_LENGTH = 350
        MAX_OFFSET = 250
    else:
        MAX_LENGTH = 800
        MAX_OFFSET = 600
    new_examples = []
    for example in tqdm(data):
        original_text = example['original_text']
        for _, _, ent, _, offset, length in example['gt_entities']:
            shift = 0
            if len(original_text) > MAX_LENGTH and offset > MAX_OFFSET:
                shift = (offset - MAX_OFFSET)
                shift += shift_shift(original_text[shift:])
            new_example = {
                'original_text': original_text[shift:],
                'gt_entities': [[0,0,ent,_,offset-shift,length]],
            }
            new_examples.append(new_example)
    return new_examples

In [17]:
datasets = [
    '/fsx/movb/data/matcha/mewsli-9/ta.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ar.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/en.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/fa.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/sr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/tr.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/de.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/es.jsonl',
    '/fsx/movb/data/matcha/mewsli-9/ja.jsonl',
]

In [19]:
def metrics_disambiguation(test_data, predictions):
    support = 0
    support_only_predicted = 0
    correct = 0
    incorrect_pos = 0

    for example_idx, (example, prediction) in tqdm(enumerate(zip(test_data, predictions))):
#         targets = {
#             (offset,length):ent_id
#             for _,_,ent_id,_,offset,length in example['gt_entities']
#         }
#         prediction = {
#             (offset,length):ent_id
#             for offset,length,ent_id in zip(prediction['offsets'], prediction['lengths'], prediction['entities'])
#         }

#         support += len(targets)
#         support_only_predicted += len(prediction)
        
#         correct += sum(1 for pos,ent_id in prediction.items() if (pos in targets and targets[pos] == ent_id))
#         incorrect_pos += sum(1 for pos,_ in prediction.items() if pos not in targets)
        if len(prediction['entities']) == 0:
            continue
        target = example['gt_entities'][0][2]
        prediction = prediction['entities'][0]
        correct += (target == prediction)
        support += 1

    accuracy = correct/support
    # accuracy_only_predicted = correct/support_only_predicted

    return accuracy, support #, accuracy_only_predicted, support_only_predicted

In [23]:
for test_data_path in datasets:
    print(f"Processing {test_data_path}")
    lang = test_data_path[-8:-6]
    test_data = load_file(test_data_path)
    test_data = convert_data_for_disambiguation(test_data[:10000], lang)
    predictions = model_eval.get_disambiguation_predictions(test_data)
    accuracy, support = metrics_disambiguation(test_data, predictions)
    print(f"Accuracty {accuracy}, support {support}")

Processing /fsx/movb/data/matcha/mewsli-9/ta.jsonl


1000it [00:00, 11419.94it/s]
100%|██████████| 1000/1000 [00:00<00:00, 3834.62it/s]
100%|██████████| 11/11 [00:47<00:00,  4.29s/it]
2692it [00:00, 1057513.01it/s]


Accuracty 0.8959881129271917, support 2692
Processing /fsx/movb/data/matcha/mewsli-9/ar.jsonl


1468it [00:00, 9301.85it/s] 
100%|██████████| 1468/1468 [00:00<00:00, 79520.06it/s]
100%|██████████| 29/29 [02:08<00:00,  4.42s/it]
7367it [00:00, 1161109.18it/s]


Accuracty 0.8839419030813085, support 7367
Processing /fsx/movb/data/matcha/mewsli-9/en.jsonl


12679it [00:00, 67349.94it/s]
100%|██████████| 10000/10000 [00:00<00:00, 12563.32it/s]
100%|██████████| 247/247 [17:07<00:00,  4.16s/it]
63164it [00:00, 915818.94it/s]


Accuracty 0.8362519644632606, support 62358
Processing /fsx/movb/data/matcha/mewsli-9/fa.jsonl


165it [00:00, 9809.22it/s]
100%|██████████| 165/165 [00:00<00:00, 253873.87it/s]
100%|██████████| 3/3 [00:07<00:00,  2.62s/it]
535it [00:00, 785422.70it/s]


Accuracty 0.8953271028037383, support 535
Processing /fsx/movb/data/matcha/mewsli-9/sr.jsonl


15011it [00:00, 16498.93it/s]
100%|██████████| 10000/10000 [00:00<00:00, 342448.07it/s]
100%|██████████| 102/102 [06:37<00:00,  3.89s/it]
25957it [00:00, 1128015.55it/s]


Accuracty 0.9280292814486611, support 25955
Processing /fsx/movb/data/matcha/mewsli-9/tr.jsonl


997it [00:00, 21980.48it/s]
100%|██████████| 997/997 [00:00<00:00, 120129.88it/s]
100%|██████████| 23/23 [01:28<00:00,  3.83s/it]
5811it [00:00, 949848.03it/s]


Accuracty 0.8652555498193082, support 5811
Processing /fsx/movb/data/matcha/mewsli-9/de.jsonl


13703it [00:02, 6202.94it/s] 
100%|██████████| 10000/10000 [00:00<00:00, 12988.24it/s]
100%|██████████| 186/186 [12:39<00:00,  4.08s/it]
47603it [00:00, 968088.39it/s]


Accuracty 0.896246034913766, support 47603
Processing /fsx/movb/data/matcha/mewsli-9/es.jsonl


10284it [00:01, 7281.39it/s] 
100%|██████████| 10000/10000 [00:00<00:00, 12100.67it/s]
100%|██████████| 214/214 [14:12<00:00,  3.98s/it]
54713it [00:00, 916439.85it/s]


Accuracty 0.8769031126057792, support 54713
Processing /fsx/movb/data/matcha/mewsli-9/ja.jsonl


3410it [00:00, 6116.93it/s]
100%|██████████| 3410/3410 [00:00<00:00, 71673.07it/s]
100%|██████████| 135/135 [08:33<00:00,  3.80s/it]
34463it [00:00, 850289.11it/s]

Accuracty 0.841477569496837, support 34462



