In [None]:
# -*- coding: utf-8 -*-

import json
import os
import pprint
import re
from collections import Counter
from pathlib import Path
from tqdm import tqdm

import multiset
import numpy as np
from omegaconf import OmegaConf
from seqeval.metrics import classification_report, f1_score
from seqeval.scheme import IOBES, IOB2
from transformers import AutoTokenizer, AutoConfig, AutoModelForTokenClassification, pipeline

from fairseq.models.roberta import RobertaModel

def parse_file_line(line, with_word_counts=False):
    if line.startswith('VOID'):
        token = ' '
        label = 'VOID'
        wc = '-'
    else:
        if with_word_counts:
            token, label, wc = line.split()
        else:
            token, label = line.split()
    if with_word_counts:
        return label, token, wc
    return label, token


def majority_vote(preds):
    value, _ = Counter(preds).most_common()[0]
    return value


def convert_segmented_preds_to_multitags(preds, use_majority_vote=False):
    """
    Concatenates all unique tags in the given sequence until reaching the first tag that exists in the rest of the word
    (including this tag - henceforth, the recurring tag). If use_majority_vote is True, then all preceding tags are
    combined with the most common tag in the rest of the word. Otherwise, we concatenate the recurring tag.
    E.g. for the following tag sequence: t_1, t_2, t_3, t_4, t_5, t_3, t_3, t_60, t_100,
    the resulting multi-tag prediction will be: t_1+t_2+t_3
    """
    unique_tags = []

    for ind, p in enumerate(preds):
        if p in preds[ind + 1:]:
            if not use_majority_vote:
                if p not in unique_tags:
                    unique_tags.append(p)
                break
            c = Counter(preds[ind:])
            value, _ = c.most_common()[0]
            if p not in unique_tags:
                unique_tags.append(value)
            break
        if p not in unique_tags:
            unique_tags.append(p)

    return '+'.join(list(unique_tags))


def convert_by_pred_spans(preds):
    unique_tags = []
    i = 0
    while i < len(preds):
        curr_pred = preds[i]
        unique_tags.append(curr_pred)
        try:
            last_occurrence = len(preds) - preds[::-1].index(curr_pred) - 1
        except ValueError:
            i += 1
        else:
            i = last_occurrence + 1
    return '+'.join(unique_tags)


def split_by_spaces_and_punct(sent, labels, sent_delims=['%'], tokenizer=None, ud_format=True):
    if tokenizer is not None:
        return get_original_words_from_tokens(sent, tokenizer)

    start = 0
    spans = []
    word_tokens = []
    for ind, (ch, l) in enumerate(zip(sent, labels)):
        if l == 'VOID':
            if start != ind:
                spans.append((start, ind))
            start = ind + 1
            word_tokens = []
        elif (ud_format and l == 'PUNCT') or (not ud_format and l.startswith("yy")) or ch in sent_delims:
            if start != ind:
                spans.append((start, ind))
            spans.append((ind, ind + 1))
            start = ind + 1
            word_tokens = []
        else:
            word_tokens.append(ch)
    if start < len(sent):
        spans.append((start, len(sent)))
    return spans


def get_original_words_from_tokens(sent, tokenizer):
    toks = tokenizer(sent, add_special_tokens=False, return_offsets_mapping=True)[0]
    previous_word_id = None
    word_offsets = []
    start, end = None, None
    for word_id, offset in zip(toks.word_ids, toks.offsets):
        if word_id != previous_word_id:
            if (start, end) != (None, None):
                word_offsets.append((start, end))
            start, end = offset
        else:
            end = offset[1]
        previous_word_id = word_id
    word_offsets.append((start, end))
    return word_offsets


def gather_predictions_from_subwords(pos_preds, word_spans):
    """
    Returns a list of POS predictions determined by the first token of each word in word_spans.
    """
    res = []
    spans_c, pos_pred_c = 0, 0
    while pos_pred_c < len(pos_preds) and spans_c < len(word_spans):
        pos_pred = pos_preds[pos_pred_c]
        if (pos_pred['start'], pos_pred['end']) == word_spans[spans_c]:
            res.append(pos_pred['entity'])
            spans_c += 1
            pos_pred_c += 1
        elif pos_pred['start'] == word_spans[spans_c][0]:
            e = pos_pred['entity']
            pos_pred_c += 1
        elif pos_pred['end'] == word_spans[spans_c][1]:
            res.append(e)
            spans_c += 1
            pos_pred_c += 1
        elif e == 'PUNCT':
            res.append(e)
            spans_c += 1
        else:
            pos_pred_c += 1
    return res


