# step0: preprocessing task label (readminssion and expired)

In [1]:
import _pickle as cPickle
import csv
import os
import sys
import datetime
import random
import pandas as pd
from tqdm import tqdm
import sklearn.model_selection as ms
encounter_dict = {}

In [2]:
class EncounterInfo(object):

    def __init__(self, patient_id, encounter_id, encounter_timestamp, expired,
                 readmission, los_3day, los_7day, Death30, Death180, Death365, admission_type=99, marital_status=99):
        self.patient_id = patient_id
        self.encounter_id = encounter_id
        self.encounter_timestamp = encounter_timestamp
        self.expired = expired
        self.readmission = readmission
        self.admission_type = admission_type
        self.marital_status = marital_status
        self.los_3day = los_3day
        self.los_7day = los_7day
        self.gender = ''
        self.dx_ids = []
        self.dx_ids_lvl1 = []
        self.dx_names = []
        self.dx_labels = []
        self.rx_ids = []
        self.rx_names = []
        self.lab_ids = []
        self.lab_names = []
        self.microbiology_ids = []
        self.microbiology_names = []
        self.physicals = []
        self.procedures_ids = []
        self.procedures_names = []
        self.Death30 = Death30
        self.Death180 = Death180
        self.Death365 = Death365


def minNums(startTime, endTime):
    '''计算两个时间点之间的分钟数'''
    # 处理格式,加上秒位
    startTime1 = startTime  # + ':00'
    endTime1 = endTime  # + ':00'
    # 计算分钟数
    startTime2 = datetime.datetime.strptime(startTime1, "%Y-%m-%d %H:%M:%S")
    endTime2 = datetime.datetime.strptime(endTime1, "%Y-%m-%d %H:%M:%S")
    seconds = (endTime2 - startTime2).seconds
    # 来获取时间差中的秒数。注意，seconds获得的秒只是时间差中的小时、分钟和秒部分的和，并没有包含时间差的天数（既是两个时间点不是同一天，失效）
    total_seconds = (endTime2 - startTime2).total_seconds()
    # 来获取准确的时间差，并将时间差转换为秒
    # print(total_seconds)
    mins = total_seconds / 60
    return int(mins)

In [26]:
# random_seed 用于随机生成
def process_admission(admission_file, icustays_file, patients_file, encounter_dict, max_patient_num=500, hour_threshold=24):

    count = 0
    patient_dict = {}
    admission_list = []

    patient_base_dict = {}
    inff = open(patients_file, 'r')
    for line in csv.DictReader(inff):
        patient_id = line['SUBJECT_ID']
        gender = line['GENDER']
        birthday = line['DOB']
        if patient_id not in patient_base_dict:
            patient_base_dict[patient_id] = []
        patient_base_dict[patient_id].append((gender,birthday))
    inff.close()


    # --------------- icustays_file --------------------
#     patient_icu_dict = {}
#     inff = open(icustays_file, 'r')
#     for line in csv.DictReader(inff):
#         FIRST_CAREUNIT = line['FIRST_CAREUNIT']
#         LAST_CAREUNIT = line['LAST_CAREUNIT']
#         icu_intime = line['INTIME']
#         icu_outtime = line['OUTTIME']
#         patient_id = line['SUBJECT_ID']
#         encounter_id = line['HADM_ID']
#         birthday = patient_base_dict[patient_id][0][1]
#         age = minNums(birthday, icu_intime)/(24. * 60 * 365)

#         ## ICU stay以及年龄大于18
#         if FIRST_CAREUNIT == LAST_CAREUNIT and LAST_CAREUNIT=='MICU' and age > 18:
#             if patient_id not in patient_icu_dict:
#                 patient_icu_dict[patient_id] = []
#             patient_icu_dict[patient_id].append(encounter_id)
#     inff.close()


    # --------------- admission_file --------------------
    inff = open(admission_file, 'r')
    for line in csv.DictReader(inff):
        if count % 1000 == 0:
            sys.stdout.write('%d\r' % count)
            sys.stdout.flush()

        # if count == max_admission_num:
        #     break

        patient_id = line['SUBJECT_ID']
        encounter_id = line['HADM_ID']
        admittime = line['ADMITTIME']
        dischtime = line['DISCHTIME']

        # encounter_timestamp = -int(line['hospitaladmitoffset'])
        encounter_timestamp = minNums(admittime, dischtime)
        # encounter_timestamp：number of minutes from unit admit time that the patient was admitted to the hospital

