In [1]:
import json
import os

from tqdm import tqdm
from collections import defaultdict

# Load Data

In [2]:
def load_data(dtype='train'):
    data = []
    with open('../../casrel_data/WebNLG/raw_WebNLG/new_{}.json'.format(dtype), 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            sample = json.loads(line)
            idx = '{}-{}'.format(dtype, i)
            text = sample['sentText']
            spos = set()
            for item in sample['relationMentions']:
                spo = (item['em1Text'], item['label'].split('/')[-1], item['em2Text'])
                spos.add(spo)
            sample = {'id': idx, 'text': text, 'spos': list(spos)}
            data.append(sample)
    return data
def load_labels(dtype='train'):
    predicates = set()
    with open('../../casrel_data/WebNLG/raw_WebNLG/new_{}.json'.format(dtype), 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            sample = json.loads(line)
            for item in sample['relationMentions']:
                predicates.add(item['label'].split('/')[-1])
    predicates = sorted(list(predicates))
    return predicates
train_data = load_data('train')
valid_data = load_data('valid')
test_data = load_data('test')
test_normal_data = load_data('test_normal')
test_seo_data = load_data('test_seo')
test_epo_data = load_data('test_epo')
predicates = load_labels()

# Save

In [3]:
def save_data(dtype, data):
    with open('{}.json'.format(dtype), 'w', encoding='utf-8') as f:
        for sample in data:
            f.write(json.dumps(sample)+'\n')
save_data('train', train_data)
save_data('valid', valid_data)
save_data('test', test_data)
save_data('test_normal', test_normal_data)
save_data('test_seo', test_seo_data)
save_data('test_epo', test_epo_data)

In [4]:
def save_predicates(predicates):
    with open('predicates.txt', 'w', encoding='utf-8') as f:
        for p in predicates:
            f.write(p+'\n')
save_predicates(predicates)