In [1]:
import os
import random
import copy
import re
import json
import jsonlines
import numpy as np
import pandas as pd
from tqdm import tqdm
import numpy as np
from collections import Counter

In [2]:
key_columns = ["input", "target", "answer_choices", "task_type", "task_dataset", "sample_id"]
templates = [
    "找出指定的三元组：\\n[INPUT_TEXT]\\n实体间关系：[LIST_LABELS]\\n答：",
    "根据给定的实体间的关系，抽取具有这些关系的实体对：\\n[INPUT_TEXT]\\n实体间关系标签：[LIST_LABELS]\\n答：",
    "找出句子中的具有[LIST_LABELS]关系类型的头尾实体对：\\n[INPUT_TEXT]\\n答：",
    "[INPUT_TEXT]\\n问题：句子中的[LIST_LABELS]等关系类型三元组是什么？\\n答：",
    "给出句子中的[LIST_LABELS]等关系类型的实体对：[INPUT_TEXT]\\n答：",
    "[INPUT_TEXT]\\n这个句子里面具有一定医学关系的实体组有哪些？\\n三元组关系选项：[LIST_LABELS]\\n答：",
    "同时完成实体识别与关系识别：\\n[INPUT_TEXT]\\n三元组关系类型：[LIST_LABELS]\\n答："
  ]

In [3]:
def read_data(data_path):
    '''read data from jsonlines file'''
    data = []

    with jsonlines.open(data_path, "r") as f:
        for meta_data in f:
            data.append(meta_data)

    return data

In [4]:
def save_data(data_path, data):
    '''write all_data list to a new jsonl'''
    with jsonlines.open(data_path, "w") as w:
        for meta_data in data:
            w.write(meta_data)

In [5]:
def get_re(templates):
    '''get the re formula for input and labels'''
    re1, re2 = [], []

    for temp in templates:

        re1.append(temp.replace("[INPUT_TEXT]", "([\s\S]*)").replace("[LIST_LABELS]", "[\s\S]*").replace("\\n答：", ""))   # match and ignore
        re2.append(temp.replace("[LIST_LABELS]", "([\s\S]*)").replace("[INPUT_TEXT]", "[\s\S]*").replace("\\n答：", ""))

    return re1, re2

In [6]:
train_data = read_data("train.json")
temp_re1, temp_re2 = get_re(templates)
input_list, answer_list = [], []
anti_input = [] # check the sample that can't match the template

for meta_data in train_data:
    flag = True
    for temp in temp_re1:
        if re.match(temp, meta_data["input"]):
            input_list.append(re.match(temp, meta_data["input"])[1])
            answer_list.append(meta_data["target"])
            flag = False
            break
    if flag:    
        anti_input.append(meta_data["input"])

In [8]:
len(input_list), len(set(input_list)), len(answer_list), len(set(answer_list)), len(anti_input)

(3000, 2828, 3000, 2273, 0)

In [15]:
def get_input(raw_input, raw_choices):

    meta_temp = random.choice(templates)
    new_input = copy.deepcopy(meta_temp)
    choice_str = ""
    for j, meta_choice in enumerate(raw_choices):
        if j == (len(raw_choices) - 1):
            choice_str = choice_str + meta_choice
        else:
            choice_str = choice_str + meta_choice + "，"
    new_input = new_input.replace("[INPUT_TEXT]", raw_input)
    new_input = new_input.replace("[LIST_LABELS]", choice_str)

    return new_input

In [19]:
raw_data = read_data("CMeIE_train.json")
answer_temp = "具有[PRED]关系的头尾实体对如下：头实体为[SUB]，尾实体为[OBJ]。\n"
aug_input_list, aug_answer_list, aug_choices_list = [], [], []

for meta_d in raw_data:
    meta_aug_input = meta_d["text"]
    meta_aug_target = ""
    meta_aug_choice = []
    
    for meta_event in meta_d["spo_list"]:
        meta_event_str = copy.deepcopy(answer_temp)
        meta_event_str = meta_event_str.replace("[PRED]", meta_event["predicate"])
        meta_event_str = meta_event_str.replace("[SUB]", meta_event["subject"])
        meta_event_str = meta_event_str.replace("[OBJ]", meta_event["object"]["@value"])
        if meta_event["predicate"] not in meta_aug_choice:
            meta_aug_choice.append(meta_event["predicate"])
        
        meta_aug_target += meta_event_str

    meta_aug_target = meta_aug_target[:-2]
    aug_input_list.append(meta_aug_input)
    aug_answer_list.append(meta_aug_target)
    aug_choices_list.append(meta_aug_choice)

In [20]:
len(set(input_list)), len(set(aug_input_list)), len(set(input_list+aug_input_list))

(2828, 14339, 14850)

In [28]:
all_choices = []
for meta_data in train_data:
    all_choices += meta_data["answer_choices"]

In [30]:
len(all_choices)

7519

In [32]:
aug_data = []
count = 0

for i in range(len(aug_input_list)):

    if aug_input_list[i] not in input_list:
        no_choice = random.random()
        if no_choice >= 0.25:
            meta_input = get_input(aug_input_list[i], aug_choices_list[i])
            meta_data = {"input": meta_input, "target": aug_answer_list[i], "answer_choices": aug_choices_list[i],
                            "task_type": "spo_generation", "task_dataset": "CMeIE", "sample_id": "train-"+str(count)}
        else:
            random_choice = [random.choice(list(set(all_choices)-set(aug_choices_list[i])))]
            meta_input = get_input(aug_input_list[i], random_choice)
            meta_data = {"input": meta_input, "target": "没有指定类型的三元组", "answer_choices": random_choice,
                            "task_type": "spo_generation", "task_dataset": "CMeIE", "sample_id": "train-"+str(count)}
        aug_data.append(meta_data)
        count+=1

    if count >= 3000:
        break

In [33]:
save_data("train_aug.json", train_data+aug_data)