#         # 只统计ICU的记录
#         if patient_id in patient_icu_dict.keys():
#             if encounter_id in patient_icu_dict[patient_id]:
        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append((admittime, encounter_timestamp, encounter_id,dischtime))

        count += 1
    inff.close()

    # 随机的数量不能大于总的患者数量
    if max_patient_num > len(patient_dict):
        max_patient_num = len(patient_dict)
    # 随机选取 患者
    patient_random_keys = random.sample(patient_dict.keys(), max_patient_num)
    patient_random_del_keys = []

    for patient_id in patient_dict.keys():
        if patient_id not in patient_random_keys:
            patient_random_del_keys.append(patient_id)
    # 删除不在随机范围内的患者记录
    for patient_random_del_key in patient_random_del_keys:
        del patient_dict[patient_random_del_key]
    # admission_list，只存储随机到的患者的就诊记录
    for patient_id, encounter_ids in patient_dict.items():
        for encounter_id in encounter_ids:
            if encounter_id[2] not in admission_list:
                admission_list.append(encounter_id[2])

    # sort
    patient_dict_sorted = {}
    for patient_id, time_enc_tuples in patient_dict.items():
        # print(time_enc_tuples)
        patient_dict_sorted[patient_id] = sorted(time_enc_tuples, reverse=False)

    # enc_readmission_dict 判断该患者是否是重新入院的
    enc_readmission_dict = {}
    for patient_id, time_enc_tuples in patient_dict_sorted.items():
        for time_enc_tuple in time_enc_tuples[:-1]:
            enc_id = time_enc_tuple[2]
            enc_readmission_dict[enc_id] = 1
        last_enc_id = time_enc_tuples[-1][2]
        enc_readmission_dict[last_enc_id] = 0


    # enc_readmission_dict 判断该患者是否是重新入院,在出院30天内
#     enc_readmission_dict = {}
#     for patient_id, time_enc_tuples in patient_dict_sorted.items(): 
#         flag=0
#         handle_flag = 0
#         for time_enc_tuple in time_enc_tuples:
#             enc_readmission_dict[time_enc_tuple[2]] = 0 # 先默认本次入院记录的next_readmission_flag 为0，后面再更新
#             if handle_flag == 1:  #第一次循环不处理下次入院时间，从第二次入院记录开始，更新上次入院的next_readmission day
#                 readminssionday = 999
#                 encounter_timestamp = minNums(last_time_enc_tuple[3], time_enc_tuple[0]) # 上次出院到本次入院的时间
#                 readminssionday =  encounter_timestamp/(24. * 60)
#                 if readminssionday<=30:
#                     enc_readmission_dict[last_time_enc_tuple[2]] = 1
#                 else:
#                     enc_readmission_dict[last_time_enc_tuple[2]] = 0
# #                 print(patient_id,last_time_enc_tuple[2],'>>>',last_time_enc_tuple[0],last_time_enc_tuple[3],'>>>',readminssionday)
#             handle_flag = 1
#             last_time_enc_tuple = time_enc_tuple
            

    inff = open(admission_file, 'r')
    count = 0
    for line in tqdm(csv.DictReader(inff)):
        if line['HADM_ID'] in admission_list:
            patient_id = line['SUBJECT_ID']
            encounter_id = line['HADM_ID']

            admittime = line['ADMITTIME']
            dischtime = line['DISCHTIME']
            deathtime = line['DEATHTIME']

            encounter_timestamp = minNums(admittime, dischtime)
            if deathtime is not None and deathtime!='':  
                death_timestamp = minNums(admittime, deathtime)
            else:
                death_timestamp=0

            # hospital_expire_flag：This is a binary flag which indicates whether the patient died within the given hospitalization.
            # ---------------------： 1 indicates death in the hospital, and 0 indicates survival to hospital discharge.
            hospital_expire_flag = line['HOSPITAL_EXPIRE_FLAG']
            duration_minute = encounter_timestamp
            losday =  duration_minute/(24. * 60)
            deathday = death_timestamp/(24. * 60)
            
            expired = 1 if hospital_expire_flag == '1' else 0
            readmission = 1 if enc_readmission_dict[encounter_id] ==1 else 0
            los_3day = 1 if losday > 3 else 0
            los_7day = 1 if losday > 7 else 0
            
            Death30 = 1 if 30>= deathday >0 else 0
            Death180 = 1 if 180>= deathday >0 else 0
            Death365 = 1 if 365>= deathday >0 else 0
            
        
            if duration_minute < 60. * hour_threshold:
                continue

            ei = EncounterInfo(patient_id, encounter_id, encounter_timestamp, expired,
                               readmission, los_3day, los_7day, Death30, Death180, Death365)
            if encounter_id in encounter_dict:
                print('Duplicate encounter ID!!')
                sys.exit(0)
            encounter_dict[encounter_id] = ei
            count = count + 1
    inff.close()

    print('Accepted Patients: {}'.format(max_patient_num))
    print('Accepted admissions: {}'.format(count))
    print('')
    return encounter_dict, admission_list

