In [1]:
import os
import random
import copy
import re
import json
import jsonlines
import numpy as np
import pandas as pd
from tqdm import tqdm
import numpy as np
from collections import Counter

In [2]:
key_columns = ["input", "target", "answer_choices", "task_type", "task_dataset", "sample_id"]
templates = [
    "判断临床试验筛选标准的类型：\\n[INPUT_TEXT]\\n选项：[LIST_LABELS]\\n答：",
    "确定试验筛选标准的类型：\\n[INPUT_TEXT]\\n类型选项：[LIST_LABELS]\\n答：",
    "[INPUT_TEXT]\\n这句话是什么临床试验筛选标准类型？\\n类型选项：[LIST_LABELS]\\n答：",
    "[INPUT_TEXT]\\n是什么临床试验筛选标准类型？\\n选项：[LIST_LABELS]\\n答：",
    "请问是什么类型？\\n[INPUT_TEXT]\\n临床试验筛选标准选项：[LIST_LABELS]\\n答："
  ]

In [3]:
def read_data(data_path):

    data = []

    with jsonlines.open(data_path, "r") as f:
        for meta_data in f:
            data.append(meta_data)

    return data

In [4]:
def save_data(data_path, data):
    # write all_data list to a new jsonl
    with jsonlines.open(data_path, "w") as w:
        for meta_data in data:
            w.write(meta_data)

In [5]:
def get_choice_num(data):

    count = 0
    choices = []

    for meta_data in data:
        count += len(meta_data["answer_choices"])
        choices += meta_data["answer_choices"]
    
    return count / len(data), len(np.unique(np.array(choices))), np.array(choices)

In [6]:
_, _, train_choice = get_choice_num(read_data("train.json"))
_, _, dev_choice = get_choice_num(read_data("dev.json"))
_, _, test_choice = get_choice_num(read_data("test.json"))

In [7]:
def get_re(templates):
    '''get the re formula for input and labels'''
    re1, re2 = [], []

    for temp in templates:

        re1.append(temp.replace("[INPUT_TEXT]", "([\s\S]*)").replace("[LIST_LABELS]", "[\s\S]*").replace("\\n答：", ""))   # match and ignore
        re2.append(temp.replace("[LIST_LABELS]", "([\s\S]*)").replace("[INPUT_TEXT]", "[\s\S]*").replace("\\n答：", ""))

    return re1, re2

In [8]:
train_data = read_data("train.json")
temp_re1, temp_re2 = get_re(templates)
input_list, answer_list, choice_list = [], [], []
anti_input = [] # check the sample that can't match the template

for meta_data in train_data:
    flag = True
    for temp in temp_re1:
        if re.match(temp, meta_data["input"]):
            input_list.append(re.match(temp, meta_data["input"])[1])
            answer_list.append(meta_data["target"])
            choice_list.append(meta_data["answer_choices"])
            flag = False
            break
    if flag:    
        anti_input.append(meta_data["input"])

In [9]:
len(input_list), len(set(input_list)), len(answer_list), len(set(answer_list)), len(anti_input)

(6000, 3622, 6000, 44, 0)

In [10]:
def get_input(raw_input, raw_choices):

    meta_temp = random.choice(templates)
    new_input = copy.deepcopy(meta_temp)
    choice_str = ""
    for j, meta_choice in enumerate(raw_choices):
        if j == (len(raw_choices) - 1):
            choice_str = choice_str + meta_choice
        else:
            choice_str = choice_str + meta_choice + "，"
    new_input = new_input.replace("[INPUT_TEXT]", raw_input)
    new_input = new_input.replace("[LIST_LABELS]", choice_str)

    return new_input

In [11]:
raw_data = json.load(open("CHIP-CTC_train.json", "r"))
aug_input_list, aug_answer_list = [], []

for meta_d in tqdm(raw_data):
    
    if meta_d["text"] not in aug_input_list:

        if meta_d["text"][0] == " ":
            meta_d["text"] = meta_d["text"][1:]
        aug_input_list.append(meta_d["text"])
        aug_answer_list.append(meta_d["label"])

100%|██████████| 22962/22962 [00:04<00:00, 5572.75it/s] 


In [12]:
len(set(input_list)), len(set(aug_input_list)), len(set(input_list+aug_input_list))

(3622, 22304, 22356)

In [31]:
label_dict = {}

for i in range(len(input_list)):
    if answer_list[i] not in label_dict.values():
        try:
            j = aug_input_list.index(input_list[i])
        except:
            continue
        label_dict[aug_answer_list[j]] = answer_list[i]
    
    if len(label_dict) == 44:
        break

In [36]:
new_input_list, new_answer_list, new_choice_list = [], [], []
for i in range(len(input_list)):
    if input_list[i] not in new_input_list:
        new_input_list.append(input_list[i])
        new_answer_list.append(answer_list[i])
        new_choice_list.append(choice_list[i])

In [37]:
len(new_input_list)

3622

In [38]:
random_index = random.sample(list(range(22304)), k=6000-3622)

In [39]:
for ri in random_index:
    new_input_list.append(aug_input_list[ri])
    new_answer_list.append(label_dict[aug_answer_list[ri]])
    n_choice = random.choice(list(range(1, 43)))
    meta_choice = random.sample(list(label_dict.values()), n_choice)
    if label_dict[aug_answer_list[ri]] not in meta_choice:
        meta_choice.append(label_dict[aug_answer_list[ri]])
    new_choice_list.append(meta_choice)

In [40]:
len(new_input_list)

6000

In [41]:
aug_data = []
count=0

for i in range(len(new_input_list)):

    meta_input = get_input(new_input_list[i], new_choice_list[i])
    meta_data = {"input": meta_input, "target": new_answer_list[i], "answer_choices": new_choice_list[i],
                    "task_type": "cls", "task_dataset": "CHIP-CTC", "sample_id": "train-"+str(count)}
    aug_data.append(meta_data)


In [42]:
save_data("train_aug.json", aug_data)