# Author: ddukic

## Import libraries

In [41]:
import os
from tqdm import tqdm
import json
from tqdm import tqdm
from collections import Counter
from collections import defaultdict

# tpath = "../data/processed/mini/evextra_"
tpath = "../data/processed/stanford/evextra_"
file = "test"

## Help functions

In [42]:
def convert_bio(tags, tag="Subject"):
    return [
        ("B-" + tag, x[1]) if i == 0 else ("I-" + tag, x[1]) for i, x in enumerate(tags)
    ]


def create_bio_tags(triplet_tokens, subject_, relation_, object_):
    bio_tags = ["O"] * len(triplet_tokens)

    subject_bio = convert_bio(subject_)
    relation_bio = convert_bio(relation_, "Relation")
    object_bio = convert_bio(object_, "Object")

    # for some reason MINI indexes go from 1
    for s in subject_bio:
        bio_tags[s[1] - 1] = s[0]

    for r in relation_bio:
        bio_tags[r[1] - 1] = r[0]

    for o in object_bio:
        bio_tags[o[1] - 1] = o[0]

    return bio_tags


def is_implicit(t):
    return (
        any("--2" in x for x in t["subject"])
        or any("--2" in x for x in t["relation"])
        or any("--2" in x for x in t["object"])
    )


def special_split(x):
    if "--2" in x:
        return [x.split("-", 1)[0], int(x.split("-", 1)[1])]
    else:
        return [x.rsplit("-", 1)[0], int(x.rsplit("-", 1)[1])]

In [43]:
def read_triplets(save=True):
    triplets = {}
    triplet_ids = set()
    with open(tpath + file + "_triplets.json", "r") as f:
        data = json.load(f)
        for k, v in tqdm(data.items()):
            tokens = data[k]["tokens"]["tokens"]
            sentence_triplets = [
                {sro: [special_split(x) for x in idx] for sro, idx in values.items()}
                for entry, values in data[k].items()
                if "triplet" in entry and not is_implicit(values)
            ]

            if len(sentence_triplets) > 0:
                triplets[k] = {
                    "tokens": tokens,
                    "triplets": sentence_triplets,
                }

                triplets[k]["bio_tags"] = [
                    create_bio_tags(tokens, t["subject"], t["relation"], t["object"])
                    for t in triplets[k]["triplets"]
                ]
    if save:
        with open(tpath + file + "_triplets_bio_all.json", "w") as f:
            json.dump(triplets, f)
    return triplets

In [44]:
triplets = read_triplets(save=False)

100%|██████████| 2482/2482 [00:00<00:00, 5595.11it/s]


In [45]:
# this is the number of sentences with at least one triplet
print(len(triplets))

1844


## Filter non-consecutive BIO tags, non_triplets, more than five tokens, empty fields

In [46]:
def is_bio_sequent(bio_tags):
    i_positions = []
    for i, token in enumerate(bio_tags):
        if token.startswith("I-"):
            i_positions.append(i)
    for position in i_positions:
        current_tag = bio_tags[position]
        previous_tag = bio_tags[position - 1]
        if not previous_tag == current_tag and not (
            previous_tag == "B-" + current_tag.split("-")[1]
        ):
            return False
    return True


def contains_triplet(bio_tags):
    return (
        "B-Subject" in bio_tags and "B-Relation" in bio_tags and "B-Object" in bio_tags
    )


def longer_than_five(x):
    return sum(["Relation" in elem for elem in x]) > 5


def check_empty(s, r, o):
    if len(s) == 0 or len(r) == 0 or len(o) == 0:
        return True

In [47]:
triplets_filtered_first = defaultdict(lambda: defaultdict(str))

for k, v in tqdm(triplets.items()):
    triplet_filtered = []
    tag_filtered = []
    for ts, tags in zip(v["triplets"], v["bio_tags"]):
        if (
            is_bio_sequent(tags)
            and contains_triplet(tags)
            and not longer_than_five(tags)
            and not check_empty(*[ts[x] for x in ["subject", "relation", "object"]])
        ):
            triplet_filtered.append(ts)
            tag_filtered.append(tags)
    if len(tag_filtered) > 0:
        triplets_filtered_first[k]["tokens"] = v["tokens"]
        triplets_filtered_first[k]["triplets"] = triplet_filtered
        triplets_filtered_first[k]["bio_tags"] = tag_filtered

100%|██████████| 1844/1844 [00:00<00:00, 20697.70it/s]


In [48]:
print(len(triplets_filtered_first))

1637


## Check how many instances have the data in some other order than SRO

In [49]:
def sro_order(s, r, o):
    for s_e in s:
        for r_e in r:
            if not (s_e[1] < r_e[1]):
                return False

    for r_e in r:
        for o_e in o:
            if not (r_e[1] < o_e[1]):
                return False

    return True


def check_data(data):
    discard = defaultdict(list)
    for k, v in data.items():
        for i, triplet in enumerate(v["triplets"]):
            if not sro_order(triplet["subject"], triplet["relation"], triplet["object"]):
                discard[k].append(i)
    return discard


