# Author: ddukic

## Import libraries

In [1]:
import os
from tqdm import tqdm
import json
from tqdm import tqdm
from collections import Counter
from collections import defaultdict

file = "dev"

## Help functions

In [2]:
def convert_bio(tags, tag="Subject"):
    return [
        ("B-" + tag, x[1]) if i == 0 else ("I-" + tag, x[1]) for i, x in enumerate(tags)
    ]


def create_bio_tags(triplet_tokens, subject_, relation_, object_):
    bio_tags = ["O"] * len(triplet_tokens)

    subject_bio = convert_bio(subject_)
    relation_bio = convert_bio(relation_, "Relation")
    object_bio = convert_bio(object_, "Object")

    # for some reason MINI indexes go from 1
    for s in subject_bio:
        bio_tags[s[1] - 1] = s[0]

    for r in relation_bio:
        bio_tags[r[1] - 1] = r[0]

    for o in object_bio:
        bio_tags[o[1] - 1] = o[0]

    return bio_tags


def is_implicit(t):
    return (
        any("--2" in x for x in t["subject"])
        or any("--2" in x for x in t["relation"])
        or any("--2" in x for x in t["object"])
    )


def special_split(x):
    if "--2" in x:
        return [x.split("-", 1)[0], int(x.split("-", 1)[1])]
    else:
        return [x.rsplit("-", 1)[0], int(x.rsplit("-", 1)[1])]

In [3]:
def read_triplets(save=True):
    triplets = {}
    triplet_ids = set()
    with open("../data/processed/mini/ace_" + file + "_triplets.json", "r") as f:
        data = json.load(f)
        for k, v in tqdm(data.items()):
            tokens = data[k]["tokens"]["tokens"]
            sentence_triplets = [
                {sro: [special_split(x) for x in idx] for sro, idx in values.items()}
                for entry, values in data[k].items()
                if "triplet" in entry and not is_implicit(values)
            ]

            if len(sentence_triplets) > 0:
                triplets[k] = {
                    "tokens": tokens,
                    "triplets": sentence_triplets,
                }

                triplets[k]["bio_tags"] = [
                    create_bio_tags(tokens, t["subject"], t["relation"], t["object"])
                    for t in triplets[k]["triplets"]
                ]
    if save:
        with open(
            "../data/processed/mini/ace_" + file + "_triplets_bio_all.json", "w"
        ) as f:
            json.dump(triplets, f)
    return triplets

In [4]:
triplets = read_triplets(save=False)

100%|██████████| 873/873 [00:00<00:00, 11877.99it/s]


In [5]:
# this is the number of sentences with at least one triplet
print(len(triplets))

734


## Deal with multiple labelings of the same sentence (keep multiply labeled sentences)

In [6]:
# skip for now

## Filter non-consecutive BIO tags, non_triplets, more than five tokens, empty fields

In [7]:
def is_bio_sequent(bio_tags):
    i_positions = []
    for i, token in enumerate(bio_tags):
        if token.startswith("I-"):
            i_positions.append(i)
    for position in i_positions:
        current_tag = bio_tags[position]
        previous_tag = bio_tags[position - 1]
        if not previous_tag == current_tag and not (
            previous_tag == "B-" + current_tag.split("-")[1]
        ):
            return False
    return True


def contains_triplet(bio_tags):
    return (
        "B-Subject" in bio_tags and "B-Relation" in bio_tags and "B-Object" in bio_tags
    )


def longer_than_five(x):
    return sum(["Relation" in elem for elem in x]) > 5


def check_empty(s, r, o):
    if len(s) == 0 or len(r) == 0 or len(o) == 0:
        return True

In [8]:
triplets_filtered = defaultdict(lambda: defaultdict(str))

