In [10]:
!pip install sacremoses -q

import itertools
import json
import numpy as np
import random
from sacremoses import MosesDetokenizer

# Only using the Wikipedia-based data splits
INPUT_DATA_SPLITS = [
    "raw/train_wiki.json",
    "raw/val_wiki.json"
]

OUTPUT_DATA_SPLITS = {
    "train": "train.json",
    "val": "val.json",
    "test": "test.json"
}

RANDOM_SEED = 41332
VAL_TEST_SPLIT = [0.1, 0.1]

NO_RELATION_ID = "NA"
NO_RELATION_ORDER_DEPENDENT = True # If True, will generate both (a, b) and
                                   # (b, a) samples. If false, will only
                                   # generate one of the two.

DETOKENIZE_SAMPLES = True

ADD_ENTITY_DELIM_TOKENS = True
ENTITY_DELIM_TOKENS = {
    "h": ("[E1S]", "[E1E]"),
    "t": ("[E2S]", "[E2E]")
}

# Aggregate splits into single dataset
with open(INPUT_DATA_SPLITS[0], "r") as f:
    full_dataset = json.load(f)
for i in range(1, len(INPUT_DATA_SPLITS)):
    with open(INPUT_DATA_SPLITS[i], "r") as f:
        dataset = json.load(f)
        for relation, samples in dataset.items():
            full_dataset.setdefault(relation, [])
            full_dataset[relation].extend(samples)

# Remove entity type information (don't think this is useful to us)
for samples in full_dataset.values():
    for sample in samples:
        del sample["h"][1]
        del sample["t"][1]

# Identify entities and relations per token sequence
tokens_to_entities = {}
tokens_to_relations = {}
for samples in full_dataset.values():
    for sample in samples:
        tokens = tuple(sample["tokens"])
        tokens_to_entities.setdefault(tokens, {}).update({
            sample["h"][0]: sample["h"],
            sample["t"][0]: sample["t"]
        })
        tokens_to_relations.setdefault(tokens, set()).add(
            (sample["h"][0], sample["t"][0])
        )

# Add no relation (NA) class to dataset
if NO_RELATION_ORDER_DEPENDENT:
    generator = itertools.permutations
else:
    generator = itertools.combinations
for tokens, entities in tokens_to_entities.items():
    for a, b in generator(entities.keys(), 2):
        if (
            (a, b) not in tokens_to_relations[tokens]
            and (b, a) not in tokens_to_relations[tokens]
        ):
            full_dataset.setdefault(NO_RELATION_ID, []).append({
                "tokens": list(tokens),
                "h": entities[a],
                "t": entities[b]
            })

# Detokenize samples using Moses detokenizer if required
if DETOKENIZE_SAMPLES:
    detokenizer = MosesDetokenizer()
    for samples in full_dataset.values():
        for sample in samples:

            # Find all entity positions
            entity_poses = []
            for entity_id in ["h", "t"]:
                entity = sample[entity_id][0]
                for poses in sample[entity_id][1]:
                    entity_poses.append(
                        (entity_id, entity, poses[0], len(poses))
                    )
            entity_poses.sort(key=lambda x: x[2])

            # Check for nested entities
            for i, entity_pos in enumerate(entity_poses[:-1]):
                _, _, start, length = entity_pos
                if start + length > entity_poses[i + 1][2]:
                    raise ValueError(
                        "Overlapping or nested entities are not supported for "
                        f"detokenization, found here: {entity_poses}"
                    )

            # Detokenize entities seperately
            for entity_id, entity, start, length in reversed(entity_poses):
                detokenized = detokenizer.detokenize(
                    sample["tokens"][start:start + length]
                )

                # Add entity delimiter tokens if needed
                if ADD_ENTITY_DELIM_TOKENS:
                    detokenized = (
                        f"{ENTITY_DELIM_TOKENS[entity_id][0]}"
                        f" {detokenized} "
                        f"{ENTITY_DELIM_TOKENS[entity_id][1]}"
                    )

                # Add detokenized entity back into token list
                sample["tokens"] = (
                    sample["tokens"][:start] +
                    [detokenized] +
                    sample["tokens"][start + length:]
                )

            # Detokenize entire sample
            sample["tokens"] = detokenizer.detokenize(sample["tokens"])

            # Remove irrelevant information
            for entity_id in ["h", "t"]:
                sample[entity_id] = sample[entity_id][0]

# Set random seed
randomiser = random.Random(RANDOM_SEED)

# Split full dataset into new train/val/test splits
train_split, val_split, test_split = {}, {}, {}
for relation, samples in full_dataset.items():
    randomiser.shuffle(samples)
    val_samples, test_samples, train_samples = np.split(samples, [
            int(VAL_TEST_SPLIT[0] * len(samples)),
            int((VAL_TEST_SPLIT[0] + VAL_TEST_SPLIT[1]) * len(samples))
    ])
    train_split.setdefault(relation, []).extend(train_samples)
    val_split.setdefault(relation, []).extend(val_samples)
    test_split.setdefault(relation, []).extend(test_samples)

# Save new data splits
with open(OUTPUT_DATA_SPLITS["train"], "w") as f:
    json.dump(train_split, f)
with open(OUTPUT_DATA_SPLITS["val"], "w") as f:
    json.dump(val_split, f)
with open(OUTPUT_DATA_SPLITS["test"], "w") as f:
    json.dump(test_split, f)

# Print debug information
for split_name, split in ([
    ("Train", train_split),
    ("Validation", val_split),
    ("Test", test_split)
]):
    print(
        f"{split_name} split info:"
        f"\n\tNum. relations (excl. no relation): {len(split) - 1}"
        f"\n\tAvg. samples per relation (excl. no relation): {int(np.mean([len(samples) for relation, samples in split.items() if relation != NO_RELATION_ID]))}"
        f"\n\tNum. no relation samples: {len(split[NO_RELATION_ID])}\n"
    )


Train split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 560
	Num. no relation samples: 3296

Validation split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 70
	Num. no relation samples: 412

Test split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 70
	Num. no relation samples: 412