triplets_to_discard = check_data(triplets_filtered_first)

In [50]:
triplets_filtered = defaultdict(lambda: defaultdict(str))

for k, v in tqdm(triplets_filtered_first.items()):
    if k not in triplets_to_discard.keys():
        triplets_filtered[k]["tokens"] = triplets_filtered_first[k]["tokens"]
        triplets_filtered[k]["triplets"] = triplets_filtered_first[k]["triplets"]
        triplets_filtered[k]["bio_tags"] = triplets_filtered_first[k]["bio_tags"]
    else:
        for i, (ts, tags) in enumerate(zip(v["triplets"], v["bio_tags"])):
            ts_final = []
            tags_final = []
            if i not in triplets_to_discard[k]:
                ts_final.append(ts)
                tags_final.append(tags)
        if len(tags_final) > 0:
            triplets_filtered[k]["tokens"] = triplets_filtered_first[k]["tokens"]
            triplets_filtered[k]["triplets"] = ts_final
            triplets_filtered[k]["bio_tags"] = tags_final

100%|██████████| 1637/1637 [00:00<00:00, 270498.98it/s]


In [51]:
print(len(triplets_filtered))

1380


## Deal with multiple labelings of the same sentence (merge if possible, if not keep the one with longest sum of BI tags)

In [52]:
def merge_tags(to_merge):
    length = len(to_merge[0])
    merged = [[] for _ in range(length)]
    for i in range(length):
        for x in to_merge:
            merged[i].append(x[i])
        if sum([x == "O" for x in merged[i]]) not in (len(to_merge) - 1, len(to_merge)):
            return False, []
    return True, [
        x[0] if x.count(x[0]) == len(x) else next(y for y in x if y != "O")
        for x in merged
    ]

def find_longest(bio_tags):
    longest = 0
    longest_idx = 0
    for i, tags in enumerate(bio_tags):
        length = sum([True if x != "O" else False for x in tags])
        if length > longest:
            longest = length
            longest_idx = i
    return bio_tags[longest_idx]

In [53]:
triplets_final = {}

for t in tqdm(triplets_filtered.keys()):
    if len(triplets_filtered[t]["triplets"]) == 1:
        triplets_final[t] = triplets_filtered[t]
        triplets_final[t]["triplets"] = triplets_filtered[t]["triplets"]
        triplets_final[t]["bio_tags"] = triplets_filtered[t]["bio_tags"]
    else:
        mergable, merged = merge_tags(triplets_filtered[t]["bio_tags"])
        if mergable:
            triplets_final[t] = triplets_filtered[t]
            triplets_final[t]["triplets"] = triplets_filtered[t]["triplets"]
            triplets_final[t]["bio_tags"] = [merged]
        else:
            triplets_final[t] = triplets_filtered[t]
            triplets_final[t]["triplets"] = triplets_filtered[t]["triplets"]
            triplets_final[t]["bio_tags"] = [find_longest(triplets_filtered[t]["bio_tags"])]

100%|██████████| 1380/1380 [00:00<00:00, 82164.20it/s]


In [54]:
print(len(triplets_final))

1380


## Viz some triplets

In [55]:
import spacy
from spacy import displacy
from spacy.tokens import Doc

nlp = spacy.load("en_core_web_lg")

def viz_sentence(tokens, tags):
    assert len(tokens) == len(tags)

    doc = Doc(nlp.vocab, words=tokens, ents=tags)
    
    displacy.render(
        doc,
        style="ent",
        options={
            "ents": ["Subject", "Relation", "Object"],
            "colors": {
                "Subject": "#ff6961",
                "Relation": "#3CB371",
                "Object": "#85C1E9",
            },
        },
    )

In [None]:
for i, (k, v) in enumerate(triplets_final.items()):
    for tr, ta in zip(v["triplets"], v["bio_tags"]):
        if i < 100:
            try:
               viz_sentence(v["tokens"], ta)
            except:
                print("----------------")
                print("Problem with", i)
                print("----------------")

## Total number of triplets

In [57]:
total = 0

for v in triplets_final.values():
    total += len(v["bio_tags"])
    
total

1380

## Total number of relations without duplicates

In [58]:
visited = []

for v in triplets_final.values():
    tokens = v["tokens"]
    for bio_tags in v["bio_tags"]:
        pairs = []
        for token, tag in zip(tokens, bio_tags):
            pairs.append((token, tag if tag in ["B-Relation", "I-Relation"] else "O"))
        if pairs not in visited:
            visited.append(pairs)
        
len(visited)

1371

## Total number of triplets without duplicates

In [59]:
visited = []

for v in triplets_final.values():
    tokens = v["tokens"]
    for bio_tags in v["bio_tags"]:
        pairs = []
        for token, tag in zip(tokens, bio_tags):
            pairs.append((token, tag))
        if pairs not in visited:
            visited.append(pairs)

len(visited)

1371

## Dump to file

In [60]:
with open(tpath + file + "_triplets_filtered_merged.json", "w") as f:
    json.dump(triplets_final, f)