# step0: preprocessing task label (readminssion and expired)

In [1]:
import _pickle as cPickle
import csv
import os
import sys
import datetime
import random
import pandas as pd
from tqdm import tqdm
import sklearn.model_selection as ms
encounter_dict = {}

In [2]:
class EncounterInfo(object):

    def __init__(self, patient_id, encounter_id, encounter_timestamp, expired,
                 readmission, los_3day, los_7day, Death30, Death180, Death365, admission_type=99, marital_status=99):
        self.patient_id = patient_id
        self.encounter_id = encounter_id
        self.encounter_timestamp = encounter_timestamp
        self.expired = expired
        self.readmission = readmission
        self.admission_type = admission_type
        self.marital_status = marital_status
        self.los_3day = los_3day
        self.los_7day = los_7day
        self.gender = ''
        self.dx_ids = []
        self.dx_ids_lvl1 = []
        self.dx_names = []
        self.dx_labels = []
        self.rx_ids = []
        self.rx_names = []
        self.lab_ids = []
        self.lab_names = []
        self.microbiology_ids = []
        self.microbiology_names = []
        self.physicals = []
        self.procedures_ids = []
        self.procedures_names = []
        self.Death30 = Death30
        self.Death180 = Death180
        self.Death365 = Death365


def minNums(startTime, endTime):
    '''计算两个时间点之间的分钟数'''
    # 处理格式,加上秒位
    startTime1 = startTime  # + ':00'
    endTime1 = endTime  # + ':00'
    # 计算分钟数
    startTime2 = datetime.datetime.strptime(startTime1, "%Y-%m-%d %H:%M:%S")
    endTime2 = datetime.datetime.strptime(endTime1, "%Y-%m-%d %H:%M:%S")
    seconds = (endTime2 - startTime2).seconds
    # 来获取时间差中的秒数。注意，seconds获得的秒只是时间差中的小时、分钟和秒部分的和，并没有包含时间差的天数（既是两个时间点不是同一天，失效）
    total_seconds = (endTime2 - startTime2).total_seconds()
    # 来获取准确的时间差，并将时间差转换为秒
    # print(total_seconds)
    mins = total_seconds / 60
    return int(mins)

In [3]:
# random_seed 用于随机生成
def process_admission(admission_file, icustays_file, patients_file, encounter_dict, max_patient_num=500, hour_threshold=24):

    count = 0
    patient_dict = {}
    admission_list = []

    patient_base_dict = {}
    inff = open(patients_file, 'r')
    for line in csv.DictReader(inff):
        patient_id = line['SUBJECT_ID']
        gender = line['GENDER']
        birthday = line['DOB']
        if patient_id not in patient_base_dict:
            patient_base_dict[patient_id] = []
        patient_base_dict[patient_id].append((gender,birthday))
    inff.close()


    # --------------- icustays_file --------------------
#     patient_icu_dict = {}
#     inff = open(icustays_file, 'r')
#     for line in csv.DictReader(inff):
#         FIRST_CAREUNIT = line['FIRST_CAREUNIT']
#         LAST_CAREUNIT = line['LAST_CAREUNIT']
#         icu_intime = line['INTIME']
#         icu_outtime = line['OUTTIME']
#         patient_id = line['SUBJECT_ID']
#         encounter_id = line['HADM_ID']
#         birthday = patient_base_dict[patient_id][0][1]
#         age = minNums(birthday, icu_intime)/(24. * 60 * 365)

#         ## ICU stay以及年龄大于18
#         if FIRST_CAREUNIT == LAST_CAREUNIT and LAST_CAREUNIT=='MICU' and age > 18:
#             if patient_id not in patient_icu_dict:
#                 patient_icu_dict[patient_id] = []
#             patient_icu_dict[patient_id].append(encounter_id)
#     inff.close()


    # --------------- admission_file --------------------
    inff = open(admission_file, 'r')
    for line in csv.DictReader(inff):
        if count % 1000 == 0:
            sys.stdout.write('%d\r' % count)
            sys.stdout.flush()

        # if count == max_admission_num:
        #     break

        patient_id = line['SUBJECT_ID']
        encounter_id = line['HADM_ID']
        admittime = line['ADMITTIME']
        dischtime = line['DISCHTIME']

        # encounter_timestamp = -int(line['hospitaladmitoffset'])
        encounter_timestamp = minNums(admittime, dischtime)
        # encounter_timestamp：number of minutes from unit admit time that the patient was admitted to the hospital

#         # 只统计ICU的记录
#         if patient_id in patient_icu_dict.keys():
#             if encounter_id in patient_icu_dict[patient_id]:
        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append((admittime, encounter_timestamp, encounter_id,dischtime))

        count += 1
    inff.close()

    # 随机的数量不能大于总的患者数量
    if max_patient_num > len(patient_dict):
        max_patient_num = len(patient_dict)
    # 随机选取 患者
    patient_random_keys = random.sample(patient_dict.keys(), max_patient_num)
    patient_random_del_keys = []

    for patient_id in patient_dict.keys():
        if patient_id not in patient_random_keys:
            patient_random_del_keys.append(patient_id)
    # 删除不在随机范围内的患者记录
    for patient_random_del_key in patient_random_del_keys:
        del patient_dict[patient_random_del_key]
    # admission_list，只存储随机到的患者的就诊记录
    for patient_id, encounter_ids in patient_dict.items():
        for encounter_id in encounter_ids:
            if encounter_id[2] not in admission_list:
                admission_list.append(encounter_id[2])

    # sort
    patient_dict_sorted = {}
    for patient_id, time_enc_tuples in patient_dict.items():
        # print(time_enc_tuples)
        patient_dict_sorted[patient_id] = sorted(time_enc_tuples, reverse=False)

    # enc_readmission_dict 判断该患者是否是重新入院的
    enc_readmission_dict = {}
    for patient_id, time_enc_tuples in patient_dict_sorted.items():
        for time_enc_tuple in time_enc_tuples[:-1]:
            enc_id = time_enc_tuple[2]
            enc_readmission_dict[enc_id] = 1
        last_enc_id = time_enc_tuples[-1][2]
        enc_readmission_dict[last_enc_id] = 0


    # enc_readmission_dict 判断该患者是否是重新入院,在出院30天内
