In [8]:
import json
import random
from typing import Tuple
from collections import defaultdict
from copy import deepcopy
from tabulate import tabulate

# Perturbe Dataset
Creates a perturbed dataset by randomly removing relations between entities from the train dataset.

In [9]:
TRAIN_DATASET_FILE = "/Users/mfa/code/lavis/jerex-infl/jerex/data/datasets/docred_joint/train_joint.json"
PERTURBED_FILE = "./data/datasets/docred_joint/train_joint_perturbed_min.json"
REMOVED_REL_FILE = "./data/datasets/docred_joint/train_joint_removed_rels_min.json"

REMOVE_RELATIONS_PERC = 0.1
MINIFIED = True

In [10]:
with open(TRAIN_DATASET_FILE, "r") as fd:
    orig_data = json.load(fd)

In [11]:
def relation_stats(dataset):
    rels = defaultdict(int)

    for doc in dataset:
        for relation in doc['labels']:
            rels[relation['r']] += 1
    
    return dict(rels)

In [12]:
def minify_dataset(dataset: list) -> list:
    cpy = deepcopy(dataset)
    cpy = cpy[:10]

    return cpy

In [13]:
def perturbe_dataset(dataset: list, perc: int = REMOVE_RELATIONS_PERC) -> Tuple[list]:
    pert_dataset = deepcopy(dataset)
    removed_relations = list()
    
    for i, doc in enumerate(dataset):
        removed_labels = list()

        for j in range(len(doc['labels']))[-1::-1]:
            if random.random() <= perc:
                rel = doc['labels'][j]
                removed_labels.append(rel)
                pert_dataset[i]['labels'].pop(j)

        removed_relations.append(removed_labels)
    
    assert len(pert_dataset) == len(dataset)

    return pert_dataset, removed_relations

In [23]:
data = orig_data

if MINIFIED:
    data = minify_dataset(data)

perturbed_dataset, removed_relations = perturbe_dataset(data)

In [24]:
orig_rel_stats = relation_stats(orig_data)
pert_rel_stats = relation_stats(perturbed_dataset)
removed_rels = 0
for doc in removed_relations:
    removed_rels += len(doc)

In [25]:
# for rel in orig_rel_stats:
#     assert rel in pert_rel_stats

print("Change in relation counts:")
print("--------------------------","\n")

headers = ["Relation", "Count (orig)", "Count (pert)", "Change (percentage)"]
rows = []

# total
orig_rel_count = sum(orig_rel_stats.values())
pert_rel_count = sum(pert_rel_stats.values())
removed_perc = 1 - (pert_rel_count / orig_rel_count)
rows.append(["(Total)", orig_rel_count, pert_rel_count, removed_perc])
rows.append([])

for rel in orig_rel_stats:
    orig_rel_count = orig_rel_stats[rel]
    pert_rel_count = 0 if rel not in pert_rel_stats else pert_rel_stats[rel]
    removed_perc = 1 - (pert_rel_count / orig_rel_count)
    rows.append([rel, orig_rel_count, pert_rel_count, removed_perc])

print(tabulate(rows, headers=headers))

Change in relation counts:
-------------------------- 

Relation      Count (orig)    Count (pert)    Change (percentage)
----------  --------------  --------------  ---------------------
(Total)              37486              90               0.997599

P131                  4063               2               0.999508
P1412                  133               6               0.954887
P6                     200               1               0.995
P17                   8785              12               0.998634
P571                   488               2               0.995902
P488                    69               4               0.942029
P102                   383               4               0.989556
P27                   2631               8               0.996959
P3373                  322               2               0.993789
P19                    466               3               0.993562
P20                    189               4               0.978836
P569                  

### Dump perturbed dataset and removed relations

In [26]:
with open(PERTURBED_FILE, "w") as fd:
    json.dump(perturbed_dataset, fd)

with open(REMOVED_REL_FILE, "w") as fd:
    json.dump(removed_relations, fd)