# Dataset Preprocessing

The [FewRel](https://github.com/thunlp/FewRel) dataset, which we are planning to use, was intended for a slightly different relation extraction task than ours. The intended task involved learning new relations dynamically and the dataset has been pre-split into training and validation *(and testing, although this isn't publicly available)* subsets accordingly, this means that the relations found in the validation split are not present in the training split.

Our task is slightly different, and easier, as our model only needs to learn a fixed set of relations and we don't have to consider learning new relations dynamically after training. Therefore, we need to modify how the dataset has been split to ensure all relations appear in all subsets *(and we also have to create our own testing split)*.

Running the cell below will re-split the dataset to more suit our needs and save the new training, validation, and testing splits to disk. There are also options available to do some further preprocessing on the data if desired, as described below.

## Options

*   **INPUT_DATA_SPLITS** - File paths of the original dataset splits to include in the re-split dataset.

*   **OUTPUT_DATA_SPLITS** - File paths to save the new training, validation, and testing splits to.

*   **RANDOM_SEED** - Random seed used when dividing the samples into splits.

*   **VAL_TEST_SPLIT** - Proportion of samples to use in the new validation and testing splits.

*   **ADD_NO_RELATION_SAMPLES** - Whether to add a new "no relation" class, and samples, to the dataset (as it is not included in the original data). Samples for this new class are found by searching the dataset for entities defined within the same token sequence but without any defined relation.

 *   **NO_RELATION_ID** - ID to use for the new "no relation" class.

 *   **NO_RELATION_ORDER_DEPENDENT** - Whether the new "no relation" class is dependent on entity order. I.e. if true, both (a, b) and (b, a) samples will be generated for the class, if false, only one of the two will be generated.

*   **DETOKENIZE_SAMPLES** - Whether to detokenize the samples, as they come pre-tokenized in the dataset. I recommend keeping this set to true, since we don't know which tokenizer was originally used, thus making it hard to tokenize new inputs after training.

*   **ADD_ENTITY_DELIM_TOKENS** - Whether to add delimiter tokens around entities within each sample. E.g. "The [E1S] dog [E1E] chased the [E2S] cat [E2E]".

 *   **ENTITY_DELIM_TOKENS** - Start and end delimiter tokens to use for each entity type.



In [9]:
!pip install sacremoses -q

import itertools
import json
import numpy as np
import random
from sacremoses import MosesDetokenizer


# === OPTIONS ==================================================================

INPUT_DATA_SPLITS = (
    "raw/train_wiki.json",
    "raw/val_wiki.json"
)

OUTPUT_DATA_SPLITS = {
    "train": "train.json",
    "val": "val.json",
    "test": "test.json"
}

RANDOM_SEED = 41332

VAL_TEST_SPLIT = (0.1, 0.1)

ADD_NO_RELATION_SAMPLES = True
NO_RELATION_ID = "NA"
NO_RELATION_ORDER_DEPENDENT = True

DETOKENIZE_SAMPLES = True

ADD_ENTITY_DELIM_TOKENS = True
ENTITY_DELIM_TOKENS = {
    "h": ("[E1S]", "[E1E]"),
    "t": ("[E2S]", "[E2E]")
}

# ==============================================================================


# Aggregate splits into single dataset
with open(INPUT_DATA_SPLITS[0], "r") as f:
    full_dataset = json.load(f)
for i in range(1, len(INPUT_DATA_SPLITS)):
    with open(INPUT_DATA_SPLITS[i], "r") as f:
        dataset = json.load(f)
        for relation, samples in dataset.items():
            full_dataset.setdefault(relation, [])
            full_dataset[relation].extend(samples)

# Remove entity type information (don't think this is useful to us)
for samples in full_dataset.values():
    for sample in samples:
        del sample["h"][1]
        del sample["t"][1]

# Identify and add no relation samples to dataset
if ADD_NO_RELATION_SAMPLES:

    # Identify entities and relations per token sequence
    tokens_to_entities = {}
    tokens_to_relations = {}
    for samples in full_dataset.values():
        for sample in samples:
            tokens = tuple(sample["tokens"])
            tokens_to_entities.setdefault(tokens, {}).update({
                sample["h"][0]: sample["h"],
                sample["t"][0]: sample["t"]
            })
            tokens_to_relations.setdefault(tokens, set()).add(
                (sample["h"][0], sample["t"][0])
            )

    # Add no relation class to dataset
    if NO_RELATION_ORDER_DEPENDENT:
        generator = itertools.permutations
    else:
        generator = itertools.combinations
    for tokens, entities in tokens_to_entities.items():
        for a, b in generator(entities.keys(), 2):
            if (
                (a, b) not in tokens_to_relations[tokens]
                and (b, a) not in tokens_to_relations[tokens]
            ):
                full_dataset.setdefault(NO_RELATION_ID, []).append({
                    "tokens": list(tokens),
                    "h": entities[a],
                    "t": entities[b]
                })

# Detokenize samples using Moses detokenizer if required
if DETOKENIZE_SAMPLES:
    detokenizer = MosesDetokenizer()
    for samples in full_dataset.values():
        for sample in samples:

            # Find all entity positions
            entity_poses = []
            for entity_id in ["h", "t"]:
                entity = sample[entity_id][0]
                for poses in sample[entity_id][1]:
                    entity_poses.append(
                        (entity_id, entity, poses[0], len(poses))
                    )
            entity_poses.sort(key=lambda x: x[2])

            # Check for nested entities
            for i, entity_pos in enumerate(entity_poses[:-1]):
                _, _, start, length = entity_pos
                if start + length > entity_poses[i + 1][2]:
                    raise ValueError(
                        "Overlapping or nested entities are not supported for "
                        f"detokenization, found here: {entity_poses}"
                    )

            # Detokenize entities seperately
            for entity_id, entity, start, length in reversed(entity_poses):
                detokenized = detokenizer.detokenize(
                    sample["tokens"][start:start + length]
                )

                # Add entity delimiter tokens if required
                if ADD_ENTITY_DELIM_TOKENS:
                    detokenized = (
                        f"{ENTITY_DELIM_TOKENS[entity_id][0]}"
                        f" {detokenized} "
                        f"{ENTITY_DELIM_TOKENS[entity_id][1]}"
                    )

                # Add detokenized entity back into token list
                sample["tokens"] = (
                    sample["tokens"][:start] +
                    [detokenized] +
                    sample["tokens"][start + length:]
                )

            # Detokenize entire sample
            sample["tokens"] = detokenizer.detokenize(sample["tokens"])

            # Remove irrelevant information
            for entity_id in ["h", "t"]:
                sample[entity_id] = sample[entity_id][0]

# Set random seed
randomiser = random.Random(RANDOM_SEED)

# Split full dataset into new train/val/test splits
train_split, val_split, test_split = {}, {}, {}
for relation, samples in full_dataset.items():
    randomiser.shuffle(samples)
    val_samples, test_samples, train_samples = np.split(samples, [
            int(VAL_TEST_SPLIT[0] * len(samples)),
            int((VAL_TEST_SPLIT[0] + VAL_TEST_SPLIT[1]) * len(samples))
    ])
    train_split.setdefault(relation, []).extend(train_samples)
    val_split.setdefault(relation, []).extend(val_samples)
    test_split.setdefault(relation, []).extend(test_samples)

# Save new data splits
with open(OUTPUT_DATA_SPLITS["train"], "w") as f:
    json.dump(train_split, f)
with open(OUTPUT_DATA_SPLITS["val"], "w") as f:
    json.dump(val_split, f)
with open(OUTPUT_DATA_SPLITS["test"], "w") as f:
    json.dump(test_split, f)

# Print information
for split_name, split in ([
    ("Train", train_split),
    ("Validation", val_split),
    ("Test", test_split)
]):
    if ADD_NO_RELATION_SAMPLES:
        split_lens_excl_no_relation = [
            len(samples) for relation, samples in split.items()
            if relation != NO_RELATION_ID
        ]
        print(
            f"{split_name} split info:"
            f"\n\tNum. relations (excl. no relation): "
            f"{len(split_lens_excl_no_relation)}"
            f"\n\tAvg. samples per relation (excl. no relation): "
            f"{int(np.mean(split_lens_excl_no_relation))}"
            f"\n\tNum. no relation samples: "
            f"{len(split.get(NO_RELATION_ID, []))}\n"
        )
    else:
        print(
            f"{split_name} split info:"
            f"\n\tNum. relations: {len(split)}"
            f"\n\tAvg. samples per relation: "
            f"{int(np.mean([len(samples) for samples in split.values()]))}\n"
        )


Train split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 560
	Num. no relation samples: 3296

Validation split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 70
	Num. no relation samples: 412

Test split info:
	Num. relations (excl. no relation): 80
	Avg. samples per relation (excl. no relation): 70
	Num. no relation samples: 412