def get_predictions_from_huggingface_output(pred_file):
    preds = []
    with open(pred_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            preds.append(line.split(' '))
    return preds


def get_spans(wcs):
    spans = []
    inds = [i for i in np.unique(wcs) if i.isdigit()]
    for i in inds:
        t = np.where(np.array(wcs) == i)[0]
        spans.append((min(t), max(t) + 1))
    return sorted(spans)

In [20]:
preds = [{'word': 'הצג',
  'score': 0.9994872212409973,
  'entity': 'VERB',
  'index': 1,
  'start': 0,
  'end': 3},
 {'word': '##נו',
  'score': 0.9989884495735168,
  'entity': 'VERB',
  'index': 2,
  'start': 3,
  'end': 5},
 {'word': 'בפניה',
  'score': 0.9901555180549622,
  'entity': 'ADP+NOUN+ADP+PRON',
  'index': 3,
  'start': 6,
  'end': 11},
 {'word': 'את',
  'score': 0.9995697736740112,
  'entity': 'ADP',
  'index': 4,
  'start': 12,
  'end': 14},
 {'word': 'רעיון',
  'score': 0.9996181130409241,
  'entity': 'NOUN',
  'index': 5,
  'start': 15,
  'end': 20},
 {'word': 'הכפ',
  'score': 0.998690664768219,
  'entity': 'DET+NOUN',
  'index': 6,
  'start': 21,
  'end': 24},
 {'word': '##ל',
  'score': 0.9961203932762146,
  'entity': 'DET+NOUN',
  'index': 7,
  'start': 24,
  'end': 25},
 {'word': 'בפעם',
  'score': 0.9948251843452454,
  'entity': 'ADP+DET+NOUN',
  'index': 8,
  'start': 26,
  'end': 30},
 {'word': 'הראשונה',
  'score': 0.9980702996253967,
  'entity': 'DET+ADJ',
  'index': 9,
  'start': 31,
  'end': 38},
 {'word': '.',
  'score': 0.9997915029525757,
  'entity': 'PUNCT',
  'index': 10,
  'start': 38,
  'end': 39}]

sent = 'הצגנו בפניה את רעיון הכפל בפעם הראשונה.'
labels = ['VERB'] * 5 + ['VOID'] + ['ADP+PRON'] * 5 + ['VOID'] + \
['ADP'] * 2 + ['VOID'] + ['NOUN'] * 5 + ['VOID'] + ['DET+NOUN'] * 4 + ['VOID'] + ['ADP+DET+NOUN'] * 4 + \
['VOID'] + ['DET+ADJ'] * 7 + ['PUNCT']
word_spans = split_by_spaces_and_punct(sent, labels)
print(word_spans)
gather_predictions_from_subwords(preds, word_spans)

[(0, 5), (6, 11), (12, 14), (15, 20), (21, 25), (26, 30), (31, 38), (38, 39)]


['VERB',
 'ADP+NOUN+ADP+PRON',
 'ADP',
 'NOUN',
 'DET+NOUN',
 'ADP+DET+NOUN',
 'DET+ADJ',
 'PUNCT']

In [None]:
def read_from_morph_tags_file(gold_dir, feats):
    tagged_sents = []
    sent = []
    for i, f in enumerate(feats):
        with open(os.path.join(gold_dir, f'{f}_test_segmented.txt'), 'r', encoding='utf-8') as fobj:
            sent_id = 0
            for line in fobj:
                line = line.strip()
                if line.isdigit():
                    continue
                if line == '':
                    if i == 0:
                        tagged_sents.append([(t, (l,)) for t, l in sent])
                    else:
                        tagged_sents[sent_id] = [(t, l + (sent[j][1],)) for j, (t, l) in
                                                 enumerate(tagged_sents[sent_id])]
                    sent = []
                    sent_id += 1
                    continue
                if line == 'VOID' or line == 'X':
                    tok = ' '
                    tag = line
                else:
                    tok, tag = line.split()
                sent.append((tok, tag))

    return tagged_sents

In [None]:
def aligned_f1_score(fn_total, fp_total, tp_total):
    precision = tp_total / (tp_total + fp_total)
    recall = tp_total / (tp_total + fn_total)
    aligned_mset_f1 = 2 * precision * recall / (precision + recall)
    aligned_mset_f1 = 100 * aligned_mset_f1
    return aligned_mset_f1

In [None]:
def calc_multiset_f1(checkpoint_path=None, huggingface_checkpoint=False, predictions_path=None, split="dev",
                     pred_format="segmented", gold_format="segmented", use_majority_vote=False, use_prediction_spans=False,
                     use_data_from_dir=True, output_file=None, eval_morph=True, gold_data_dir=None, ud_format=True,
                    verbose=False):
    """The main function for evaluating mset-F1 scores for POS and morphological features."""
    feats = ['abbr',
             'case',
             'definite',
             'gender',
             'hebbinyan',
             'hebexistential',
             'hebsource',
             'mood',
             'number',
             'person',
             'polarity',
             'prefix',
             'prontype',
             'reflex',
             'tense',
             'verbform',
             'verbtype',
             'voice',
             'xtra'] if ud_format else ['feats_gen',
                                        'feats_HebBinyan',
                                        'feats_num',
                                        'feats_per',
                                        'feats_polar',
                                        'feats_suf_gen',
                                        'feats_suf_num',
                                        'feats_suf_per',
                                        'feats_tense']

    if checkpoint_path is not None:
        print("Loading model from checkpoint {} ...".format(checkpoint_path))
    if huggingface_checkpoint:
        id2label = {int(i): l for i, l in
                    json.load(open(os.path.join(str(Path(checkpoint_path).parent), 'ids_to_labels.json'), 'r')).items()}
        label2id = {l: i for i, l in id2label.items()}
        config = AutoConfig.from_pretrained(str(Path(checkpoint_path).parent), label2id=label2id, id2label=id2label)
        model = AutoModelForTokenClassification.from_pretrained(str(Path(checkpoint_path).parent), config=config)
        tokenizer = AutoTokenizer.from_pretrained(str(Path(checkpoint_path).parent))
        roberta = pipeline(task='ner', model=model, tokenizer=tokenizer)
    elif predictions_path is None:
        roberta = RobertaModel.from_pretrained(str(Path(checkpoint_path).parent),
                                               checkpoint_file=os.path.basename(checkpoint_path),
                                               override_data_key=use_data_from_dir)
        roberta.eval()
    else:
        predictions = get_predictions_from_huggingface_output(predictions_path)

    tokenizer = None
    if gold_data_dir is None:
        gold_data_dir = str(Path(checkpoint_path).parent)

    gold_pos_path = os.path.join(gold_data_dir, "pos_{}_{}{}.txt".format(split, gold_format, 
                                                                         "_spmrl" if not ud_format else "_ud"))
    if eval_morph:
        morph_tagged_sents = read_from_morph_tags_file(gold_data_dir, feats=feats, split=split)
        sent_lens = [len(s) for s in morph_tagged_sents]
        
    print("Evaluating msetF1 for {} results in path:".format(split), gold_pos_path)
    tp_total, fp_total, fn_total = {}, {}, {}
    eval_sets = ['pos', 'morph', 'all'] if eval_morph else ['pos']
    for d in [tp_total, fp_total, fn_total]:
        for eval_set in eval_sets:
            d[eval_set] = 0

    error_stats = []
    multitag_sep_char = '+'
    with open(gold_pos_path, 'r', encoding='utf-8') as fin_pos:
        sent = ''
        sent_id = 0
        sent_count = 0
        token_ind_in_sent = 0
        gold_pos = []
        n_errors = 0
        n_words = 0
        for pos_line in tqdm(fin_pos):
            pos_line = pos_line.strip()
            try:
                sent_id = int(pos_line)
                continue
            except ValueError:
                pass

            if pos_line != '':
                try:
                    parse_file_line(pos_line)
                except ValueError:
                    # New sentence
                    sent_id = pos_line
                    continue
                else:
                    pos_label, token = parse_file_line(pos_line)
                    token_ind_in_sent += 1
                    sent += token
                    gold_pos.append(pos_label)
                    continue

            if huggingface_checkpoint:
                pos_preds = roberta(sent)
            elif predictions_path is None:
                tokens = roberta.encode(sent)
                if eval_morph:
                    pos_preds = roberta.predict_tags('pos' if ud_format else 'upostag', tokens)
                    morph_preds = list(zip(*[roberta.predict_tags(feat, tokens) for feat in feats]))
                    assert len(pos_preds) == sent_lens[
                        sent_count], f"Error! Length of POS preds ({len(pos_preds)}) does not match length of " \
                                     f"morph preds ({sent_lens[sent_count]})."
                    assert len(morph_preds) == sent_lens[
                        sent_count], f"Error! Length of POS preds ({len(pos_preds)}) does not match length of " \
                                     f"morph preds ({sent_lens[sent_count]})."
                else:
                    pos_preds = roberta.predict_tags('postagging', tokens)

            word_spans = split_by_spaces_and_punct(sent, gold_pos, tokenizer=tokenizer, ud_format=ud_format,
                                                   sent_delims=['%', ':', ';'])
            n_words += len(word_spans)
            if huggingface_checkpoint:
                pos_preds = gather_predictions_from_subwords(pos_preds, word_spans, majority=use_majority_vote)
            elif predictions_path is not None:
                pos_preds = predictions[sent_count]

            if huggingface_checkpoint:
                assert len(word_spans) == len(pos_preds), f"Error! Number of word spans: {len(word_spans)}, " \
                                                          f"number of POS predictions: {len(pos_preds)}, " \
                                                          f"sentence is: {sent}"
            for ind, (w_start, w_end) in enumerate(word_spans):
                for g in [gold_pos]:
                    assert "VOID" not in g[w_start:w_end], \
                        "Error! Found a VOID tag in gold labels for the word {}".format(sent[w_start:w_end])

                if pred_format == "segmented" and gold_format == "multitag":
                    # Convert the input format into the multi-tag format by splitting all united tags
                    if w_end - w_start == 1:
                        union_pos_pred = pos_preds[w_start:w_end][0]
                    elif use_prediction_spans:
                        union_pos_pred = convert_by_pred_spans(pos_preds[w_start:w_end])
                    else:
                        union_pos_pred = convert_segmented_preds_to_multitags(pos_preds[w_start:w_end],
                                                                              use_majority_vote=use_majority_vote)
                    pos_pred_tokens = union_pos_pred.split(multitag_sep_char)
                else:  # multitag -> multitag
                    if w_end - w_start == 1:
                        if predictions_path is not None:
                            pos_pred_tokens = [pos_preds[ind]]
                        elif not huggingface_checkpoint:
                            pos_pred_tokens = pos_preds[w_start:w_end]
                        elif huggingface_checkpoint:
                            pos_pred_tokens = [pos_preds[ind]]
                    elif use_prediction_spans:
                        union_pos_pred = convert_by_pred_spans(pos_preds[w_start:w_end])
                        pos_pred_tokens = union_pos_pred.split(multitag_sep_char)
                    elif use_majority_vote:
                        if huggingface_checkpoint:
                            pos_pred_tokens = pos_preds[ind].split(multitag_sep_char)
                        else:
                            union_pos_pred = majority_vote(pos_preds[w_start:w_end])
                            pos_pred_tokens = union_pos_pred.split(multitag_sep_char)
                    else:  # prediction by first token
                        if predictions_path is not None or huggingface_checkpoint:
                            pos_pred_tokens = pos_preds[ind].split(multitag_sep_char)
                        else:
                            pos_pred_tokens = pos_preds[w_start].split(multitag_sep_char)

                pos_gold_tokens = convert_by_pred_spans(gold_pos[w_start:w_end]).split(multitag_sep_char)

                if eval_morph:
                    morph_pred_tokens = []
                    morph_gold_tokens = []
                    for i in range(len(feats)):
                        if use_prediction_spans:
                            pred_feat = convert_by_pred_spans([p[i] for p in morph_preds[w_start:w_end]])
                            gold_feat = convert_by_pred_spans(
                                [p[1][i] for p in morph_tagged_sents[sent_count][w_start:w_end]])
                        else:
                            pred_feat = majority_vote([p[i] for p in morph_preds[w_start:w_end]])
                            gold_feat = majority_vote(
                                [p[1][i] for p in morph_tagged_sents[sent_count][w_start:w_end]])
                        if pred_feat != 'X':
                            pred_feat = [p for p in pred_feat.split(multitag_sep_char) if p != 'X']
                        else:
                            pred_feat = []

                        if gold_feat != 'X':
                            gold_feat = [p for p in gold_feat.split(multitag_sep_char) if p != 'X']
                        else:
                            gold_feat = []

                        pred_feat = [p for p in pred_feat if p != '']
                        gold_feat = [g for g in gold_feat if g != '']
                        morph_pred_tokens.extend(pred_feat)
                        morph_gold_tokens.extend(gold_feat)

                    all_pred_tokens = pos_pred_tokens + morph_pred_tokens

                    for g in [gold_pos]:
                        assert len(set(g[w_start:w_end])) == 1, \
                            "Error! Not all gold lables are identical for the word {}: {}".format(
                                sent[w_start:w_end], g[w_start:w_end])

                    all_gold_tokens = pos_gold_tokens + morph_gold_tokens
                else:
                    all_pred_tokens = pos_pred_tokens
                    all_gold_tokens = pos_gold_tokens

                if verbose and sorted(pos_pred_tokens) != sorted(pos_gold_tokens):
                    n_errors += 1
                    print("========================")
                    print("Wrong prediction in sentence id:", sent_id)
                    print("Sentence:", sent)
                    print("Word:", sent[w_start:w_end])
                    if huggingface_checkpoint:
                        print("Pred POS labels:", pos_preds[ind])
                    else:
                        print("Pred POS labels:", pos_preds[w_start:w_end])
                    print("Gold POS labels:", gold_pos[w_start:w_end])

                    error_stats.append({"sent_id": sent_id,
                                        "sent_text": sent,
                                        "word": sent[w_start:w_end],
                                        "pred_pos": pos_preds[
                                            ind] if huggingface_checkpoint else pos_preds[w_start:w_end],
                                        "gold_pos": gold_pos[w_start:w_end],
                                        "pred_mset": all_pred_tokens if eval_morph else pos_pred_tokens,
                                        "gold_mset": all_gold_tokens if eval_morph else pos_gold_tokens})
                    if eval_morph:
                        print("Separated pred labels (All):", all_pred_tokens)
                        print("Separated gold labels (All):", all_gold_tokens)

                    print("Separated pred labels (POS):", pos_pred_tokens)
                    print("Separated gold labels (POS):", pos_gold_tokens)

                for eval_set in eval_sets:
                    if eval_set == 'pos':
                        pred_toks = pos_pred_tokens
                        gold_toks = pos_gold_tokens
                    elif eval_set == 'morph':
                        pred_toks = morph_pred_tokens
                        gold_toks = morph_gold_tokens
                    else:
                        pred_toks = all_pred_tokens
                        gold_toks = all_gold_tokens
                    tp_token = multiset.Multiset(pred_toks).intersection(multiset.Multiset(gold_toks))
                    fp_token = multiset.Multiset(pred_toks).difference(multiset.Multiset(gold_toks))
                    fn_token = multiset.Multiset(gold_toks).difference(multiset.Multiset(pred_toks))
                    tp_total[eval_set] += len(tp_token)
                    fp_total[eval_set] += len(fp_token)
                    fn_total[eval_set] += len(fn_token)
            sent = ''
            token_ind_in_sent = 0
            gold_pos = []
            sent_count += 1

    aligned_mset_f1 = {}
    for eval_set in eval_sets:
        aligned_mset_f1[eval_set] = aligned_f1_score(fn_total[eval_set],
                                                     fp_total[eval_set],
                                                     tp_total[eval_set])

    if output_file is not None:
        logging_str = "============ Results ============\n"
        logging_str += "Checkpoint path: {}\n".format(checkpoint_path)
        logging_str += "Prediction format: {}\n".format(pred_format)
        logging_str += "Gold format: {}\n".format(gold_format)
        logging_str += "Heuristics used: Majority={}, Prediction Spans={}\n".format(use_majority_vote,
                                                                                    use_prediction_spans)
        for eval_set in eval_sets:
            logging_str += "Aligned MSET-F1 on {} for {}: {:.2f}\n".format(split, eval_set, aligned_mset_f1[eval_set])
        open(output_file, 'w').write(logging_str)
        open('{}.json'.format(output_file), 'w', encoding='utf-8').write(
            json.dumps(error_stats, indent=4, ensure_ascii=False))

    return aligned_mset_f1

In [None]:
calc_multiset_f1(#checkpoint_path=r"checkpoints\ud_pos_segmented_new_tavbert_ar_lr_0.0001_bsz_32_5epochs\checkpoint5.pt", 
                 predictions_path=r"checkpoints\tavbert_base_he_pos_lr_0.0001_bsz_16_5epochs\predictions.txt",
                 split="test", 
                 pred_format="segmented",
                 gold_format="multitag", 
                 use_majority_vote=False, 
                 use_prediction_spans=True,
                 use_data_from_dir=True, 
                 output_file="postagging_only_tavbert_base_he_segmented_new.txt", 
                 eval_morph=False, 
                 gold_data_dir=r"finetuning\pos\data",
                 ud_format=True,
                 verbose=True)

In [None]:
def calc_ner_f1(checkpoint_path, predictions_path=None, split="dev", use_majority_vote=False,
                use_prediction_spans=False,
                use_data_from_dir=True, output_file=None, gold_data_dir=None, set_id=1, scheme=IOBES, f1_type="micro",
               ner_format='single'):
    print("Loading model from checkpoint {} ...".format(checkpoint_path))
    if predictions_path is None:
        roberta = RobertaModel.from_pretrained(str(Path(checkpoint_path).parent),
                                               checkpoint_file=os.path.basename(checkpoint_path),
                                               override_data_key=use_data_from_dir)
        pprint.pprint(roberta.cfg)
        roberta.eval()
    else:
        predictions = [p.strip().split() for p in open(predictions_path, 'r').readlines()]

    if gold_data_dir is None:
        gold_data_dir = str(Path(checkpoint_path).parent)

    gold_ner_path_wc = os.path.join(gold_data_dir, f"token-single_gold_{split}.bmes_word_counts")
    print(f"Evaluating F1 for {split} on set {set_id}, with results in path: {gold_ner_path_wc}")
    tp_total, fp_total, fn_total = 0, 0, 0
    y_true, y_pred, preds_tok_label = [], [], []
    error_stats = []
    gold_words = []
    with open(gold_ner_path_wc, 'r', encoding='utf-8') as fin_ner:
        sent = ''
        sent_count = 0
        token_ind_in_sent = 0
        gold_pos = []
        gold_wcs = []
        y_pred_per_sent = []
        y_true_per_sent = []
        preds_tok_label_per_sent = []
        n_errors = 0
        n_words = 0
        for index, pos_line in enumerate(fin_ner):
            pos_line = pos_line.strip()
            sent_id = sent_count
            if pos_line == '':
                if sent_count % 10 == 0:
                    print("Processed {} sentences".format(sent_count))
                word_spans = get_spans(gold_wcs)
                n_words += len(word_spans)
                gold_words.append([sent[w_start:w_end] for w_start, w_end in word_spans])
                if predictions_path is None:
                    tokens = roberta.encode(sent)
                    ner_preds = roberta.predict_tags('ner', tokens)
                else:
                    ner_preds = predictions[sent_count]
                    assert len(ner_preds) == len(word_spans), f'{len(ner_preds)}, {len(word_spans)}'
                for ind, (w_start, w_end) in enumerate(word_spans):
                    for g in [gold_pos]:
                        assert "VOID" not in g[w_start:w_end], "Error! Found a VOID tag in gold labels for the word {}".format(sent[w_start:w_end])
                        assert len(set(g[
                                       w_start:w_end])) == 1, "Error! Gold labels for word {} contain more than one unique tag: {}".format(
                            sent[w_start:w_end], g[w_start:w_end])
                    if predictions_path is not None:
                        all_pred_tokens = ner_preds[ind].split('+')
                    else:
                        if use_prediction_spans:
                            union_pos_pred = convert_by_pred_spans(ner_preds[w_start:w_end])
                            all_pred_tokens = union_pos_pred.split('+')
                        elif use_majority_vote:
                            union_pos_pred = majority_vote(ner_preds[w_start:w_end])
                            all_pred_tokens = union_pos_pred.split('+')
                        else:  # prediction by first token
                            all_pred_tokens = ner_preds[w_start].split('+')

                    all_gold_tokens = convert_by_pred_spans(gold_pos[w_start:w_end]).split('+')

                    if sorted(all_pred_tokens) != sorted(all_gold_tokens):
                        n_errors += 1
                        print("========================")
                        print("Wrong prediction in sentence id:", sent_id)
                        print("Sentence:", sent)
                        print("Word:", sent[w_start:w_end])
                        print("Pred labels:", ner_preds[w_start:w_end] if predictions_path is None else ner_preds[ind])
                        print("Gold labels:", gold_pos[w_start:w_end])

                        error_stats.append({"sent_id": sent_id,
                                            "sent_text": sent,
                                            "word": sent[w_start:w_end],
                                            "pred_pos": ner_preds[w_start:w_end] if predictions_path is None else ner_preds[ind],
                                            "gold_pos": gold_pos[w_start:w_end],
                                            "pred_mset": all_pred_tokens,
                                            "gold_mset": all_gold_tokens})

                        print("Separated pred labels:", all_pred_tokens)
                        print("Separated gold labels:", all_gold_tokens)

                    y_true_per_sent.append(all_gold_tokens[0])
                    y_pred_per_sent.append(all_pred_tokens[0])
                    preds_tok_label_per_sent.append((sent[w_start:w_end], all_pred_tokens[0]))
                sent = ''
                token_ind_in_sent = 0
                gold_pos = []
                gold_wcs = []
                y_true.append(y_true_per_sent)
                y_pred.append(y_pred_per_sent)
                preds_tok_label.append(preds_tok_label_per_sent)
                y_true_per_sent = []
                y_pred_per_sent = []
                preds_tok_label_per_sent = []
                sent_count += 1
                continue

            pos_label, token, wc = parse_file_line(pos_line, with_word_counts=True)
            token_ind_in_sent += 1
            sent += token

            gold_pos.append(pos_label)
            gold_wcs.append(wc)
    
    ner_f1 = None
#     ner_f1 = f1_score(y_true, y_pred, average=f1_type, scheme=scheme)
    ner_f1 = f1_score(y_true, y_pred)
    print(classification_report(y_true, y_pred, digits=4, mode='strict', scheme=scheme))

    if output_file is not None:
        logging_str = "============ Results ============\n"
        logging_str += "Checkpoint path: {}\n".format(checkpoint_path)
        logging_str += "Heuristics used: Majority={}, Prediction Spans={}\n".format(use_majority_vote,
                                                                                    use_prediction_spans)
        logging_str += "{} F1 on {}: {:.4f}\n".format(f1_type, split, ner_f1)
        open(output_file, 'w').write(logging_str)
        open('{}.json'.format(output_file), 'w', encoding='utf-8').write(
            json.dumps(error_stats, indent=4, ensure_ascii=False))
        with open('{}_for_ner_eval.txt'.format(str(Path(checkpoint_path).parent) if predictions_path is None else str(Path(predictions_path).parent)), 
                  'w', encoding='utf-8') as f:
            for sent_preds in preds_tok_label:
                for tok, label in sent_preds:
                    f.write(f'{tok} {label}\n')
                f.write('\n')

    return ner_f1, gold_words

In [None]:
_, dev_gold_words = calc_ner_f1(
    checkpoint_path=r"checkpoints\punct_corrected_nemo_multi_token_base_he_from_last_lr_5e-05_bsz_32_5epochs\checkpoint5.pt",
    split="dev",
    use_majority_vote=False,
    use_prediction_spans=False,
    use_data_from_dir=True,
    output_file=r"ner_roberta_base_he_nemo_token_multi_test_first_token_dev.txt",
    gold_data_dir=r"finetuning_data\ner\data\raw_data_he\nemo")

In [None]:
gold_words

In [None]:
calc_ner_f1(
    checkpoint_path=r"checkpoints\roberta_base_lambda_5_oscar_3e-04\finetune_ner\nemo_ner_base_he_from_last_lr_1e-04_bsz_32_5epochs\checkpoint5.pt",
#     checkpoint_path=r"checkpoints\ner_base_ar_from_last_lr_5e-05_bsz_32_5epochs\checkpoint5.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_929721_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_5e-05_bsz_32_seed_42298_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_5023924_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_5e-05_bsz_32_seed_903443_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_32763_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_569021_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_74212987_15epochs\checkpoint_best.pt",
#     checkpoint_path=r"checkpoints\ner_base_ar_from_last_lr_3e-05_bsz_16_5epochs\checkpoint5.pt",
    split="test",
    use_majority_vote=True,
    use_prediction_spans=False,
    use_data_from_dir=True,
    gold_data_dir=r"finetuning_data\ner\raw_data_he\nemo",
    ner_format='single',
    scheme=IOB2,
    output_file=r'arabertv01_ner_lr_3e-05_bsz_64_5epochs_test.txt',
    f1_type="macro")

In [None]:
alephbert_pred_files = [
#     r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_929721_15epochs\predictions.txt",
#     r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs\predictions.txt", 
#     r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_929721_15epochs\dev_predictions.txt",
#     r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs\dev_predictions.txt"
]

tavbert_pred_files = [
    r"alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_929721_15epochs_for_ner_eval.txt",
r"alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs_for_ner_eval.txt",
    r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_32763_15epochs_for_ner_eval.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_929721_15epochs_for_ner_eval.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_5023924_15epochs_for_ner_eval.txt",
#r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_42298_15epochs_dev.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_32763_15epochs_for_ner_eval_dev.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_569021_15epochs_for_ner_eval_test.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_569021_15epochs_for_ner_eval_dev.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_5023924_15epochs_for_ner_eval_dev.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_929721_15epochs_for_ner_dev.txt",
r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_74212987_15epochs_for_ner_eval.txt",
    r"punct_corrected_nemo_token_single_base_he_from_last_lr_3e-05_bsz_16_seed_74212987_15epochs_for_ner_eval_dev.txt",
    
]

gold_path = r"finetuning_data\ner\raw_data_he\nemo\raw_tokens\token-single_gold_test.bmes"
dev_path = r"finetuning_data\ner\raw_data_he\nemo\raw_tokens\token-single_gold_dev.bmes"  
import pandas as pd
import re

f1_scores = {'alephbert': {}, 'tavbert': {}}

pat = re.compile(r'^.*_lr_(?P<lr>.*)_bsz_(?P<bsz>.*)_seed_(?P<seed>[0-9]+).*$')

def get_predictions_from_huggingface_output(pred_file):
    preds = []
    with open(pred_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            preds.append(line.split(' '))
    return preds

for _ in ['dev', 'test']:
#     for aleph_p in alephbert_pred_files:
#         preds = get_predictions_from_huggingface_output(aleph_p)
#         assert len(gold_words) == len(preds)
#         assert [len(g) == len(p) for g, p in zip(gold_words, preds)]
#         print(aleph_p)
#         pred_path = os.path.basename(str(Path(aleph_p).parent) + '_for_ner_eval_dev.txt')
#         with open(pred_path, 'w', encoding='utf-8') as fobj:
#             for i in range(len(preds)):
#                 for w, p in list(zip(*list(zip(gold_words, preds))[i])):
#                     fobj.write(f'{w} {p}\n')
#                 fobj.write('\n')

#         seed = int(pat.match(aleph_p).groupdict()['seed'])
#         _, _, f1_s = ne_evaluate_mentions.evaluate_files(gold_path=dev_path, pred_path=pred_path, verbose=True)
#         f1_scores['alephbert'][seed] = {}
#         f1_scores['alephbert'][seed][split] = 100 * f1_s

    for tav_p in tavbert_pred_files:
    #     preds = get_predictions_from_huggingface_output(tav_p)
    #     assert len(gold_words) == len(preds), f"{len(gold_words)}, {len(preds)}"
    #     assert [len(g) == len(p) for g, p in zip(gold_words, preds)]
#         if (split == 'dev' and not tav_p.endswith("dev.txt")) or (split == 'test' and tav_p.endswith("dev.txt")):
#             continue
        print(tav_p)
        assert pat.match(tav_p).groupdict()['lr'] == '3e-05'
        assert int(pat.match(tav_p).groupdict()['bsz']) == 16
        seed = int(pat.match(tav_p).groupdict()['seed'])
        _, _, f1_s_a = ne_evaluate_mentions.evaluate_files(gold_path=dev_path, pred_path=tav_p, verbose=False)
        _, _, f1_s_b = ne_evaluate_mentions.evaluate_files(gold_path=gold_path, pred_path=tav_p, verbose=False)
        print("dev:", f1_s_a, "test:", f1_s_b)
        f1_s = f1_s_a if f1_s_a != -1 else f1_s_b
        split = 'dev' if f1_s_a != -1 else 'test'
        if seed not in f1_scores['tavbert']:
            f1_scores['tavbert'][seed] = {}
        f1_scores['tavbert'][seed][split] = 100 * f1_s
        pprint.pprint(f1_scores)

print("===================================================")
pprint.pprint(f1_scores)

In [None]:
# pred_file = r"checkpoints\alephbert_ner_nemo_token_multi_lr_5e-05_bsz_64_6epochs\test_predictions.txt"
# pred_file = r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_64_6epochs\test_predictions.txt"
# pred_file = r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_64_5epochs\predictions.txt"
# pred_file = r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_929721_15epochs\predictions.txt"
pred_file = r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs\predictions.txt"
#     
preds = get_predictions_from_huggingface_output(pred_file)

In [None]:
import pprint
from NEMO import ne_evaluate_mentions

# _, gold_words = calc_ner_f1(
# #     checkpoint_path=r"checkpoints\roberta_base_lambda_5_oscar_3e-04\finetune_ner\nemo_ner_base_he_from_last_lr_1e-04_bsz_32_5epochs\checkpoint5.pt",
#     checkpoint_path=r"checkpoints\punct_corrected_nemo_multi_token_base_he_from_last_lr_5e-05_bsz_32_5epochs\checkpoint5.pt",
#     split="test",
#     use_majority_vote=False,
#     use_prediction_spans=False,
#     use_data_from_dir=True,
#     output_file=r"ner_roberta_base_he_nemo_token_multi_test_first_token.txt",
#     gold_data_dir=r"finetuning_data\ner\raw_data_he\nemo")
gold_path = r"finetuning_data\ner\raw_data_he\nemo\raw_tokens\token-single_gold_test.bmes"

# #### CHANGE THIS! ####
# pred_files = [r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_929721_15epochs\predictions.txt",
#               r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs\predictions.txt"]

pred_files = [r"checkpoints\alephbert_ner_nemo_token_single_lr_3e-05_bsz_16_seed_5023924_15epochs\predictions.txt"]
def get_predictions_from_huggingface_output(pred_file):
    preds = []
    with open(pred_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            preds.append(line.split(' '))
    return preds

pred_paths = []
for pred_file in pred_files:
    preds = get_predictions_from_huggingface_output(pred_file)

    assert len(gold_words) == len(preds)
    assert [len(g) == len(p) for g, p in zip(gold_words, preds)]


    pred_path = os.path.basename(str(Path(pred_file).parent))
    pred_paths.append(pred_path)
    with open(pred_path, 'w', encoding='utf-8') as fobj:
        for i in range(len(preds)):
            for w, p in list(zip(*list(zip(gold_words, preds))[i])):
                fobj.write(f'{w} {p}\n')
            fobj.write('\n')

f1_scores = []
print(pred_path)
for pred_path in [
    r"punct_corrected_nemo_token_single_base_he_from_last_lr_5e-05_bsz_32_seed_42298_15epochs_for_ner_eval.txt",
    r"punct_corrected_nemo_token_single_base_he_from_last_lr_5e-05_bsz_32_seed_903443_15epochs_for_ner_eval.txt"
]:
    print(pred_path)
    _, _, f1_s = ne_evaluate_mentions.evaluate_files(gold_path=gold_path, pred_path=pred_path, verbose=True)
    f1_scores.append(f1_s)

f1_scores = np.array(f1_scores)
print(f1_scores)
print(100 * np.mean(f1_scores), '+-', np.std(100 * f1_scores))