In [3]:
import json
import numpy as np
import torch
import sys
from transformers import BertTokenizer
from random import sample
from math import ceil

type_dict = {"none":0, "name":1, "location":2, "time":3, "contact":4,
             "ID":5, "profession":6, "biomarker":7, "family":8,
             "clinical_event":9, "special_skills":10, "unique_treatment":11,
             "account":12, "organization":13, "education":14, "money":15,
             "belonging_mark":16, "med_exam":17, "others":18}

def type2str(t):
    for key, value in type_dict.items():  # for name, age in dictionary.iteritems():  (for Python 2.x)
        if t == value:
            return key

In [8]:
PRETRAINED_LM = "hfl/chinese-bert-wwm"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM)
bert_data = []
with open ('./dataset/train_1.json', 'r') as json_file:
    data_file = json.load(json_file)
    print("start preprocessing...")
    c = 0
    for data in data_file:
        article = data['article']
        type_list = []
        for i, item in enumerate(data['item']):
            article = article[:item[1] + i*2] + "_" + item[3] + "_" + article[item[2] + i*2:]
            type_list.append(type_dict[item[4]])
        article = article.replace("醫師：", "[SEP]") \
        .replace("民眾：", "[SEP]").replace("家屬：", "[SEP]") \
        .replace("個管師：", "[SEP]").replace("護理師：", "[SEP]")
        tokens = tokenizer.tokenize(article)
        
        
        start_pos_label = np.full(len(tokens),0)
        end_pos_label = np.full(len(tokens),0)
        type_label = np.full(len(tokens), 0)
        count_back = 0
        begin = 0
        j = 0
        remove_list = []
        for i in range(len(tokens)):
            if tokens[i] == '_':
                remove_list.append(i)
                if count_back == 0:
                    start_pos_label[i+1] = 1
                    begin = i
                    count_back += 1
                else:
                    end_pos_label[i-1] = 1
                    type_label[begin : i] = type_list[j]
                    j += 1
                    count_back = 0

        start_pos_label = start_pos_label.tolist()
        end_pos_label = end_pos_label.tolist()
        type_label = type_label.tolist()

        for i in sorted(remove_list, reverse=True):
            del tokens[i], start_pos_label[i], end_pos_label[i] ,  type_label[i]
        
        # tokens[0] = "[CLS]"
        if tokens[0] == "[SEP]":
            del tokens[0], start_pos_label[0], end_pos_label[0] ,  type_label[0]
        tokens.append("[SEP]")
        start_pos_label.append(0)
        end_pos_label.append(0)
        type_label.append(0)

        ids = tokenizer.convert_tokens_to_ids(tokens)

        pt_dict = {'input_ids':ids, 
                   "start_pos_label":start_pos_label, 
                   'end_pos_label' :end_pos_label,
                   "type_label":type_label,
                   'article_id':data['id']}
        bert_data.append(pt_dict)
        c += 1
#         for i in range(len(start_pos_label)):
#             if start_pos_label[i] == 1:
#                 start = i
#             if end_pos_label[i] == 1:
#                 end = i
#                 print(tokens[start:end+1] , type2str(type_label[start]))

        print("\rprocessed %d data" %c, end="")
    
print("")


"""to length 512"""
bert_data_train_512 = []
bert_data_test_512 = []
c = 0
c1 = 0
c2 = 0
length = len(bert_data)
try:
    test_list = set(eval(sys.argv[2]))
except:
    test_list = sample(range(length), ceil(length * 0.33)) # split 1/3 testing data


error_count = 0

for data in bert_data:
    c += 1
    ids = data['input_ids']
    start_pos_label = data['start_pos_label']
    end_pos_label = data['end_pos_label']
    type_ = data['type_label']
    pos = 0
    flag = 0
    sep_pos = 0
    new_pos = 0
    while (pos < len(ids)):
        ids_512 = ids[pos : pos + 512]
        count_back = 0
        for i in range(min(511, len(ids_512)-1), 0, -1):
            if (ids_512[i] == 102): # 102 = [SEP]
                count_back += 1
                if (count_back == 1):
                    sep_pos = i
                    new_pos = pos + i + 1
                elif (count_back <= 3): # overlap n-1 sentences 
                    new_pos = pos + i + 1
        if(count_back == 0):
            sep_pos = 510
            new_pos = pos + 510 - 5 # overlap n tokens

        ids_512 = [101] + ids[pos : pos+sep_pos+1] + [0] * (512 - sep_pos - 2)
        seg_512 = [0] * 512

        att_512 = [1] * (sep_pos + 2) + [0] * (512 - sep_pos - 2)
        start_pos_label_512 = [0] + start_pos_label[pos : pos+sep_pos+1] + [0] * (512 - sep_pos - 2)
        end_pos_label_512 = [0] + end_pos_label[pos : pos+sep_pos+1] + [0] * (512 - sep_pos - 2)
        type_512 = [0] + type_[pos : pos+sep_pos+1] + [0] * (512 - sep_pos - 2)
        flag = 1 if (pos+sep_pos+1 >= len(ids)) else 0
        pos = new_pos
        
#         print("")
#         print('ids',len(ids_512))
#         print('seg',len(seg_512))
#         print('att',len(att_512))
#         print('start',len(start_pos_label_512))
#         print('end',len(end_pos_label_512))
#         print('type',len(type_512))
        
        if len(ids_512)!= 512:
            error_count += 1
            continue

       

        pt_dict = {"input_ids":ids_512, 
                   "seg":seg_512, 
                   "att":att_512,
                   "start_pos_label":start_pos_label_512,
                   "end_pos_label":end_pos_label_512,
                   "type_label":type_512,
                   "article_id":data['article_id']}
    
        
        if (data['article_id'] not in test_list):
            bert_data_train_512.append(pt_dict)
            c1 += 1
        else:
            bert_data_test_512.append(pt_dict)
            c2 += 1
        
        print("\rprocessed %d data to length 512" %(c1+c2), end="")

        if (flag): # read single talk
            break

torch.save(bert_data_train_512, "./dataset/train1_train_512_bert_data.pt")
torch.save(bert_data_test_512, "./dataset/train1_test_512_bert_data.pt")

print("")
print("processed %d origin datas to %d train datas and %d test datas in length 512"
    % (c, c1, c2))
print("Preprocess Done !!")
print("Testing set id list: ", test_list)
print(error_count)

start preprocessing...
processed 120 data
processed 522 data to length 512
processed 120 origin datas to 345 train datas and 177 test datas in length 512
Preprocess Done !!
Testing set id list:  [65, 26, 106, 74, 20, 76, 72, 31, 92, 17, 115, 7, 43, 113, 19, 44, 86, 27, 85, 93, 11, 91, 63, 90, 82, 107, 51, 84, 0, 77, 35, 71, 119, 9, 98, 79, 97, 45, 10, 2]
22