#     enc_readmission_dict = {}
#     for patient_id, time_enc_tuples in patient_dict_sorted.items(): 
#         flag=0
#         handle_flag = 0
#         for time_enc_tuple in time_enc_tuples:
#             enc_readmission_dict[time_enc_tuple[2]] = 0 # 先默认本次入院记录的next_readmission_flag 为0，后面再更新
#             if handle_flag == 1:  #第一次循环不处理下次入院时间，从第二次入院记录开始，更新上次入院的next_readmission day
#                 readminssionday = 999
#                 encounter_timestamp = minNums(last_time_enc_tuple[3], time_enc_tuple[0]) # 上次出院到本次入院的时间
#                 readminssionday =  encounter_timestamp/(24. * 60)
#                 if readminssionday<=30:
#                     enc_readmission_dict[last_time_enc_tuple[2]] = 1
#                 else:
#                     enc_readmission_dict[last_time_enc_tuple[2]] = 0
# #                 print(patient_id,last_time_enc_tuple[2],'>>>',last_time_enc_tuple[0],last_time_enc_tuple[3],'>>>',readminssionday)
#             handle_flag = 1
#             last_time_enc_tuple = time_enc_tuple
            

    inff = open(admission_file, 'r')
    count = 0
    for line in tqdm(csv.DictReader(inff)):
        if line['HADM_ID'] in admission_list:
            patient_id = line['SUBJECT_ID']
            encounter_id = line['HADM_ID']

            admittime = line['ADMITTIME']
            dischtime = line['DISCHTIME']
            deathtime = line['DEATHTIME']

            encounter_timestamp = minNums(admittime, dischtime)
            if deathtime is not None and deathtime!='':  
                death_timestamp = minNums(admittime, deathtime)
            else:
                death_timestamp=0

            # hospital_expire_flag：This is a binary flag which indicates whether the patient died within the given hospitalization.
            # ---------------------： 1 indicates death in the hospital, and 0 indicates survival to hospital discharge.
            hospital_expire_flag = line['HOSPITAL_EXPIRE_FLAG']
            duration_minute = encounter_timestamp
            losday =  duration_minute/(24. * 60)
            deathday = death_timestamp/(24. * 60)
            
            expired = 1 if hospital_expire_flag == '1' else 0
            readmission = 1 if enc_readmission_dict[encounter_id] ==1 else 0
            los_3day = 1 if losday > 3 else 0
            los_7day = 1 if losday > 7 else 0
            
            Death30 = 1 if 30>= deathday >0 else 0
            Death180 = 1 if 180>= deathday >0 else 0
            Death365 = 1 if 365>= deathday >0 else 0
            
        
            if duration_minute < 60. * hour_threshold:
                continue

            ei = EncounterInfo(patient_id, encounter_id, encounter_timestamp, expired,
                               readmission, los_3day, los_7day, Death30, Death180, Death365)
            if encounter_id in encounter_dict:
                print('Duplicate encounter ID!!')
                sys.exit(0)
            encounter_dict[encounter_id] = ei
            count = count + 1
    inff.close()

    print('Accepted Patients: {}'.format(max_patient_num))
    print('Accepted admissions: {}'.format(count))
    print('')
    return encounter_dict, admission_list

In [4]:
def process_patients(patients_file, encounter_dict):
    count = 0
    enc_dict = encounter_dict
    inff = open(patients_file, 'r')
    for line in csv.DictReader(inff):
        patient_id = line['SUBJECT_ID']
        gender = line['GENDER']
        for _, enc in enc_dict.items():
            if enc.patient_id == patient_id:
                enc.gender = gender
            count += 1
    inff.close()

    print('Accepted admissions: %d' % count)
    print('')
    return encounter_dict


In [5]:

# input_path = argv[1]
# output_path = argv[2]
# input_path = '/home/caoyu/jupyterNotebook/data_test'
# output_path = '/home/caoyu/project/healthRecords/data/mimiciv'

input_path = '/home/caoyu/project/GraphCLHealth/data/mimiciii'
output_path = '/home/caoyu/project/GraphCLHealth/processed_data/mimiciii'
minimum_cnt = 5
max_patient_num = 100000
print('max_patient_num:' + str(max_patient_num))


flag_test_flag = 1 # whether to use the test files for debugging
icd_diagnoses_file = input_path + '/D_ICD_DIAGNOSES.csv'
icd_procedures_file = input_path + '/D_ICD_PROCEDURES.csv'
icd_labItems_file = input_path + '/D_LABITEMS.csv'


