# Author: ddukic

In [131]:
import sys

sys.path.append("../baselines/")
from collections import defaultdict
from dataset import MavenDataset, ACE2005Dataset, EVEXTRADataset, EDNYTDataset
import seqeval
from transformers import AutoTokenizer
from util import *
import json

dataset_map = {
    "ACE2005": [ACE2005Dataset, "ace", "../data/raw/ace/test.json"],
    "Maven": [MavenDataset, "maven", "../data/raw/maven/valid.jsonl"],
    "EVEXTRA": [EVEXTRADataset, "evextra", "../data/processed/evextra/test.json"],
    "EDNYT": [EDNYTDataset, "ednyt", "../data/processed/ednyt/test.json"],
}

tokenizer = AutoTokenizer.from_pretrained(
    "roberta-base", add_prefix_space=True, do_lower_case=False
)

In [134]:
def get_metrics(dataset="ACE2005"):
    dataset_picked = dataset_map[dataset]
    with open(
        "../data/processed/mini/"
        + dataset_picked[1]
        + "_test_triplets_filtered_merged.json",
        "r",
    ) as f:
        mini = json.load(f)
    all_triggers, idx2trigger, trigger2idx = build_vocab(["Trigger"])
    dataset_test = dataset_picked[0](
        fpath=dataset_picked[2],
        tokenizer=tokenizer,
        trigger2id=trigger2idx,
        task="trigger identification",
    )

    tags_pred = []
    tags_true = []

    for i in range(len(dataset_test.triggers)):
        if str(i) in mini.keys():
            if mini[str(i)]["tokens"] == dataset_test.tokens[i]:
                pred = []
                for x in mini[str(i)]["bio_tags"][0]:
                    if "Relation" in x:
                        pred.append(x.split("-")[0] + "-Trigger")
                    else:
                        pred.append("O")
            else:
                print("Problem with index ", str(i))
                continue
            tags_pred.append(pred)
        else:
            tags_pred.append(["O"] * len(dataset_test.triggers[i]))
        tags_true.append(dataset_test.triggers[i])

    return seqeval.compute(
        predictions=tags_pred, references=tags_true, scheme="IOB2", mode="strict"
    )


In [135]:
# ACE2005
get_metrics(dataset="ACE2005")

{'Trigger': {'precision': 0.03205128205128205,
  'recall': 0.035545023696682464,
  'f1': 0.033707865168539325,
  'number': 422},
 'overall_precision': 0.03205128205128205,
 'overall_recall': 0.035545023696682464,
 'overall_f1': 0.033707865168539325,
 'overall_accuracy': 0.928023758099352}

In [136]:
# Maven
get_metrics(dataset="Maven")

{'Trigger': {'precision': 0.20375284306292646,
  'recall': 0.056866271688531526,
  'f1': 0.08891645988420183,
  'number': 18904},
 'overall_precision': 0.20375284306292646,
 'overall_recall': 0.056866271688531526,
 'overall_f1': 0.08891645988420183,
 'overall_accuracy': 0.8797737538864663}

In [137]:
# EDNYT
get_metrics(dataset="EDNYT")

{'Trigger': {'precision': 0.06382978723404255,
  'recall': 0.021791767554479417,
  'f1': 0.03249097472924188,
  'number': 413},
 'overall_precision': 0.06382978723404255,
 'overall_recall': 0.021791767554479417,
 'overall_f1': 0.03249097472924188,
 'overall_accuracy': 0.9069833191970597}

In [138]:
# Evextra
get_metrics(dataset="EVEXTRA")

Problem with index  2156


{'Trigger': {'precision': 0.1712403951701427,
  'recall': 0.06319627304030788,
  'f1': 0.0923213493120284,
  'number': 4937},
 'overall_precision': 0.1712403951701427,
 'overall_recall': 0.06319627304030788,
 'overall_f1': 0.0923213493120284,
 'overall_accuracy': 0.8840432195427498}