In [1]:
import json
import os
from tqdm import tqdm
import re
from transformers import BertTokenizerFast
import copy
import torch
from common.utils import Preprocessor
import yaml
import logging
from pprint import pprint
from IPython.core.debugger import set_trace

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [19]:
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
config = yaml.load(open("build_data_config.yaml", "r"), Loader = yaml.FullLoader)

In [20]:
def load_json(filepath, by_line=True):
    data = []
    with open(filepath, 'r') as handle:
        if by_line:
            for line in handle:
                data.append(
                    json.loads(line)
                )
        else:
            data = json.load(handle)
    return data

In [24]:
exp_name = config["exp_name"]
# data_in_dir = '/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_star'
# data_out_dir = '/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_star'
data_in_dir = os.path.join(config["data_in_dir"])
data_out_dir = os.path.join(config["data_out_dir"])
if not os.path.exists(data_out_dir):
    os.makedirs(data_out_dir)

In [25]:
data_in_dir

'/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_raw'

# Load Data

In [26]:
file_name2data = {}
for path, folds, files in os.walk(data_in_dir):
    for file_name in files:
        print(file_name)
        file_path = os.path.join(path, file_name)
        file_name = re.match("(.*?)\.json", file_name).group(1)
        print(file_path)
        file_name2data[file_name] = load_json(file_path)

raw_train.json
/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_raw/raw_train.json
raw_test.json
/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_raw/raw_test.json
raw_valid.json
/Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_raw/raw_valid.json


In [27]:
train_data = file_name2data['raw_train']
dev_data = file_name2data['raw_valid']
test_data = file_name2data['raw_test']
print('Train: {} | Dev: {} | Test: {}'.format(
    len(train_data), len(dev_data), len(test_data)
))

Train: 56196 | Dev: 5000 | Test: 5000


In [28]:
train_data[0].keys()

dict_keys(['sentText', 'articleId', 'relationMentions', 'entityMentions', 'sentId'])

In [29]:
train_data[3]['entityMentions']

[{'start': 0, 'label': 'LOCATION', 'text': 'Columbus'},
 {'start': 1, 'label': 'LOCATION', 'text': 'Ohio'},
 {'start': 2, 'label': 'PERSON', 'text': 'Zach Wells'},
 {'start': 3, 'label': 'PERSON', 'text': 'Kyle Martino'}]

In [30]:
train_data[6]['relationMentions']

[{'em1Text': 'Anthony D. Weiner',
  'em2Text': 'Brooklyn',
  'label': '/people/person/place_lived'},
 {'em1Text': 'Anthony D. Weiner',
  'em2Text': 'Queens',
  'label': '/people/person/place_lived'}]

In [31]:
train_data[2]['entityMentions']

[{'start': 0, 'label': 'LOCATION', 'text': 'Debra Hill'},
 {'start': 1, 'label': 'LOCATION', 'text': 'Haddonfield'}]

# Preprocess

In [32]:
# @specific
from transformers import AutoTokenizer
import transformers

bert_path = 'bert-large-cased'
if config["encoder"] == "BERT":
#     tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens = False, do_lower_case = False)
    tokenizer = BertTokenizerFast.from_pretrained(bert_path, add_special_tokens = False, do_lower_case = False)
#     tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    tokenize = tokenizer.tokenize
    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)["offset_mapping"]
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")
    def get_tok2char_span_map(text):
        tokens = tokenize(text)
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1 # +1: whitespace
        return tok2char_span

In [33]:
preprocessor = Preprocessor(tokenize_func = tokenize, 
                            get_tok2char_span_map_func = get_tok2char_span_map)

## Transform

In [34]:
ori_format = config["ori_data_format"]
ori_format = 'raw_nyt'
print(ori_format)
if ori_format != "tplinker": # if tplinker, skip transforming
    for file_name, data in file_name2data.items():
        if "train" in file_name:
            data_type = "train"
        if "valid" in file_name:
            data_type = "valid"
        if "test" in file_name:
            data_type = "test"
        data = preprocessor.transform_data(data, ori_format = ori_format, dataset_type = data_type, add_id = True)
        file_name2data[file_name] = data

Transforming data format: 56196it [00:00, 445318.56it/s]
Clean:   0%|          | 0/56196 [00:00<?, ?it/s]

raw_nyt