In [27]:
def process_patients(patients_file, encounter_dict):
    count = 0
    enc_dict = encounter_dict
    inff = open(patients_file, 'r')
    for line in csv.DictReader(inff):
        patient_id = line['SUBJECT_ID']
        gender = line['GENDER']
        for _, enc in enc_dict.items():
            if enc.patient_id == patient_id:
                enc.gender = gender
            count += 1
    inff.close()

    print('Accepted admissions: %d' % count)
    print('')
    return encounter_dict


In [28]:

# input_path = argv[1]
# output_path = argv[2]
# input_path = '/home/caoyu/jupyterNotebook/data_test'
# output_path = '/home/caoyu/project/healthRecords/data/mimiciv'

input_path = '/home/caoyu/project/GraphCLHealth/data/mimiciii'
output_path = '/home/caoyu/project/GraphCLHealth/processed_data/mimiciii'
minimum_cnt = 5
max_patient_num = 100000
print('max_patient_num:' + str(max_patient_num))


flag_test_flag = 1 # whether to use the test files for debugging
icd_diagnoses_file = input_path + '/D_ICD_DIAGNOSES.csv'
icd_procedures_file = input_path + '/D_ICD_PROCEDURES.csv'
icd_labItems_file = input_path + '/D_LABITEMS.csv'


if flag_test_flag == 0:
    admission_file = input_path + '/ADMISSIONS10.csv'
    diagnosis_file = input_path + '/DIAGNOSES_ICD10.csv'
    procedures_file = input_path + '/PROCEDURES_ICD10.csv'
    labevents_file = input_path + '/LABEVENTS10.csv'
    microbiology_file = input_path + '/microbiologyevents10.csv'
    patients_file = input_path + '/PATIENTS.csv'
    icustays_file = input_path + '/ICUSTAYS10.csv'
else:
    admission_file = input_path + '/ADMISSIONS.csv'
    diagnosis_file = input_path + '/DIAGNOSES_ICD.csv'
    procedures_file = input_path + '/PROCEDURES_ICD.csv'
    labevents_file = input_path + '/LABEVENTS.csv'
    microbiology_file = input_path + '/microbiologyevents.csv'
    patients_file = input_path + '/PATIENTS.csv'
    icustays_file = input_path + '/ICUSTAYS.csv'
# 调试用的文件

encounter_dict = {}

# max_admission_nun 处理最大的住院患者流水数量
print('Processing ADMISSIONS.csv')
encounter_dict, admission_list = process_admission(admission_file, icustays_file, patients_file, encounter_dict, max_patient_num, hour_threshold=24)
# print('Processing PATIENTS.csv')
# encounter_dict = process_patients(patients_file, encounter_dict)




max_patient_num:100000
Processing ADMISSIONS.csv
58000

58976it [00:45, 1290.23it/s]

Accepted Patients: 46520
Accepted admissions: 56684






