# Author: ddukic

In [5]:
import sys

sys.path.append("../baselines/")

from dataset import (
    ACE2005TriggerRelationDataset,
    EDNYTTriggerRelationDataset,
    EVEXTRATriggerRelationDataset,
    MavenTriggerRelationDataset,
)

from transformers import AutoTokenizer

from util import build_vocab

all_labels_trigger, id2trigger, trigger2id = build_vocab(["Trigger"])
all_labels_relation, id2relation, relation2id = build_vocab(["Relation"])


tokenizer = AutoTokenizer.from_pretrained(
    "roberta-base", do_lower_case=False, add_prefix_space=True
)


def calculate_freqs(
    source_dataset, dataset_train_path, dataset_train_mini_extractions_path, tokenizer
):
    from torch.utils import data
    import torch

    dataset_train = source_dataset(
        fpath_trigger=dataset_train_path,
        fpath_relation=dataset_train_mini_extractions_path,
        tokenizer=tokenizer,
        trigger2id=trigger2id,
        relation2id=relation2id,
        implicit=True,
    )

    for i in range(len(dataset_train.triggers)):
        for j in range(len(dataset_train.triggers[i])):
            if "Trigger" in dataset_train.triggers[i][j]:
                dataset_train.triggers[i][j] = "T"
            if "Relation" in dataset_train.relations[i][j]:
                dataset_train.relations[i][j] = "T"

    # trigger_relation
    yes_yes = 0
    yes_no = 0
    no_yes = 0
    no_no = 0

    for trig, rel in zip(dataset_train.triggers, dataset_train.relations):
        for i in range(len(trig)):
            if trig[i] == "T" and rel[i] == "T":
                yes_yes += 1

    for trig, rel in zip(dataset_train.triggers, dataset_train.relations):
        for i in range(len(trig)):
            if trig[i] != "T" and rel[i] != "T":
                no_no += 1

    for trig, rel in zip(dataset_train.triggers, dataset_train.relations):
        for i in range(len(trig)):
            if trig[i] == "T" and rel[i] != "T":
                yes_no += 1

    for trig, rel in zip(dataset_train.triggers, dataset_train.relations):
        for i in range(len(trig)):
            if trig[i] != "T" and rel[i] == "T":
                no_yes += 1

    return yes_yes, yes_no, no_yes, no_no

In [25]:
print(
    calculate_freqs(
        ACE2005TriggerRelationDataset,
        "../data/raw/ace/train.json",
        "../data/processed/mini/ace_train_triplets_filtered_merged.json",
        tokenizer,
    )
)

(746, 3799, 15564, 229186)


In [26]:
print(
    calculate_freqs(
        EDNYTTriggerRelationDataset,
        "../data/processed/ednyt/train.json",
        "../data/processed/mini/ednyt_train_triplets_filtered_merged.json",
        tokenizer,
    )
)

(479, 4219, 2471, 60942)


In [27]:
print(
    calculate_freqs(
        EVEXTRATriggerRelationDataset,
        "../data/processed/evextra/train.json",
        "../data/processed/mini/evextra_train_triplets_filtered_merged.json",
        tokenizer,
    )
)

(3038, 14350, 10564, 192710)


In [6]:
print(
    calculate_freqs(
        MavenTriggerRelationDataset,
        "../data/raw/maven/train.jsonl",
        "../data/processed/mini/maven_train_triplets_filtered_merged.json",
        tokenizer,
    )
)

(15051, 65487, 30896, 720752)