Clean: 100%|██████████| 56196/56196 [00:00<00:00, 236967.80it/s]
Transforming data format: 5000it [00:00, 557412.22it/s]
Clean: 100%|██████████| 5000/5000 [00:00<00:00, 193805.69it/s]
Transforming data format: 5000it [00:00, 358420.13it/s]
Clean: 100%|██████████| 5000/5000 [00:00<00:00, 185337.73it/s]


In [35]:
train_data = file_name2data['raw_train']
dev_data = file_name2data['raw_valid']
test_data = file_name2data['raw_test']
print('Train: {} | Dev: {} | Test: {}'.format(
    len(train_data), len(dev_data), len(test_data)
))

Train: 56196 | Dev: 5000 | Test: 5000


In [36]:
train_data[1]

{'text': 'North Carolina EASTERN MUSIC FESTIVAL Greensboro , June 25-July 30 .',
 'id': 'train_1',
 'relation_list': [{'subject': 'North Carolina',
   'predicate': '/location/location/contains',
   'object': 'Greensboro'}]}

## Clean and Add Spans

In [37]:
# check token level span
def check_tok_span(data):
    def extr_ent(text, tok_span, tok2char_span):
        char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
        char_span = (char_span_list[0][0], char_span_list[-1][1])
        decoded_ent = text[char_span[0]:char_span[1]]
        return decoded_ent

    span_error_memory = set()
    for sample in tqdm(data, desc = "check tok spans"):
        text = sample["text"]
        tok2char_span = get_tok2char_span_map(text)
        for rel in sample["relation_list"]:
            subj_tok_span, obj_tok_span = rel["subj_tok_span"], rel["obj_tok_span"]
            if extr_ent(text, subj_tok_span, tok2char_span) != rel["subject"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, subj_tok_span, tok2char_span), rel["subject"]))
            if extr_ent(text, obj_tok_span, tok2char_span) != rel["object"]:
                span_error_memory.add("extr: {}---gold: {}".format(extr_ent(text, obj_tok_span, tok2char_span), rel["object"]))
                
    return span_error_memory

In [38]:
# clean, add char span, tok span
# collect relations
# check tok spans
rel_set = set()
error_statistics = {}
for file_name, data in file_name2data.items():
    assert len(data) > 0
    if "relation_list" in data[0]: # train or valid data
        # rm redundant whitespaces
        # separate by whitespaces
        data = preprocessor.clean_data_wo_span(data, separate = config["separate_char_by_white"])
        error_statistics[file_name] = {}
        # add char span
        if config["add_char_span"]:
            data, miss_sample_list = preprocessor.add_char_span(data, config["ignore_subword"])
            error_statistics[file_name]["miss_samples"] = len(miss_sample_list)
            
        # clean
        data, bad_samples_w_char_span_error = preprocessor.clean_data_w_span(data)
        error_statistics[file_name]["char_span_error"] = len(bad_samples_w_char_span_error)
        
        # add tok span
        data = preprocessor.add_tok_span(data)
        
        # collect relations
        for sample in tqdm(data, desc = "collect relations"):
            for rel in sample["relation_list"]:
                rel_set.add(rel["predicate"])
        
        # check tok span
        if config["check_tok_span"]:
            span_error_memory = check_tok_span(data)
            if len(span_error_memory) > 0:
                print(span_error_memory)
            error_statistics[file_name]["tok_span_error"] = len(span_error_memory)
            
        file_name2data[file_name] = data
pprint(error_statistics)

clean data: 100%|██████████| 56196/56196 [00:00<00:00, 61059.06it/s]
clean data w char spans: 100%|██████████| 56196/56196 [00:00<00:00, 94366.13it/s]
Adding token level spans: 100%|██████████| 56032/56032 [00:03<00:00, 15991.43it/s]
collect relations: 100%|██████████| 56032/56032 [00:00<00:00, 1179345.44it/s]
check tok spans: 100%|██████████| 56032/56032 [00:00<00:00, 81667.28it/s]
clean data: 100%|██████████| 5000/5000 [00:00<00:00, 56174.07it/s]
clean data w char spans: 100%|██████████| 5000/5000 [00:00<00:00, 156289.27it/s]
Adding token level spans:   0%|          | 0/4979 [00:00<?, ?it/s]