In [29]:
expired = {}
readmission = {}
death30 = {}
death180 = {}
death365 = {}
for key in encounter_dict.keys():
    if encounter_dict[key].encounter_id not in expired:
        expired[encounter_dict[key].encounter_id] = encounter_dict[key].expired
    if encounter_dict[key].encounter_id not in readmission:
        readmission[encounter_dict[key].encounter_id] = encounter_dict[key].readmission
    if encounter_dict[key].encounter_id not in death30:
        death30[encounter_dict[key].encounter_id] = encounter_dict[key].Death30
    if encounter_dict[key].encounter_id not in death180:
        death180[encounter_dict[key].encounter_id] = encounter_dict[key].Death180      
    if encounter_dict[key].encounter_id not in death30:
        death365[encounter_dict[key].encounter_id] = encounter_dict[key].Death365 
        
#     print(key,encounter_dict[key].encounter_id,encounter_dict[key].expired,encounter_dict[key].readmission)

In [30]:
cnt = 0
for pa in readmission:
    if readmission[pa]==1:
        cnt+=1
print('readmission',cnt)


cnt = 0
for pa in death30:
    if death30[pa]==1:
        cnt+=1
print('death30',cnt)

cnt = 0
for pa in death180:
    if death180[pa]==1:
        cnt+=1
print('death180',cnt)


cnt = 0
for pa in expired:
    if expired[pa]==1:
        cnt+=1
print('expired', cnt)

#     print(death180[pa])
print (len(expired))
print (len(readmission))

readmission 12134
death30 4492
death180 4875
expired 4883
56684
56684


In [19]:
print(4492/56684)

0.07924634817585209


In [20]:
print(3275/56684)

0.05777644485216287


In [21]:
print(len(expired))

56684


In [22]:
print(len(readmission))

56684


# 5. Convert DB for ReAdmPred task

