In [1]:
import json
import random
from collections import defaultdict
from copy import deepcopy
from tabulate import tabulate

# Perturbe Dataset
Creates a perturbed dataset by randomly removing relations between entities from the train dataset.

In [2]:
TRAIN_DATASET_FILE = "/Users/mfa/code/lavis/jerex-infl/jerex/data/datasets/docred_joint/train_joint.json"
PERTURBED_FILE = "/Users/mfa/code/lavis/jerex-infl/jerex/data/datasets/docred_joint/train_joint_perturbed.json"
REMOVED_REL_FILE = "/Users/mfa/code/lavis/jerex-infl/jerex/data/datasets/docred_joint/train_joint_removed_rels.json"

REMOVE_RELATIONS_PERC = 0.1

In [3]:
with open(TRAIN_DATASET_FILE, "r") as fd:
    orig_data = json.load(fd)

In [4]:
def relation_stats(dataset):
    rels = defaultdict(int)

    for doc in dataset:
        for relation in doc['labels']:
            rels[relation['r']] += 1
    
    return dict(rels)

In [5]:
def perturbe_dataset(dataset: list, perc: int = REMOVE_RELATIONS_PERC) -> list:
    pert_dataset = deepcopy(dataset)
    removed_relations = list()
    
    for i, doc in enumerate(dataset):
        removed_labels = list()

        for j in range(len(doc['labels']))[-1::-1]:
            if random.random() <= perc:
                rel = doc['labels'][j]
                removed_labels.append(rel)
                pert_dataset[i]['labels'].pop(j)

        removed_relations.append(removed_labels)
    
    assert len(pert_dataset) == len(dataset)

    return pert_dataset, removed_relations

In [6]:
perturbed_dataset, removed_relations = perturbe_dataset(orig_data)

In [7]:
orig_rel_stats = relation_stats(orig_data)
pert_rel_stats = relation_stats(perturbed_dataset)
removed_rels = 0
for doc in removed_relations:
    removed_rels += len(doc)

In [8]:
for rel in orig_rel_stats:
    assert rel in pert_rel_stats

print("Change in relation counts:")
print("--------------------------","\n")

headers = ["Relation", "Count (orig)", "Count (pert)", "Change (percentage)"]
rows = []

# total
orig_rel_count = sum(orig_rel_stats.values())
pert_rel_count = sum(pert_rel_stats.values())
removed_perc = 1 - (pert_rel_count / orig_rel_count)
rows.append(["(Total)", orig_rel_count, pert_rel_count, removed_perc])
rows.append([])

for rel in orig_rel_stats:
    orig_rel_count = orig_rel_stats[rel]
    pert_rel_count = pert_rel_stats[rel]
    removed_perc = 1 - (pert_rel_count / orig_rel_count)
    rows.append([rel, orig_rel_count, pert_rel_count, removed_perc])

print(tabulate(rows, headers=headers))

Change in relation counts:
-------------------------- 

Relation      Count (orig)    Count (pert)    Change (percentage)
----------  --------------  --------------  ---------------------
(Total)              37486           33693              0.101184

P131                  4063            3639              0.104356
P1412                  133             116              0.12782
P6                     200             189              0.055
P17                   8785            7890              0.101878
P571                   488             454              0.0696721
P488                    69              68              0.0144928
P102                   383             346              0.0966057
P27                   2631            2378              0.0961612
P3373                  322             284              0.118012
P19                    466             426              0.0858369
P20                    189             176              0.0687831
P569                  1035   

### Dump perturbed dataset and removed relations

In [9]:
with open(PERTURBED_FILE, "w") as fd:
    json.dump(perturbed_dataset, fd)

with open(REMOVED_REL_FILE, "w") as fd:
    json.dump(removed_relations, fd)