# 🧪 Evaluation: Event Extraction
This notebook evaluate the test set for the task `event extraction`.

**Note**: Before conducting experiments, you need to install `kaner` package first. Otherwise, this notebook will raise an *import error*.

```bash
cd ../
python setup.py install
```

In [36]:
import json
import os
from copy import deepcopy

import tqdm

from kaner.context import GlobalContext as gctx
from kaner.adapter.tokenizer import CharTokenizer
from kaner.adapter.knowledge import Gazetteer
from kaner.trainer import NERTrainer, TrainerConfig
from kaner.common import load_json, load_jsonl


gctx.init()

## 1. Define Prediction Function

In [37]:
def load_trainer(model_folder: str) -> NERTrainer:
    options = load_json("utf-8", model_folder, "config.json")
    options["output_folder"] = model_folder
    options["identity"] = os.path.basename(os.path.normpath(model_folder))
    config = TrainerConfig(options, data_folder="../data")
    tokenizer = CharTokenizer(model_folder)
    gazetteer = Gazetteer(model_folder)
    out_adapter = gctx.create_outadapter(config.out_adapter, dataset_folder=model_folder, file_name="labels")
    in_adapters = (
        gctx.create_inadapter(
            config.in_adapter, dataset=[], tokenizer=tokenizer, out_adapter=out_adapter, gazetteer=gazetteer,
            **config.hyperparameters
        )
        for _ in range(3)
    )
    collate_fn = gctx.create_batcher(
        config.model, input_pad=tokenizer.pad_id, output_pad=out_adapter.unk_id, lexicon_pad=gazetteer.pad_id, device=config.device
    )
    model = gctx.create_model(config.model, **config.hyperparameters)
    trainer = NERTrainer(
        config, tokenizer, in_adapters, out_adapter, collate_fn, model, None,
        gctx.create_traincallback(config.model), gctx.create_testcallback(config.model)
    )
    trainer.startp()

    return trainer


def predict(trainer: NERTrainer, text: str) -> dict:
    max_seq_len = 510
    max_query_len = 128
    back_offset = 32
    max_seq_len -= max_query_len
    # cut document
    fragments = []
    pointer = 0
    while pointer < len(text):
        new_text = text[pointer: pointer + max_seq_len]
        fragments.append(new_text)
        pointer += max_seq_len - back_offset
    # predict
    raw_result = trainer.predict(fragments)
    result = {
        "text": text,
        "spans": []
    }
    offset = 0
    for i in range(len(fragments)):
        for j in range(len(raw_result[i]["spans"])):
            raw_result[i]["spans"][j]["start"] += offset
            raw_result[i]["spans"][j]["end"] += offset
        result["spans"].extend(raw_result[i]["spans"])
        offset += (max_seq_len - back_offset)
    # check
    for span in result["spans"]:
        assert result["text"][span["start"]: span["end"] + 1] == span["text"]

    return result

## 2. Get Predictions

In [38]:
datarf = load_jsonl("utf-8", "../data", "datahub", "ccksee", "datarf.jsonl")
datarf_ne = []
for datapoint in datarf:
    if len(datapoint["events"]) > -1:
        datarf_ne.append(datapoint)
test_len = int(len(datarf_ne) * 0.1)
test_data = datarf_ne[len(datarf_ne) - test_len:]

trainder = load_trainer("../data/logs/trainer-plmtg-ccksee-1")
for i in tqdm.tqdm(range(len(test_data)), "Evaluation (testset)"):
    test_data[i]["predicted_spans"] = predict(trainder, test_data[i]["text"])["spans"]

Lexicon embedding is None! ../data/logs/trainer-plmtg-ccksee-1


Text2Tensor: 0it [00:00, ?it/s]
Text2Tensor: 0it [00:00, ?it/s]
Text2Tensor: 0it [00:00, ?it/s]
Evaluation (testset): 100%|██████████| 395/395 [00:26<00:00, 14.65it/s]


# 3. Event Evaluation

In [39]:
from typing import List
from collections import defaultdict


def merge_entity(data: List[dict]) -> List[dict]:
    data = deepcopy(data)
    for i, _ in enumerate(data):
        # merge entity
        events = {}
        for span in data[i]["predicted_spans"]:
            event_type, role_name = span["label"].split(".")
            if event_type not in events.keys():
                events[event_type] = {}
            if role_name not in events[event_type].keys():
                events[event_type][role_name] = []
            span.pop("label")
            events[event_type][role_name].append(span)
        data[i]["predictions"] = events
        
        data[i]["merged_predictions"] = []
        # default: one event only for each event type
        # 1) compare count
        # 2) compare probability
        for event_type, event in events.items():
            final_event = {"event_type": event_type, "arguments": {}}
            for role_name, entities in events[event_type].items():
                role_confidence = defaultdict(float)
                role_count = defaultdict(int)
                for entity in entities:
                    role_confidence[entity["text"]] += entity["confidence"]
                    role_count[entity["text"]] += 1
                for role_value in role_confidence.keys():
                    role_confidence[role_value] /= role_count[role_value]
                role_with_max_count, max_confidence = "", 0.
                for role_value in role_count.keys():
                    if role_with_max_count == "" or role_count[role_with_max_count] < role_count[role_value]:
                        role_with_max_count = role_value
                    elif role_count[role_with_max_count] == role_count[role_value] and role_confidence[role_with_max_count] < role_confidence[role_value]:
                        role_with_max_count = role_value
                final_event["arguments"][role_name] = role_with_max_count
            data[i]["merged_predictions"].append(final_event)

    return data


def evaluate(data: List[dict]) -> dict:
    n_correct, n_predicted, n_goldtruth = 0, 0, 0
    for i, _ in enumerate(data):
        predicted = deepcopy(data[i]["merged_predictions"])
        goldtruth = deepcopy(data[i]["events"])
        for event in predicted:
            n_predicted += len(event["arguments"])
        for event in goldtruth:
            n_goldtruth += len(event["arguments"])
        # extract entities
        for gold_event in goldtruth:
            matched_count = 0
            matched_index = -1
            gold_roles = set(
                [
                    "{0}#{1}".format(gold_event["arguments"][i]["label"], gold_event["arguments"][i]["text"])
                    for i, _ in enumerate(gold_event["arguments"])
                ]
            )
            for i, pred_event in enumerate(predicted):
                pred_roles = set(["{0}#{1}".format(role, pred_event["arguments"][role]) for role in pred_event["arguments"].keys()])
                correct_roles = gold_roles.intersection(pred_roles)
                if matched_index == -1 or len(correct_roles) > matched_count:
                    matched_count = len(correct_roles)
                    matched_index = i
            if matched_index != -1:
                n_correct += matched_count
                predicted.pop(matched_index)
    p = n_correct / n_predicted if n_predicted > 0 else 0
    r = n_correct / n_goldtruth if n_goldtruth > 0 else 0
    f1 = 2*p*r / (p+r) if p+r > 0 else 0

    return {
        "f1": f1, "p": p, "r": r
    }


merged_data = merge_entity(test_data)
print(evaluate(merged_data))

{'f1': 0.6906187624750499, 'p': 0.7966231772831927, 'r': 0.6095126247798004}