In [31]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [32]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in readmission:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(readmission[hadm_id])
                    global_label.append(readmission[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/readm/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/readm/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 338373.01it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 1
##### NoKGenc  --  valid


2000it [00:00, 319176.93it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 1
##### NoKGenc  --  test


2000it [00:00, 328836.06it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 620616.02it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 1
##### UniKGenc  --  valid


2000it [00:00, 578006.48it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 1
##### UniKGenc  --  test


2000it [00:00, 615496.96it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 349362.22it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 1
##### UnifiedNoKGenc  --  valid


2000it [00:00, 338591.64it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 1
##### UnifiedNoKGenc  --  test


2000it [00:00, 345295.46it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 301590.53it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 1
##### UnifiedUniKGenc  --  valid


2000it [00:00, 269600.13it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 1
##### UnifiedUniKGenc  --  test


2000it [00:00, 301185.12it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0


# 5. Convert DB for expired task （180）

In [13]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [14]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in death180:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(death180[hadm_id])
                    global_label.append(death180[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/Death180/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/Death180/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 196609.06it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 0
##### NoKGenc  --  valid


2000it [00:00, 260160.28it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 0
##### NoKGenc  --  test


2000it [00:00, 271475.99it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 620032.21it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 0
##### UniKGenc  --  valid


2000it [00:00, 584571.99it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 0
##### UniKGenc  --  test


2000it [00:00, 617172.45it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 334000.63it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  valid


2000it [00:00, 320775.80it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  test


2000it [00:00, 132979.42it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 304359.93it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  valid


2000it [00:00, 249779.90it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  test


2000it [00:00, 267093.58it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0


# 5. Convert DB for expired task （30）

In [11]:
import os
import torch
import numpy as np
import pickle as pkl
from numpy.random import choice
from tqdm import tqdm
out_db_name = 'dx,prx'
db_name = 'dxprx'
size = 2000
DB_type = ['NoKGenc','UniKGenc','UnifiedNoKGenc','UnifiedUniKGenc']
SPLIT = ['train','valid','test']
SOURCE_PATH = '/home/caoyu/project/MultiModalMed/preprocessing/legacy/13-01-2021'
knowmix_PATH='/home/caoyu/project/MultiModalMed/gtx/data/knowmix'

In [12]:
for db_type in DB_type:
    if "Unified" not in db_type:
        NUM_SPECIAL_TOKENS = 3
        id2label = torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/id2label')
        label2id = {v:k for k,v in id2label.items()}
        id2entity = {int(line.split('\t')[1]) + NUM_SPECIAL_TOKENS: line.split('\t')[0].split('^^')[0] for line in open(os.path.join(f'{SOURCE_PATH}/{db_name}', 'entity2id.txt')).read().splitlines()[1:]}
        label2entity = {k:id2entity[v] for k,v in label2id.items()}
    else:
        id2entity = {v: k.split('\t')[0].split('^^')[0] for k, v in
             torch.load(f'{SOURCE_PATH}/{db_name}_{db_type}/unified_node').items()}

    if db_type == 'NoKGenc':
        global_label = list()
    idx = 0
    for split in SPLIT:
        print('#####',db_type,' -- ',split)
        db = torch.load(f'{knowmix_PATH}/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
        db_new = dict()
        for k in db:
            if k in ['label_mask','rc_index']:
                continue
            db_new[k] = list()
        for in_db_idx, _input in tqdm(enumerate(db['input'])):
            if db_type == 'NoKGenc':
                hadm_id = id2entity[_input[1]].split('/')[-1].replace('>','')
                if hadm_id in death30:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(death30[hadm_id])
                    global_label.append(death30[hadm_id])
                else:
                    global_label.append(None)
            else:
                if global_label[idx] is not None:
                    for k in db_new:
                        if k not in ['label','label_mask','rc_index']:
                            db_new[k].append(db[k][in_db_idx])
                    db_new['label'].append(global_label[idx])
            idx += 1
#             if idx >100:
#                 break
        print(f'{db_name},{size},{split}',len(db_new['label']))
        if "Unified" in db_type:
            os.makedirs(f'../gtx/data/Death30/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}', exist_ok=True)
            torch.save(db_new,f'../gtx/data/Death30/knowmix/{out_db_name}_{size}/{out_db_name}_{db_type}/{split}/db')
            print("*** Saved! ***")
        print(f"20th sample",[id2entity[x] for x in db_new['input'][20][1:10]])
#         print(f"100th sample", [id2entity[x] for x in db_new['input'][0:20]])
        print(f"20th label", db_new['label'][20])

##### NoKGenc  --  train


28915it [00:00, 222512.66it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 0
##### NoKGenc  --  valid


2000it [00:00, 310482.20it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 0
##### NoKGenc  --  test


2000it [00:00, 321920.64it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UniKGenc  --  train


28915it [00:00, 574615.28it/s]


dxprx,2000,train 28358
20th sample ['</hadm_id/129414>', '</diagnoses/406607>', '</diagnoses/406605>', '</diagnoses/406616>', '</diagnoses/406614>', '</diagnoses/406615>', '</diagnoses/406613>', '</diagnoses/406606>', '</diagnoses/406611>']
20th label 0
##### UniKGenc  --  valid


2000it [00:00, 566721.25it/s]


dxprx,2000,valid 1957
20th sample ['</hadm_id/167090>', '</diagnoses/67993>', '</diagnoses/67990>', '</procedures/28197>', '</procedures/28192>', '</diagnoses/67995>', '</diagnoses/67996>', '</procedures/28190>', '</diagnoses/67992>']
20th label 0
##### UniKGenc  --  test


2000it [00:00, 592164.90it/s]


dxprx,2000,test 1961
20th sample ['</hadm_id/185204>', '</diagnoses/299964>', '</diagnoses/299958>', '</diagnoses/299967>', '</diagnoses/299965>', '</diagnoses/299961>', '</diagnoses/299963>', '</diagnoses/299962>', '</procedures/109432>']
20th label 0
##### UnifiedNoKGenc  --  train


28915it [00:00, 338582.72it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  valid


2000it [00:00, 323148.35it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedNoKGenc  --  test


2000it [00:00, 230241.20it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0
##### UnifiedUniKGenc  --  train


28915it [00:00, 302928.42it/s]


dxprx,2000,train 28358
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  valid


2000it [00:00, 261743.21it/s]


dxprx,2000,valid 1957
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'procedures', 'procedures', 'diagnoses', 'diagnoses', 'procedures', 'diagnoses']
20th label 0
##### UnifiedUniKGenc  --  test


2000it [00:00, 291980.79it/s]


dxprx,2000,test 1961
*** Saved! ***
20th sample ['hadm', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'diagnoses', 'procedures']
20th label 0