for k, v in tqdm(triplets.items()):
    triplet_filtered = []
    tag_filtered = []
    for ts, tags in zip(v["triplets"], v["bio_tags"]):
        if (
            is_bio_sequent(tags)
            and contains_triplet(tags)
            and not longer_than_five(tags)
            and not check_empty(*[ts[x] for x in ["subject", "relation", "object"]])
        ):
            triplet_filtered.append(ts)
            tag_filtered.append(tags)
    if len(tag_filtered) > 0:
        triplets_filtered[k]["tokens"] = v["tokens"]
        triplets_filtered[k]["triplets"] = triplet_filtered
        triplets_filtered[k]["bio_tags"] = tag_filtered

100%|██████████| 734/734 [00:00<00:00, 45829.15it/s]


In [9]:
print(len(triplets_filtered))

485


## Check how many instances have the data in some other order than SRO

In [10]:
def sro_order(s, r, o):
    for s_e in s:
        for r_e in r:
            if not (s_e[1] < r_e[1]):
                return False

    for r_e in r:
        for o_e in o:
            if not (r_e[1] < o_e[1]):
                return False

    return True


def check_data(data):
    discard = defaultdict(list)
    for k, v in data.items():
        for i, triplet in enumerate(v["triplets"]):
            if not sro_order(
                triplet["subject"], triplet["relation"], triplet["object"]
            ):
                discard[k].append(i)
    return discard


triplets_to_discard = check_data(triplets_filtered)

In [11]:
triplets_final = defaultdict(lambda: defaultdict(str))

for k, v in tqdm(triplets_filtered.items()):
    if k not in triplets_to_discard.keys():
        triplets_final[k]["tokens"] = triplets_filtered[k]["tokens"]
        triplets_final[k]["triplets"] = triplets_filtered[k]["triplets"]
        triplets_final[k]["bio_tags"] = triplets_filtered[k]["bio_tags"]
    else:
        for i, (ts, tags) in enumerate(zip(v["triplets"], v["bio_tags"])):
            ts_final = []
            tags_final = []
            if i not in triplets_to_discard[k]:
                ts_final.append(ts)
                tags_final.append(tags)
        if len(tags_final) > 0:
            triplets_final[k]["tokens"] = triplets_filtered[k]["tokens"]
            triplets_final[k]["triplets"] = ts_final
            triplets_final[k]["bio_tags"] = tags_final

100%|██████████| 485/485 [00:00<00:00, 300345.11it/s]


In [12]:
print(len(triplets_final))

485


## Drop duplicates

In [13]:
# TODO

## Viz some triplets

In [14]:
import spacy
from spacy import displacy
from spacy.tokens import Doc

nlp = spacy.load("en_core_web_lg")


def viz_sentence(tokens, tags):
    assert len(tokens) == len(tags)

    doc = Doc(nlp.vocab, words=tokens, ents=tags)

    displacy.render(
        doc,
        style="ent",
        options={
            "ents": ["Subject", "Relation", "Object"],
            "colors": {
                "Subject": "#ff6961",
                "Relation": "#3CB371",
                "Object": "#85C1E9",
            },
        },
    )

In [None]:
for i, (k, v) in enumerate(triplets_final.items()):
    for tr, ta in zip(v["triplets"], v["bio_tags"]):
        if i < 100:
            viz_sentence(v["tokens"], ta)

## Total number of triplets

In [16]:
total = 0

for v in triplets_final.values():
    total += len(v["bio_tags"])

total

731

## Total number of relations without duplicates

In [17]:
visited = []

for v in triplets_final.values():
    tokens = v["tokens"]
    for bio_tags in v["bio_tags"]:
        pairs = []
        for token, tag in zip(tokens, bio_tags):
            pairs.append((token, tag if tag in ["B-Relation", "I-Relation"] else "O"))
        if pairs not in visited:
            visited.append(pairs)

len(visited)

665

## Total number of triplets without duplicates

In [18]:
visited = []

for v in triplets_final.values():
    tokens = v["tokens"]
    for bio_tags in v["bio_tags"]:
        pairs = []
        for token, tag in zip(tokens, bio_tags):
            pairs.append((token, tag))
        if pairs not in visited:
            visited.append(pairs)

len(visited)

718

## Dump to file

In [19]:
with open("../data/processed/mini/ace_" + file + "_triplets_filtered.json", "w") as f:
    json.dump(triplets_final, f)