In [6]:
import json
import numpy as np
import os
from collections import defaultdict

In [2]:
data_dir = '/Volumes/External HDD/dataset/tacred/data/json'
train_split_file = os.path.join(data_dir, 'train_split.json')
test_split_file = os.path.join(data_dir, 'test_split.json')

In [22]:
def load_data(data_file):
    with open(data_file, 'rb') as handle:
        data = json.load(handle)
    return data

def group_by_triple(data):
    triple2data = defaultdict(lambda: list())
    for example in data:
        relation = example['relation']
        subject_type = example['subj_type']
        object_type = example['obj_type']
        triple = (subject_type, relation, object_type)
        triple2data[triple].append(example)

    for triple, examples in triple2data.items():
        triple2data[triple] = np.array(examples)
    return triple2data

def triple2lengths(data1, data2):
    data2_triple = set(list(data2.keys()))
    seen_triple = set()
    for triple, samples1 in data1.items():
        seen_triple.add(triple)
        samples2 = data2[triple]
        print(f'{triple} | Train: {len(samples1)} | Test: {len(samples2)}')
    unseen_triple  = data2_triple - seen_triple
    for triple in unseen_triple:
        print(f'{triple} | Train: 0 | Test: {len(data2[triple])}')

In [4]:
train_data = load_data(train_split_file)
test_data = load_data(test_split_file)

In [16]:
train_triple2data = group_by_triple(train_data)
test_triple2data = group_by_triple(test_data)

In [23]:
triple2lengths(train_triple2data, test_triple2data)

('ORGANIZATION', 'no_relation', 'ORGANIZATION') | Train: 4283 | Test: 2166
('PERSON', 'no_relation', 'CAUSE_OF_DEATH') | Train: 228 | Test: 141
('PERSON', 'no_relation', 'LOCATION') | Train: 758 | Test: 394
('ORGANIZATION', 'no_relation', 'NUMBER') | Train: 2299 | Test: 1084
('PERSON', 'no_relation', 'TITLE') | Train: 1169 | Test: 570
('PERSON', 'no_relation', 'NUMBER') | Train: 1213 | Test: 613
('PERSON', 'no_relation', 'PERSON') | Train: 8255 | Test: 4164
('ORGANIZATION', 'no_relation', 'DATE') | Train: 2831 | Test: 1468
('ORGANIZATION', 'no_relation', 'CITY') | Train: 644 | Test: 296
('ORGANIZATION', 'no_relation', 'LOCATION') | Train: 810 | Test: 435
('ORGANIZATION', 'no_relation', 'PERSON') | Train: 3041 | Test: 1500
('ORGANIZATION', 'no_relation', 'COUNTRY') | Train: 1053 | Test: 529
('PERSON', 'no_relation', 'DATE') | Train: 2247 | Test: 1109
('PERSON', 'no_relation', 'CITY') | Train: 505 | Test: 232
('ORGANIZATION', 'no_relation', 'STATE_OR_PROVINCE') | Train: 263 | Test: 126
(