{'extr: Brooklynites---gold: Brooklyn', 'extr: Austrians---gold: Austria', 'extr: Oregonian---gold: Oregon', 'extr: Sukarnoputri---gold: Sukarno', 'extr: Congressional---gold: Congress', 'extr: Kyrgyzstan---gold: Kyrgyz', 'extr: Germanys---gold: Germany', 'extr: Chinatown---gold: China', 'extr: Africans---gold: Africa', 'extr: Australian---gold: Australia', 'extr: Cambodian---gold: Cambodia', 'extr: Japanese---gold: Japan', 'extr: Rwandan---gold: Rwanda', 'extr: Brooklyn-Queens---gold: Brooklyn', 'extr: Russian---gold: Russia', 'extr: Brooklyn-Queens---gold: Queens', 'extr: Iraqi-manned---gold: Iraq', 'extr: Miami-Dade---gold: Miami', 'extr: Vietnamese---gold: Vietnam', 'extr: Taiwanese---gold: Taiwan', 'extr: Rutgers-Newark---gold: Newark', 'extr: Israelis---gold: Israel', 'extr: Arab-Israeli---gold: Israel', 'extr: Zimbabwean---gold: Zimbabwe', 'extr: Chicago-based---gold: Chicago', 'extr: Catalonian---gold: Catalonia', 'extr: Indianapolis---gold: Indiana', 'extr: Brooklyn-born---gol

Adding token level spans: 100%|██████████| 4979/4979 [00:00<00:00, 16462.72it/s]
collect relations: 100%|██████████| 4979/4979 [00:00<00:00, 877824.28it/s]
check tok spans: 100%|██████████| 4979/4979 [00:00<00:00, 70464.32it/s]
clean data: 100%|██████████| 5000/5000 [00:00<00:00, 53364.00it/s]
clean data w char spans: 100%|██████████| 5000/5000 [00:00<00:00, 167854.07it/s]
Adding token level spans:   0%|          | 0/4983 [00:00<?, ?it/s]

{'extr: New Yorkers---gold: New York', 'extr: Israeli-Lebanon---gold: Lebanon', 'extr: Ex-Enron---gold: Enron', 'extr: Asia-Pacific---gold: Asia', 'extr: Queensborough---gold: Queens', 'extr: one-China---gold: China', 'extr: Texas-Oklahoma---gold: Texas', 'extr: Africans---gold: Africa', 'extr: Australian---gold: Australia', 'extr: Japanese---gold: Japan', 'extr: Iranian---gold: Iran', 'extr: Russian---gold: Russia', 'extr: Miami-Dade---gold: Miami', 'extr: Serbian---gold: Serbia', 'extr: Iraqi---gold: Iraq', 'extr: Taiwanese---gold: Taiwan', 'extr: New Jersey-oriented---gold: New Jersey', 'extr: Cambodian-born---gold: Cambodia', 'extr: New Zealander---gold: New Zealand', 'extr: Israelis---gold: Israel', 'extr: Arab-Israeli---gold: Israel', 'extr: Tasmanian---gold: Tasmania', 'extr: Indians---gold: India', 'extr: Zambian---gold: Zambia', 'extr: Chilean---gold: Chile', 'extr: Newarks---gold: Newark', 'extr: Indian---gold: India', 'extr: Cuban---gold: Cuba', 'extr: European---gold: Europ

Adding token level spans: 100%|██████████| 4983/4983 [00:00<00:00, 16941.90it/s]
collect relations: 100%|██████████| 4983/4983 [00:00<00:00, 1108118.17it/s]
check tok spans: 100%|██████████| 4983/4983 [00:00<00:00, 81316.52it/s]

{'extr: New Yorkers---gold: New York', 'extr: Queensborough---gold: Queens', 'extr: Austrians---gold: Austria', 'extr: Liberian---gold: Liberia', 'extr: anti-Syrian---gold: Syria', 'extr: Chinatown---gold: China', 'extr: non-European---gold: Europe', 'extr: Australian---gold: Australia', 'extr: Japanese---gold: Japan', 'extr: Iranian---gold: Iran', 'extr: Russian---gold: Russia', 'extr: Iraqi---gold: Iraq', 'extr: Urumqi-Beijing---gold: Beijing', 'extr: Miller-Great Neck---gold: Great Neck', 'extr: Israelis---gold: Israel', 'extr: Indian-restaurant-packed---gold: India', 'extr: New Yorker---gold: New York', 'extr: Indian---gold: India', 'extr: Cuban---gold: Cuba', 'extr: Bolivian---gold: Bolivia', 'extr: European---gold: Europe', 'extr: Russian-Israeli---gold: Israel', 'extr: South Korean---gold: South Korea', 'extr: Israeli-Palestinian---gold: Israel', 'extr: Iraqis---gold: Iraq', 'extr: Israeli---gold: Israel', 'extr: African---gold: Africa', 'extr: Pakistani---gold: Pakistan', 'extr




In [39]:
train_data = file_name2data['raw_train']
dev_data = file_name2data['raw_valid']
test_data = file_name2data['raw_test']
print('Train: {} | Dev: {} | Test: {}'.format(
    len(train_data), len(dev_data), len(test_data)
))

Train: 56032 | Dev: 4983 | Test: 4979


In [42]:
file_name2data.keys()

dict_keys(['raw_train', 'raw_test', 'raw_valid'])

In [43]:
file_name2data['raw_train'][0]

{'text': 'Massachusetts ASTON MAGNA Great Barrington ; also at Bard College , Annandale-on-Hudson , N.Y. , July 1-Aug .',
 'id': 'train_0',
 'relation_list': [{'subject': 'Annandale-on-Hudson',
   'predicate': '/location/location/contains',
   'object': 'Bard College',
   'subj_char_span': [68, 87],
   'obj_char_span': [53, 65],
   'subj_tok_span': [11, 12],
   'obj_tok_span': [8, 10]}]}

# Output to Disk

In [44]:
rel_set = sorted(rel_set)
rel2id = {rel:ind for ind, rel in enumerate(rel_set)}
data_statistics = {
    "relation_num": len(rel2id),
}

for file_name, data in file_name2data.items():
    data_path = os.path.join(data_out_dir, "{}.json".format(file_name))
    json.dump(data, open(data_path, "w", encoding = "utf-8"), ensure_ascii = False)
    logging.info("{} is output to {}".format(file_name, data_path))
    data_statistics[file_name] = len(data)

rel2id_path = os.path.join(data_out_dir, "rel2id.json")
json.dump(rel2id, open(rel2id_path, "w", encoding = "utf-8"), ensure_ascii = False)
logging.info("rel2id is output to {}".format(rel2id_path))

data_statistics_path = os.path.join(data_out_dir, "data_statistics.txt")
json.dump(data_statistics, open(data_statistics_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
logging.info("data_statistics is output to {}".format(data_statistics_path)) 

pprint(data_statistics)

INFO:root:raw_train is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/raw_train.json
INFO:root:raw_test is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/raw_test.json
INFO:root:raw_valid is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/raw_valid.json
INFO:root:rel2id is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/rel2id.json
INFO:root:data_statistics is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/data_statistics.txt


{'raw_test': 4979, 'raw_train': 56032, 'raw_valid': 4983, 'relation_num': 24}


# Genrate WordDict

In [45]:
if config["encoder"] in {"BiLSTM", }:
    all_data = []
    for data in list(file_name2data.values()):
        all_data.extend(data)
        
    token2num = {}
    for sample in tqdm(all_data, desc = "Tokenizing"):
        text = sample['text']
        for tok in tokenize(text):
            token2num[tok] = token2num.get(tok, 0) + 1
    
    token2num = dict(sorted(token2num.items(), key = lambda x: x[1], reverse = True))
    max_token_num = 50000
    token_set = set()
    for tok, num in tqdm(token2num.items(), desc = "Filter uncommon words"):
        if num < 3: # filter words with a frequency of less than 3
            continue
        token_set.add(tok)
        if len(token_set) == max_token_num:
            break
        
    token2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(token_set))}
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
#     idx2token = {idx:tok for tok, idx in token2idx.items()}
    
    dict_path = os.path.join(data_out_dir, "token2idx.json")
    json.dump(token2idx, open(dict_path, "w", encoding = "utf-8"), ensure_ascii = False, indent = 4)
    logging.info("token2idx is output to {}, total token num: {}".format(dict_path, len(token2idx))) 

Tokenizing: 100%|██████████| 65994/65994 [00:00<00:00, 90326.63it/s]
Filter uncommon words: 100%|██████████| 90818/90818 [00:00<00:00, 1714357.79it/s]
INFO:root:token2idx is output to /Users/georgestoica/Desktop/Research/TPlinker-joint-extraction/ori_data/nyt_bilstm/token2idx.json, total token num: 34790