if flag_test_flag == 0:
    admission_file = input_path + '/ADMISSIONS10.csv'
    diagnosis_file = input_path + '/DIAGNOSES_ICD10.csv'
    procedures_file = input_path + '/PROCEDURES_ICD10.csv'
    labevents_file = input_path + '/LABEVENTS10.csv'
    microbiology_file = input_path + '/microbiologyevents10.csv'
    patients_file = input_path + '/PATIENTS.csv'
    icustays_file = input_path + '/ICUSTAYS10.csv'
else:
    admission_file = input_path + '/ADMISSIONS.csv'
    diagnosis_file = input_path + '/DIAGNOSES_ICD.csv'
    procedures_file = input_path + '/PROCEDURES_ICD.csv'
    labevents_file = input_path + '/LABEVENTS.csv'
    microbiology_file = input_path + '/microbiologyevents.csv'
    patients_file = input_path + '/PATIENTS.csv'
    icustays_file = input_path + '/ICUSTAYS.csv'
# 调试用的文件

encounter_dict = {}

# max_admission_nun 处理最大的住院患者流水数量
print('Processing ADMISSIONS.csv')
encounter_dict, admission_list = process_admission(admission_file, icustays_file, patients_file, encounter_dict, max_patient_num, hour_threshold=24)
# print('Processing PATIENTS.csv')
# encounter_dict = process_patients(patients_file, encounter_dict)




max_patient_num:100000
Processing ADMISSIONS.csv
58000

58976it [00:36, 1605.72it/s]


Accepted Patients: 46520
Accepted admissions: 56684



In [6]:
expired = {}
readmission = {}
death30 = {}
death180 = {}
death365 = {}
for key in encounter_dict.keys():
    if encounter_dict[key].encounter_id not in expired:
        expired[encounter_dict[key].encounter_id] = encounter_dict[key].expired
    if encounter_dict[key].encounter_id not in readmission:
        readmission[encounter_dict[key].encounter_id] = encounter_dict[key].readmission
    if encounter_dict[key].encounter_id not in death30:
        death30[encounter_dict[key].encounter_id] = encounter_dict[key].Death30
    if encounter_dict[key].encounter_id not in death180:
        death180[encounter_dict[key].encounter_id] = encounter_dict[key].Death180      
    if encounter_dict[key].encounter_id not in death30:
        death365[encounter_dict[key].encounter_id] = encounter_dict[key].Death365 
        
#     print(key,encounter_dict[key].encounter_id,encounter_dict[key].expired,encounter_dict[key].readmission)

In [7]:
cnt = 0
for pa in readmission:
    if readmission[pa]==1:
        cnt+=1
print('readmission',cnt)


cnt = 0
for pa in death30:
    if death30[pa]==1:
        cnt+=1
print('death30',cnt)

cnt = 0
for pa in death180:
    if death180[pa]==1:
        cnt+=1
print('death180',cnt)


cnt = 0
for pa in expired:
    if expired[pa]==1:
        cnt+=1
print('expired', cnt)

#     print(death180[pa])
print (len(expired))
print (len(readmission))

readmission 12134
death30 4492
death180 4875
expired 4883
56684
56684


In [8]:
print(4492/56684)

0.07924634817585209


In [9]:
print(3275/56684)

0.05777644485216287


In [10]:
print(len(expired))

56684


In [11]:
print(len(readmission))

56684


# 5. Convert DB for ReAdmPred task

