### Mention Identification (gold Input)

In [30]:
import json
gold = [json.loads(line) for line in open('/home/pbaweja/SciREX/scirex_dataset/release_data/test.jsonl')]
predicted = [json.loads(line) for line in open('/home/pbaweja/SciREX/test_outputs/ner_predictions.jsonl')]

gold_ner = predicted_ner = [[(m[0], m[1] - 1, m[2]) for m in doc['ner']] for doc in gold]

predicted_ner = [[(m[0], m[1] - 1, m[2]) for m in doc['ner']] for doc in predicted]
predicted_num = sum([len(s) for s in predicted_ner])
gold_num = sum([len(s) for s in gold_ner])

matched = sum([len(set(g) & set(p)) for g, p in zip(predicted_ner, gold_ner)])

p = matched / predicted_num
r = matched / gold_num
f1 = 2 * p * r / (p + r)

print("Mention Identification")
print(f"p = {p}")
print(f"r = {r}")
print(f"f1 = {f1}")

Mention Identification
p = 0.7236133951263268
r = 0.7572059916304901
f1 = 0.7400286669851888


### Coreference Evaluation

In [34]:
def span_match(span_1, span_2):
    sa, ea = span_1
    sb, eb = span_2
    iou = (min(ea, eb) - max(sa, sb)) / (max(eb, ea) - min(sa, sb))
    return iou

In [35]:
import pandas as pd
from sys import argv


def overlap_score(cluster_1, cluster_2):
    matched = 0
    for s1 in cluster_1:
        matched += 1 if any([span_match(s1, s2) > 0.5 for s2 in cluster_2]) else 0

    return matched / len(cluster_1)


def compute_metrics(predicted_clusters, gold_clusters):
    matched_predicted = []
    matched_gold = []
    for i, p in enumerate(predicted_clusters):
        for j, g in enumerate(gold_clusters):
            if overlap_score(p, g) > 0.5:
                matched_predicted.append(i)
                matched_gold.append(j)

    matched_predicted = set(matched_predicted)
    matched_gold = set(matched_gold)

    metrics = {
        "p": len(matched_predicted) / (len(predicted_clusters) + 1e-7),
        "r": len(matched_gold) / (len(gold_clusters) + 1e-7),
    }
    metrics["f1"] = 2 * metrics["p"] * metrics["r"] / (metrics["p"] + metrics["r"] + 1e-7)

    return metrics


def score_scirex_model(predictions, gold_data):
    gold_data = {x["doc_id"]: list(x["coref"].values()) for x in gold_data}
    predictions = {x["doc_id"]: list(x["clusters"].values()) for x in predictions}

    all_metrics = []
    for p, pc in predictions.items():
        metrics = compute_metrics(pc, gold_data[p])
        all_metrics.append(metrics)

    all_metrics = pd.DataFrame(all_metrics)
    print(all_metrics.describe())

In [48]:
gold = [json.loads(line) for line in open('/home/pbaweja/SciREX/scirex_dataset/release_data/test.jsonl')]
predicted = [json.loads(line) for line in open('/home/pbaweja/SciREX/test_outputs/cluster_predictions.jsonl')]

In [46]:
gold[0].keys()

dict_keys(['coref', 'coref_non_salient', 'doc_id', 'method_subrelations', 'n_ary_relations', 'ner', 'sections', 'sentences', 'words'])

In [49]:
predicted[0].keys()

dict_keys(['doc_id', 'spans', 'clusters'])

In [50]:
score_scirex_model(predicted, gold)

              f1          p          r
count  66.000000  66.000000  66.000000
mean    0.092940   0.051823   0.662185
std     0.054236   0.033848   0.225624
min     0.016000   0.008264   0.066667
25%     0.054632   0.029806   0.500000
50%     0.085762   0.045658   0.666667
75%     0.118125   0.063717   0.833333
max     0.278422   0.172973   1.000000
