In [24]:
import json
import numpy as np
import os
from collections import defaultdict

In [45]:
data_dir = '/Volumes/External HDD/dataset/tacred/data/json'
train_split_file = os.path.join(data_dir, 'train_split-0.1.json')
test_split_file = os.path.join(data_dir, 'test_split-0.1.json')
original_file = os.path.join(data_dir, 'train_negatives.json')

In [46]:
def load_data(data_file):
    with open(data_file, 'rb') as handle:
        data = json.load(handle)
    return data

def group_by_triple(data):
    triple2data = defaultdict(lambda: list())
    for example in data:
        relation = example['relation']
        subject_type = example['subj_type']
        object_type = example['obj_type']
        triple = (subject_type, relation, object_type)
        triple2data[triple].append(example)

    for triple, examples in triple2data.items():
        triple2data[triple] = np.array(examples)
    return triple2data

def triple2lengths(data1, data2):
    data2_triple = set(list(data2.keys()))
    seen_triple = set()
    for triple, samples1 in data1.items():
        seen_triple.add(triple)
        samples2 = data2[triple]
        print(f'{triple} | Train: {len(samples1)} | Test: {len(samples2)}')
    unseen_triple  = data2_triple - seen_triple
    for triple in unseen_triple:
        print(f'{triple} | Train: 0 | Test: {len(data2[triple])}')

In [47]:
train_data = load_data(train_split_file)
test_data = load_data(test_split_file)
negative_data = load_data(original_file)

In [48]:
train_triple2data = group_by_triple(train_data)
test_triple2data = group_by_triple(test_data)

In [49]:
triple2lengths(train_triple2data, test_triple2data)

('ORGANIZATION', 'no_relation', 'ORGANIZATION') | Train: 5805 | Test: 644
('PERSON', 'no_relation', 'CAUSE_OF_DEATH') | Train: 333 | Test: 36
('PERSON', 'no_relation', 'LOCATION') | Train: 1037 | Test: 115
('ORGANIZATION', 'no_relation', 'NUMBER') | Train: 3045 | Test: 338
('PERSON', 'no_relation', 'ORGANIZATION') | Train: 2514 | Test: 279
('PERSON', 'no_relation', 'TITLE') | Train: 1566 | Test: 173
('ORGANIZATION', 'no_relation', 'PERSON') | Train: 4087 | Test: 454
('PERSON', 'no_relation', 'DATE') | Train: 3021 | Test: 335
('PERSON', 'no_relation', 'NUMBER') | Train: 1644 | Test: 182
('PERSON', 'no_relation', 'PERSON') | Train: 11178 | Test: 1241
('PERSON', 'per:title', 'TITLE') | Train: 273 | Test: 30
('ORGANIZATION', 'no_relation', 'DATE') | Train: 3870 | Test: 429
('ORGANIZATION', 'no_relation', 'CITY') | Train: 846 | Test: 94
('PERSON', 'no_relation', 'STATE_OR_PROVINCE') | Train: 355 | Test: 39
('ORGANIZATION', 'no_relation', 'LOCATION') | Train: 1121 | Test: 124
('ORGANIZATION'

In [50]:
train_ids = set([d['id'] for d in train_data])
test_ids = set([d['id'] for d in test_data])
orig_id = set([d['id'] for d in negative_data])

In [51]:
orig_id - train_ids.union(test_ids)

set()

In [52]:
train_ids.intersection(test_ids)

set()