In [1]:
import json
import random
from typing import Tuple
from collections import defaultdict
from copy import deepcopy
from tabulate import tabulate

# Perturbe Dataset
Creates a perturbed dataset by randomly removing relations between entities from the train dataset.

In [2]:
TRAIN_DATASET_FILE = "./data/datasets/docred/train_annotated.json"
PERTURBED_FILE = "./data/datasets/docred/train_annotated_perturbed.json"
REMOVED_REL_FILE = "./data/datasets/docred/train_annotated_removed_rels.json"

REMOVE_RELATIONS_PERC = 0.1
MINIFIED = False

In [3]:
with open(TRAIN_DATASET_FILE, "r") as fd:
    orig_data = json.load(fd)

In [4]:
def relation_stats(dataset):
    rels = defaultdict(int)

    for doc in dataset:
        for relation in doc['labels']:
            rels[relation['r']] += 1
    
    return dict(rels)

In [5]:
def minify_dataset(dataset: list) -> list:
    cpy = deepcopy(dataset)
    cpy = cpy[:10]

    return cpy

In [6]:
def perturbe_dataset(dataset: list, perc: int = REMOVE_RELATIONS_PERC) -> Tuple[list]:
    pert_dataset = deepcopy(dataset)
    removed_relations = list()
    
    for i, doc in enumerate(dataset):
        removed_labels = list()

        for j in range(len(doc['labels']))[-1::-1]:
            if random.random() <= perc:
                rel = doc['labels'][j]
                removed_labels.append(rel)
                pert_dataset[i]['labels'].pop(j)

        removed_relations.append(removed_labels)
    
    assert len(pert_dataset) == len(dataset)

    return pert_dataset, removed_relations

In [7]:
data = orig_data

if MINIFIED:
    data = minify_dataset(data)

perturbed_dataset, removed_relations = perturbe_dataset(data)

In [8]:
orig_rel_stats = relation_stats(orig_data)
pert_rel_stats = relation_stats(perturbed_dataset)
removed_rels = 0
for doc in removed_relations:
    removed_rels += len(doc)

In [9]:
# for rel in orig_rel_stats:
#     assert rel in pert_rel_stats

print("Change in relation counts:")
print("--------------------------","\n")

headers = ["Relation", "Count (orig)", "Count (pert)", "Change (percentage)"]
rows = []

# total
orig_rel_count = sum(orig_rel_stats.values())
pert_rel_count = sum(pert_rel_stats.values())
removed_perc = 1 - (pert_rel_count / orig_rel_count)
rows.append(["(Total)", orig_rel_count, pert_rel_count, removed_perc])
rows.append([])

for rel in orig_rel_stats:
    orig_rel_count = orig_rel_stats[rel]
    pert_rel_count = 0 if rel not in pert_rel_stats else pert_rel_stats[rel]
    removed_perc = 1 - (pert_rel_count / orig_rel_count)
    rows.append([rel, orig_rel_count, pert_rel_count, removed_perc])

print(tabulate(rows, headers=headers))

Change in relation counts:
-------------------------- 

Relation      Count (orig)    Count (pert)    Change (percentage)
----------  --------------  --------------  ---------------------
(Total)              38180           34307              0.101441

P159                   264             235              0.109848
P17                   8921            8001              0.103127
P131                  4193            3782              0.0980205
P150                  2004            1789              0.107285
P27                   2689            2393              0.110078
P569                  1044             936              0.103448
P19                    511             473              0.074364
P172                    79              64              0.189873
P571                   475             417              0.122105
P576                    79              66              0.164557
P607                   275             242              0.12
P30                    356        

### Dump perturbed dataset and removed relations

In [10]:
with open(PERTURBED_FILE, "w") as fd:
    json.dump(perturbed_dataset, fd)

with open(REMOVED_REL_FILE, "w") as fd:
    json.dump(removed_relations, fd)