In [57]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [60]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
#         db = torch.load(f'{knowmix_PATH}/{out_db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in readmission:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(readmission[hadm_id])
                    global_label.append(readmission[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/readm/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/readm/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 216234.00it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### NoKGenc  --  valid


2000it [00:00, 237853.24it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### NoKGenc  --  test


2000it [00:00, 265840.85it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 570418.08it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### UniKGenc  --  valid


2000it [00:00, 404348.21it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### UniKGenc  --  test


2000it [00:00, 539842.20it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 513078.00it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedNoKGenc  --  valid


2000it [00:00, 433094.53it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  test


2000it [00:00, 464614.12it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 506472.14it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedUniKGenc  --  valid


2000it [00:00, 369444.55it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  test


2000it [00:00, 461698.94it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0


# 5. Convert DB for expired task （180）

In [61]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [62]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
#         db = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in death180:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(death180[hadm_id])
                    global_label.append(death180[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/Death180/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/Death180/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 300238.40it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### NoKGenc  --  valid


2000it [00:00, 300225.76it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### NoKGenc  --  test


2000it [00:00, 280414.78it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 564873.31it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### UniKGenc  --  valid


2000it [00:00, 437179.90it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### UniKGenc  --  test


2000it [00:00, 586862.18it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 615707.15it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedNoKGenc  --  valid


2000it [00:00, 514354.53it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  test


2000it [00:00, 536561.85it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 517423.17it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedUniKGenc  --  valid


2000it [00:00, 357845.24it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  test


2000it [00:00, 470187.10it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0


# 5. Convert DB for expired task （30）

In [63]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [64]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
#         db = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in death30:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(death30[hadm_id])
                    global_label.append(death30[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/Death30/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/Death30/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 220011.18it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### NoKGenc  --  valid


2000it [00:00, 228311.14it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### NoKGenc  --  test


2000it [00:00, 234318.66it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 536055.64it/s]


dxprx,2000,train 28349
20th sample ['</hadm_id/136706>', '</diagnoses/6695>', '</diagnoses/6693>', '</diagnoses/6692>', '</diagnoses/6697>', '</diagnoses/6696>', '</procedures/2746>', '</diagnoses/6694>', '</diagnoses_icd9_code/4280>']
20th label 0
##### UniKGenc  --  valid


2000it [00:00, 380902.15it/s]


dxprx,2000,valid 1963
20th sample ['</hadm_id/167487>', '</diagnoses/75303>', '</procedures/31300>', '</diagnoses/75306>', '</procedures/31299>', '</procedures/31301>', '</diagnoses/75301>', '</procedures/31298>', '</diagnoses/75302>']
20th label 0
##### UniKGenc  --  test


2000it [00:00, 520934.48it/s]


dxprx,2000,test 1964
20th sample ['</hadm_id/189391>', '</diagnoses/248511>', '</diagnoses/248498>', '</diagnoses/248515>', '</diagnoses/248522>', '</diagnoses/248512>', '</procedures/94257>', '</diagnoses/248503>', '</diagnoses/248514>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 610693.84it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedNoKGenc  --  valid


2000it [00:00, 416845.96it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  test


2000it [00:00, 554728.74it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 504504.33it/s]


dxprx,2000,train 28349
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses_icd9_code']
20th label 0
##### UnifiedUniKGenc  --  valid


2000it [00:00, 282596.95it/s]


dxprx,2000,valid 1963
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'procedures', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  test


2000it [00:00, 433049.82it/s]


dxprx,2000,test 1964
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses', 'diagnoses']
20th label 0


# 3. Convert DB for ErrorDetection for Px,Dx

In [18]:
# Load D_ICD files
import torch
import pandas as pd
d_proc = pd.read_csv("/home/caoyu/project/GraphCLHealth/data/mimiciii//D_ICD_PROCEDURES.csv")
d_diag = pd.read_csv("/home/caoyu/project/GraphCLHealth/data/mimiciii//D_ICD_DIAGNOSES.csv")

In [19]:
def reformat(code, is_diag):
    """
        Put a period in the right place because the MIMIC-3 data files exclude them.
        Generally, procedure codes have dots after the first two digits, 
        while diagnosis codes have dots after the first three digits.
    """
    code = ''.join(code.split('.'))
    if is_diag:
        if code.startswith('E'):
            if len(code) > 4:
                code = code[:4] + '.' + code[4:]
        else:
            if len(code) > 3:
                code = code[:3] + '.' + code[3:]
    else:
        code = code[:2] + '.' + code[2:]
    return code

In [20]:
d_diag['absolute_code'] = d_diag.apply(lambda row: str(reformat(str(row[1]), True)), axis=1)
d_proc['absolute_code'] = d_proc.apply(lambda row: str(reformat(str(row[1]), False)), axis=1)

In [21]:
d_diag['code_name'] = d_diag.apply(lambda row: str(row[3]).lower().strip(), axis=1)
d_proc['code_name'] = d_proc.apply(lambda row: str(row[3]).lower().strip(), axis=1)

In [22]:
d_diag_dict = d_diag[['absolute_code','code_name']].to_dict()
d_code2name = {d_diag_dict['absolute_code'][idx]:d_diag_dict['code_name'][idx] for idx in range(len(d_diag))}
d_name2codecat = {d_diag_dict['code_name'][idx]:dict([('large',d_diag_dict['absolute_code'][idx].split(".")[0]),('small',d_diag_dict['absolute_code'][idx].split(".")[-1] if len(d_diag_dict['absolute_code'][idx].split("."))>1 else "")])  for idx in range(len(d_diag))}
d_proc_dict = d_proc[['absolute_code','code_name']].to_dict()
p_code2name = {d_proc_dict['absolute_code'][idx]:d_proc_dict['code_name'][idx] for idx in range(len(d_proc))}
p_name2codecat = {d_proc_dict['code_name'][idx]:dict([('large',d_proc_dict['absolute_code'][idx].split(".")[0]),('small',d_proc_dict['absolute_code'][idx].split(".")[-1] if len(d_diag_dict['absolute_code'][idx].split("."))>1 else "")])  for idx in range(len(d_proc))}


In [23]:
code2name = dict()
for d in [d_code2name, p_code2name]:
    for k,v in d.items():
        code2name[k.lower()] = f'{v}'

In [24]:
name2codecat = dict()
for d in [d_name2codecat, p_name2codecat]:
    for k,v in d.items():
        v['large'] = v['large'].lower()
        name2codecat[f'{k}'] = v

In [25]:
code2name_abs = dict()
for d in [d_code2name, p_code2name]:
    for k,v in d.items():
        code2name_abs[k.lower()] = f'{v}'
codebook = dict()
for k in code2name_abs:
    try:
        large, small = k.split(".")
    except:
        large, small = k, ""
    if large not in codebook:
        codebook[large] = list()
    codebook[large].append(small)
    

In [44]:
import os
import torch
from copy import deepcopy
import numpy as np
from numpy.random import choice
from tqdm import tqdm
db_name = 'dx,prx'
db_name_source = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']#,'UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'

In [47]:
db_type = 'UnifiedNoKGenc'
# uniid2entity = {v:k.split('\t')[0].split('^^')[0].strip('"').replace('\\"','"') for k,v in torch.load(f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_UnifiedNoKGenc/unified_node').items()}
uniid2entity = {v:k.split('\t')[0].split('^^')[0].strip('"').replace('\\"','"') for k,v in torch.load(f'{SOURCE_PATH}/{db_name_source}_UnifiedNoKGenc/unified_node').items()}

tot_amt = dict()
for split in SPLIT:
#     db = torch.load(f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
    db = torch.load(f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
#     db = torch.load(f'{SOURCE_PATH}/{db_name_source}_{db_type}/{split}/db')
    for idx in range(len(db['input'])):
        ids = [x for x in db['input'][idx] if x>7]
        for _id in ids:
            if _id not in tot_amt:
                tot_amt[_id] = 0
            tot_amt[_id]+=1
code_frequency = {'.'.join([name2codecat[uniid2entity[k]]['large'],name2codecat[uniid2entity[k]]['small']]):1/v for k,v in tot_amt.items()}
uniid2entity

{0: 'PAD',
 1: 'MASK',
 2: 'CLS',
 3: 'hadm',
 4: 'diagnoses_icd9_code',
 5: 'diagnoses',
 6: 'procedures_icd9_code',
 7: 'procedures',
 8: 'anastomosis of hepatic duct to gastrointestinal tract',
 9: 'closed fracture of scapula, unspecified part',
 10: 'closure of other rectal fistula',
 11: 'other repair of rectum',
 12: 'malignant neoplasm of trachea',
 13: 'pica',
 14: 'closed fracture of patella',
 15: 'bone graft, carpals and metacarpals',
 16: 'dental caries, unspecified',
 17: 'other intubation of respiratory tract',
 18: 'carcinoma in situ of liver and biliary system',
 19: 'insertion of sphenoidal electrodes',
 20: 'repair of arteriovenous fistula',
 21: 'asbestosis',
 22: 'gamma globulin causing adverse effects in therapeutic use',
 23: 'other and unspecified intracranial hemorrhage following injury with open intracranial wound, with loss of consciousness of unspecified duration',
 24: 'arteriography of femoral and other lower extremity arteries',
 25: 'administration of inh

In [28]:
def reformat(code, is_diag):
    """
        Put a period in the right place because the MIMIC-3 data files exclude them.
        Generally, procedure codes have dots after the first two digits, 
        while diagnosis codes have dots after the first three digits.
    """
    code = ''.join(code.split('.'))
    if is_diag:
        if code.startswith('e'):
            if len(code) > 4:
                code = code[:4] + '.' + code[4:]
        else:
            if len(code) > 3:
                code = code[:3] + '.' + code[3:]
    else:
        code = code[:2] + '.' + code[2:]
    return code

In [29]:
def corrupt_input_and_generate_label(inputs, mode, id2entity, entity2id, code_frequency=None):
    if code_frequency is None:
        raise ValueError("Must turn on non-uniform sampling")
    try:
        inputs=np.array(inputs)
        inputs_ori = inputs.copy()
        input_entities = np.array([id2entity[x] if x in id2entity else x for x in inputs])

        if mode == 's':
            codes = list()
            f = list()
            for x in input_entities:
                if ("icd9_code" in x) and (x not in codes):
                    is_diag = True if 'diag' in x.split('/')[1] else False
                    code = name2codecat[code2name[reformat(x.split('/')[-1].strip(">"),is_diag = is_diag)]]
                    small_cat_length = len(codebook[code['large']])
                    if small_cat_length>=2:
                        codes.append(x)
                        f.append(code_frequency[".".join([code['large'],code['small']])])
            p = [x/sum(f) for x in f]
            corruption_target_codes = choice(codes,size=max(int(len(codes)*0.25),1), replace=False, p=p) 
            corruption_targets_idx = np.array([np.where(input_entities==code)[0][0] for code in corruption_target_codes])
            for corruption_target in corruption_targets_idx:
                code_entity = input_entities[corruption_target].split("/")
                header = '/'.join(code_entity[:-1])
                is_diag = True if 'diag' in header else False
                code = reformat(code_entity[-1].strip(">"),is_diag=is_diag)

                target_literal_idx = np.where(entity2id[code2name[code]] == inputs)[0][0]

                icd_code = code2name[code]
                codecat = name2codecat[icd_code]
                large_cat, small_cat = codecat['large'], codecat['small']
                small_cat_lists = codebook[large_cat].copy()

                small_cat_lists.remove(small_cat)
                for _ in range(50):
                    corrupted_small_cat = choice(small_cat_lists)
                    if code2name[".".join([large_cat,corrupted_small_cat])] in entity2id:
                        ERROR_FLAG=False
                        break
                    else:
                        ERROR_FLAG = True
                if ERROR_FLAG:
                    raise ValueError()
                inputs[target_literal_idx] = entity2id[code2name[".".join([large_cat,corrupted_small_cat])]]
                inputs[corruption_target] = entity2id['/'.join([header,"".join([large_cat,corrupted_small_cat])])+">"]
            labels = ~(inputs_ori==inputs)
    except:
        inputs = None
        labels = None
    return inputs, labels

In [48]:
for db_type in DB_type:
    print(db_type)
    if db_type == 'NoKGenc':
        global_sample = list()
    idx = 0
    for split in SPLIT:
        # Prepare essential files
        NUM_SPECIAL_TOKENS = 3
        id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0].strip('"').replace('\\"','"') for line in open(os.path.join(f'../preprocessing/legacy/13-01-2021/dxprx/','entity2id.txt')).read().splitlines()[1:]}
        id2entity[0]='PAD'
        id2entity[1]='MASK'
        id2entity[2]='CLS'
        entity2id = {v:k for k,v in id2entity.items()}
        uniid2entity = {v:k.split('\t')[0].split('^^')[0].strip('"').replace('\\"','"') for k,v in torch.load(f'{SOURCE_PATH}/{db_name_source}_UnifiedNoKGenc/unified_node').items()}
        entity2uniid = {v:k for k,v in uniid2entity.items()}
        id2uniid = {k:entity2uniid[v] for k,v in id2entity.items() if v in entity2uniid}
        db = torch.load(f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
#         db = torch.load(f'{SOURCE_PATH}/{db_name_source}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['rc_index', 'label_mask']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                x, y = corrupt_input_and_generate_label(inputs=db['input'][in_db_idx],mode='s',id2entity=id2entity, entity2id=entity2id, code_frequency=code_frequency)
                sample = {
                    'input': x,
                    'label': y,
                } if x is not None else None
                if sample is not None:
                    for k in db_new:
                        if k in sample:
                            db_new[k].append(sample[k].tolist())
                        else:
                            db_new[k].append(db[k][in_db_idx])
                global_sample.append(sample)
            else:
                if global_sample[idx] is not None:
                    sample = {k:global_sample[idx][k].copy() for k in global_sample[idx]}
                    actual_input = np.array(db['input'][in_db_idx].copy())
                    # Convert Non-unified sample to unifieid sample
                    if 'Unified' in db_type:
                        living_ids = sample['input'][sample['label']]
                        convertable_ids = np.array([True if living_id in id2uniid else False for living_id in living_ids])
                        sample['label'][sample['label']==True] = convertable_ids
                        actual_input[sample['label']] = np.array([id2uniid[x] for x in living_ids[convertable_ids]])
                    for k in db_new:
                        if k not in sample:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(sample['label'].tolist())
                    if 'Unified' in db_type:
                        db_new['input'].append(actual_input.tolist())
                    else:
                        db_new['input'].append(sample['input'].tolist())
                            
            idx += 1
        print('*'*50)
        print([uniid2entity[x] if 'Unified' in db_type else id2entity[x] for x in np.array(db['input'][0])[np.where(np.array(db_new['label'][0])==True)]])
        print([uniid2entity[x] if 'Unified' in db_type else id2entity[x] for x in np.array(db_new['input'][0])[np.where(np.array(db_new['label'][0])==True)]])
        print(np.where(np.array(db_new['label'][0])==True))
        print('-'*50)
        print(f'{db_name},{size},{split}',len(db_new['label']))
        print(list(db_new.keys()))
        print('*'*50)
        os.makedirs(f'../gtx/data/ed/{db_name}_{size}/{db_name}_{db_type}/{split}', exist_ok=True)
        torch.save(db_new,f'../gtx/data/ed/{db_name}_{size}/{db_name}_{db_type}/{split}/db')            

NoKGenc


28915it [00:23, 1224.90it/s]


**************************************************
['</diagnoses_icd9_code/5533>', '</diagnoses_icd9_code/45340>', '</procedures_icd9_code/3179>', 'diaphragmatic hernia without mention of obstruction or gangrene', 'acute venous embolism and thrombosis of unspecified deep vessels of lower extremity', 'other repair and plastic operations on trachea']
['</diagnoses_icd9_code/5531>', '</diagnoses_icd9_code/45372>', '</procedures_icd9_code/315>', 'umbilical hernia without mention of obstruction or gangrene', 'chronic venous embolism and thrombosis of deep veins of upper extremity', 'local excision or destruction of lesion or tissue of trachea']
(array([27, 32, 41, 47, 52, 61]),)
--------------------------------------------------
dx,prx,2000,train 28282
['input', 'label', 'text']
**************************************************


2000it [00:01, 1246.40it/s]


**************************************************
['</procedures_icd9_code/151>', 'excision of lesion or tissue of cerebral meninges']
['</procedures_icd9_code/152>', 'hemispherectomy']
(array([10, 17]),)
--------------------------------------------------
dx,prx,2000,valid 1956
['input', 'label', 'text']
**************************************************


2000it [00:01, 1261.74it/s]


**************************************************
['</diagnoses_icd9_code/28521>', '</diagnoses_icd9_code/4589>', '</diagnoses_icd9_code/2749>', '</diagnoses_icd9_code/2761>', 'anemia in chronic kidney disease', 'hypotension, unspecified', 'gout, unspecified', 'hyposmolality and/or hyponatremia']
['</diagnoses_icd9_code/2859>', '</diagnoses_icd9_code/4581>', '</diagnoses_icd9_code/27401>', '</diagnoses_icd9_code/2767>', 'anemia, unspecified', 'chronic hypotension', 'acute gouty arthropathy', 'hyperpotassemia']
(array([26, 29, 36, 38, 45, 48, 55, 57]),)
--------------------------------------------------
dx,prx,2000,test 1946
['input', 'label', 'text']
**************************************************
UniKGenc


28915it [00:01, 27002.55it/s]


**************************************************
['</diagnoses_icd9_code/5533>', '</diagnoses_icd9_code/45340>', '</procedures_icd9_code/3179>', 'diaphragmatic hernia without mention of obstruction or gangrene', 'acute venous embolism and thrombosis of unspecified deep vessels of lower extremity', 'other repair and plastic operations on trachea']
['</diagnoses_icd9_code/5531>', '</diagnoses_icd9_code/45372>', '</procedures_icd9_code/315>', 'umbilical hernia without mention of obstruction or gangrene', 'chronic venous embolism and thrombosis of deep veins of upper extremity', 'local excision or destruction of lesion or tissue of trachea']
(array([27, 32, 41, 47, 52, 61]),)
--------------------------------------------------
dx,prx,2000,train 28282
['input', 'mask', 'label', 'text']
**************************************************


2000it [00:00, 26353.41it/s]


**************************************************
['</procedures_icd9_code/151>', 'excision of lesion or tissue of cerebral meninges']
['</procedures_icd9_code/152>', 'hemispherectomy']
(array([10, 17]),)
--------------------------------------------------
dx,prx,2000,valid 1956
['input', 'mask', 'label', 'text']
**************************************************


2000it [00:00, 27180.14it/s]


**************************************************
['</diagnoses_icd9_code/28521>', '</diagnoses_icd9_code/4589>', '</diagnoses_icd9_code/2749>', '</diagnoses_icd9_code/2761>', 'anemia in chronic kidney disease', 'hypotension, unspecified', 'gout, unspecified', 'hyposmolality and/or hyponatremia']
['</diagnoses_icd9_code/2859>', '</diagnoses_icd9_code/4581>', '</diagnoses_icd9_code/27401>', '</diagnoses_icd9_code/2767>', 'anemia, unspecified', 'chronic hypotension', 'acute gouty arthropathy', 'hyperpotassemia']
(array([26, 29, 36, 38, 45, 48, 55, 57]),)
--------------------------------------------------
dx,prx,2000,test 1946
['input', 'mask', 'label', 'text']
**************************************************
UnifiedNoKGenc


28915it [00:01, 19536.38it/s]


**************************************************
['diaphragmatic hernia without mention of obstruction or gangrene', 'acute venous embolism and thrombosis of unspecified deep vessels of lower extremity', 'other repair and plastic operations on trachea']
['umbilical hernia without mention of obstruction or gangrene', 'chronic venous embolism and thrombosis of deep veins of upper extremity', 'local excision or destruction of lesion or tissue of trachea']
(array([47, 52, 61]),)
--------------------------------------------------
dx,prx,2000,train 28282
['input', 'label', 'text', 'knowledge']
**************************************************


2000it [00:00, 18709.37it/s]


**************************************************
['excision of lesion or tissue of cerebral meninges']
['hemispherectomy']
(array([17]),)
--------------------------------------------------
dx,prx,2000,valid 1956
['input', 'label', 'text', 'knowledge']
**************************************************


2000it [00:00, 20561.42it/s]


**************************************************
['anemia in chronic kidney disease', 'hypotension, unspecified', 'gout, unspecified', 'hyposmolality and/or hyponatremia']
['anemia, unspecified', 'chronic hypotension', 'acute gouty arthropathy', 'hyperpotassemia']
(array([45, 48, 55, 57]),)
--------------------------------------------------
dx,prx,2000,test 1946
['input', 'label', 'text', 'knowledge']
**************************************************
UnifiedUniKGenc


28915it [00:02, 12944.72it/s]


**************************************************
['diaphragmatic hernia without mention of obstruction or gangrene', 'acute venous embolism and thrombosis of unspecified deep vessels of lower extremity', 'other repair and plastic operations on trachea']
['umbilical hernia without mention of obstruction or gangrene', 'chronic venous embolism and thrombosis of deep veins of upper extremity', 'local excision or destruction of lesion or tissue of trachea']
(array([47, 52, 61]),)
--------------------------------------------------
dx,prx,2000,train 28282
['input', 'mask', 'label', 'text', 'knowledge']
**************************************************


2000it [00:00, 19654.66it/s]


**************************************************
['excision of lesion or tissue of cerebral meninges']
['hemispherectomy']
(array([17]),)
--------------------------------------------------
dx,prx,2000,valid 1956
['input', 'mask', 'label', 'text', 'knowledge']
**************************************************


2000it [00:00, 18444.37it/s]


**************************************************
['anemia in chronic kidney disease', 'hypotension, unspecified', 'gout, unspecified', 'hyposmolality and/or hyponatremia']
['anemia, unspecified', 'chronic hypotension', 'acute gouty arthropathy', 'hyperpotassemia']
(array([45, 48, 55, 57]),)
--------------------------------------------------
dx,prx,2000,test 1946
['input', 'mask', 'label', 'text', 'knowledge']
**************************************************


# 3-1. Convert DB for ErrorDetection for Rx

# 4. Modify Pretraining Input for KnowMix strategy

In [31]:
import os
import torch
from tqdm import tqdm, trange
db_name = 'dx,prx'
db_name_source = 'dxprx'
size = 2000
DB_type = ['UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'

In [32]:
know_list = list()
for db_type in DB_type:
    print('DB_type>>>>>>>>>>>',db_type)
    global_idx=0
    # Prepare essential files
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'../gtx/data/{db_name}_{size}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name_source}','entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v:k.split('\t')[0].split('^^')[0] for k,v in torch.load(f'{SOURCE_PATH}/{db_name_source}_{db_type}/unified_node').items()}
    for split in SPLIT:
        db = torch.load(f'../gtx/data/ed/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
        db_new = {k:list() for k in db}
        db_new['knowledge'] = list()
        for idx in trange(len(db['input'])):
            for k in db:
                db_new[k].append(db[k][idx])
            if db_type == 'UnifiedNoKGenc':
                desc = [id2entity[x].replace('"','') if x>5 else "" for x in db_new['input'][idx]]
                know_list.append(desc)
                db_new['knowledge'].append(desc)
            else:
                db_new['knowledge'].append(know_list[global_idx])
            global_idx += 1
        os.makedirs(f'../gtx/data/ed/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}', exist_ok=True)
        torch.save(db_new,f'../gtx/data/ed/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
    id2desc = {k:v.split('^^')[0].replace('"','') if '^^' in v else "" for k,v in id2entity.items()}
    print(len(id2desc))
    torch.save(id2desc,f'../gtx/data/ed/knowmix/{db_name}_{size}/{db_name}_{db_type}/id2desc')
                        

DB_type>>>>>>>>>>> UnifiedNoKGenc


100%|██████████| 28274/28274 [00:00<00:00, 38166.89it/s]
100%|██████████| 1951/1951 [00:00<00:00, 42107.93it/s]
100%|██████████| 1942/1942 [00:00<00:00, 43011.23it/s]


7871
DB_type>>>>>>>>>>> UnifiedUniKGenc


100%|██████████| 28274/28274 [00:00<00:00, 666507.90it/s]
100%|██████████| 1951/1951 [00:00<00:00, 511890.85it/s]
100%|██████████| 1942/1942 [00:00<00:00, 540070.17it/s]


7871


In [33]:
db_name_source = 'dxprx'

In [34]:
know_list = list()
for db_type in DB_type:
    global_idx=0
    # Prepare essential files
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021/{db_name_source}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1])+NUM_SPECIAL_TOKENS:line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021/{db_name_source}','entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v:k.split('\t')[0].split('^^')[0] for k,v in torch.load(f'/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021/{db_name_source}_{db_type}/unified_node').items()}
    for split in SPLIT:
        db = torch.load(f'/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021/{db_name_source}_{db_type}/{split}/db')
        db_new = {k:list() for k in db}
        db_new['knowledge'] = list()
        for idx in trange(len(db['input'])):
            for k in db:
                db_new[k].append(db[k][idx])
            if db_type == 'UnifiedNoKGenc':
                desc = [id2entity[x].replace('"','') if x>5 else "" for x in db_new['input'][idx]]
                know_list.append(desc)
                db_new['knowledge'].append(desc)
            else:
                db_new['knowledge'].append(know_list[global_idx])
            global_idx += 1
        os.makedirs(f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}', exist_ok=True)
        torch.save(db_new,f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/{split}/db')
    id2desc = {k:v.split('^^')[0].replace('"','') if '^^' in v else "" for k,v in id2entity.items()}
    print(len(id2desc))
    torch.save(id2desc,f'../gtx/data/knowmix/{db_name}_{size}/{db_name}_{db_type}/id2desc')              
                
            

100%|██████████| 28915/28915 [00:00<00:00, 39815.49it/s]
100%|██████████| 2000/2000 [00:00<00:00, 39097.34it/s]
100%|██████████| 2000/2000 [00:00<00:00, 40511.96it/s]


7871


100%|██████████| 28915/28915 [00:00<00:00, 514724.26it/s]
100%|██████████| 2000/2000 [00:00<00:00, 423410.46it/s]
100%|██████████| 2000/2000 [00:00<00:00, 403337.24it/s]


7871


#### check

In [67]:
# db = torch.load(f'/home/caoyu/project/MedGTX/gtx/data/readm/knowmix/dx,prx_2000/dx,prx_UnifiedUniKGenc/valid/db')
# db = torch.load(f'/home/caoyu/project/MultiModalMed/gtx/data/readm/knowmix/dx,prx_2000/dx,prx_UnifiedUniKGenc/valid/db')
db = torch.load(f'/home/caoyu/project/MultiModalMed/gtx/data/readm/knowmix/px_1000/px_UnifiedUniKGenc/valid/db')

IDX = 1
for k, v in db.items():
    print(k)
    print(v[1])

input
[2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 105, 326, 247, 407, 16, 142, 10, 24, 403, 25, 16, 10, 154, 412, 3877, 10, 16, 3877, 26, 154, 10, 577, 16, 71, 673, 250, 16, 45, 10, 374, 154, 352, 17, 110, 16, 841, 10, 719, 24, 7, 515, 11, 10, 214, 16, 365, 16, 109, 65, 120, 16, 17, 606, 154, 352, 185, 46, 10, 5, 16, 437, 816, 10, 16, 437, 158, 440, 102, 16, 95, 10, 75, 16, 17, 247, 105, 267, 305, 16, 10, 5, 159, 10, 99, 460, 24, 99, 16, 37, 14, 37, 230, 9, 24, 10, 216, 86, 17, 577, 1241, 16, 361, 332, 16, 139, 214, 17, 17, 2426, 158, 16, 146, 5, 105, 201, 10, 400, 16, 10, 403, 25, 142, 24, 25, 24, 10, 142, 403, 121, 16, 62, 17, 95, 577, 10, 16, 312, 673, 577, 17, 312, 291, 16, 16, 673, 577, 361, 10, 78, 277, 16, 10, 71, 17, 577, 291, 71, 16, 93, 24, 27, 25, 10, 154, 412, 267, 10, 16, 392, 16, 65, 365, 71, 24, 841, 10, 719, 7, 719, 7, 10, 5, 841, 24, 10, 27, 24, 93, 25, 5, 1457, 199, 17, 16, 

#### print(len(id2desc)) 对应配置表的num_kg_labels