In [1]:
import pickle
import pandas as pd
import numpy as np
import gzip
from collections import Counter

In [None]:
# Loading the data from the file
with open('/home/yzq/paper/HetMS-AMRGNN/het/data2/unified_map.pkl', 'rb') as f:
    unified_map = pickle.load(f)

# 获取需要处理的 subject_id 列表
subject_ids = [id for (type, id), _ in unified_map.items() if type == 'p']

# 从 diagnosis_features.csv 获取已有的诊断特征
diagnosis_features = pd.read_csv('/home/yzq/paper/HetMS-AMRGNN/het/data/diagnosis_features.csv')
existing_diagnoses = diagnosis_features.columns.tolist()
diagnosis_features = diagnosis_features.astype(int)  # 将 True/False 转换为 1/0

# 创建诊断到全局索引的映射
diagnosis_to_global_index = {diag: unified_map[('d', diag)] for diag in existing_diagnoses if ('d', diag) in unified_map}

# 创建诊断到索引的映射
diagnosis_to_index = {diag: idx for idx, diag in enumerate(existing_diagnoses)}

# 读取 admissions 数据以获取住院时间信息
with gzip.open('/home/yzq/mimiciv/3.0/hosp/admissions.csv.gz', 'rt') as f:
    admissions = pd.read_csv(f)
admissions = admissions[admissions['subject_id'].isin(subject_ids)]
admissions = admissions.sort_values(['subject_id', 'admittime'])

# 读取历史诊断数据
with gzip.open('/home/yzq/mimiciv/3.0/hosp/diagnoses_icd.csv.gz', 'rt') as f:
    historical_diagnoses = pd.read_csv(f)

In [3]:
# 合并 admissions 和 diagnoses 数据
historical_diagnoses = historical_diagnoses.merge(admissions[['subject_id', 'hadm_id', 'admittime']], on=['subject_id', 'hadm_id'])

# 过滤和排序历史诊断
historical_diagnoses = historical_diagnoses[historical_diagnoses['subject_id'].isin(subject_ids)]
historical_diagnoses = historical_diagnoses.sort_values(['subject_id', 'admittime'])

# 获取每个患者的最后一次住院 ID
last_admissions = admissions.groupby('subject_id').last()['hadm_id'].to_dict()

In [12]:
# 计算诊断频次
diagnosis_freq = Counter(historical_diagnoses['icd_code'])

# 选择高频诊断（例如，选择前100个高频诊断）
top_diagnoses = [diag for diag, _ in diagnosis_freq.most_common(100) if diag not in existing_diagnoses]

In [13]:
# 创建诊断到索引的映射
all_diagnoses = existing_diagnoses + top_diagnoses
diagnosis_to_index = {diag: idx for idx, diag in enumerate(all_diagnoses)}

In [15]:
from collections import defaultdict, Counter
# 处理每个患者的历史诊断
patient_history = defaultdict(list)

for subject_id, group in historical_diagnoses.groupby('subject_id'):
    for hadm_id, adm_group in group.groupby('hadm_id'):
        if hadm_id != last_admissions[subject_id]:  # 排除最后一次住院
            admission_diagnoses = []
            for _, row in adm_group.iterrows():
                icd_code = row['icd_code']
                if icd_code in diagnosis_to_index:
                    admission_diagnoses.append(diagnosis_to_index[icd_code])
            if admission_diagnoses:  # 只添加非空的住院记录
                patient_history[subject_id].append(admission_diagnoses)


In [16]:
# 转换为普通字典
patient_history = dict(patient_history)

In [17]:
patient_history

{10001217: [[61, 5]],
 10003372: [[30, 35, 49], [33, 29, 22, 30, 20, 49, 44], [30, 35, 20]],
 10003637: [[37, 95, 31, 84, 16, 68, 30, 20],
  [31, 30, 98, 47, 84],
  [82, 16, 22, 23, 90, 86, 31, 98, 89, 84, 49],
  [31, 98, 84, 20, 68],
  [82, 37, 22, 23, 98, 90, 89, 84, 20, 49, 86],
  [61, 5, 53, 79, 88],
  [82, 37, 20, 31, 90, 98, 47, 84, 49]],
 10004990: [[5, 52], [37, 24, 42, 30]],
 10010231: [[35], [35], [42], [35]],
 10010278: [[30], [30, 43]],
 10011365: [[5, 75, 2, 52, 93], [38, 30, 98, 31], [5, 75, 52, 53, 2, 79, 93]],
 10017644: [[30, 83, 49]],
 10021927: [[11, 71, 5, 2, 52, 51, 4, 41],
  [8, 61, 5, 50, 2],
  [10, 4, 5, 2, 52, 50, 51, 93],
  [5, 50],
  [10, 50, 52, 5],
  [38, 60, 35, 30, 20, 49],
  [9, 71, 7, 5, 2, 50, 51, 52],
  [63, 50, 5, 2, 52, 51],
  [38, 60, 35, 30, 26, 49, 74, 99],
  [61, 2, 50, 5, 70],
  [50, 2, 5, 61]],
 10023486: [[38, 30, 20, 32, 47, 16],
  [32, 97, 64, 67, 17],
  [36, 37, 21, 15, 23, 30, 62, 46, 20, 24, 32],
  [34, 15, 64, 62, 20, 32, 17, 67, 35, 42

In [18]:
diagnosis_to_index

{'0389': 0,
 '25000': 1,
 '2724': 2,
 '2762': 3,
 '2859': 4,
 '4019': 5,
 '42731': 6,
 '4280': 7,
 '486': 8,
 '51881': 9,
 '5849': 10,
 '5990': 11,
 '78552': 12,
 '99592': 13,
 'A419': 14,
 'D62': 15,
 'D649': 16,
 'D696': 17,
 'E039': 18,
 'E119': 19,
 'E785': 20,
 'E870': 21,
 'E871': 22,
 'E872': 23,
 'E875': 24,
 'E876': 25,
 'F329': 26,
 'F419': 27,
 'G4733': 28,
 'G92': 29,
 'I10': 30,
 'I2510': 31,
 'I4891': 32,
 'J189': 33,
 'J9601': 34,
 'K219': 35,
 'N170': 36,
 'N179': 37,
 'N390': 38,
 'R6521': 39,
 'U071': 40,
 'V4986': 41,
 'Y92230': 42,
 'Z20822': 43,
 'Z515': 44,
 'Z66': 45,
 'Z781': 46,
 'Z7901': 47,
 'Z794': 48,
 'Z87891': 49,
 '53081': 50,
 'V1582': 51,
 '311': 52,
 '41401': 53,
 '40390': 54,
 '2449': 55,
 'V5861': 56,
 'E1122': 57,
 'V5867': 58,
 '5859': 59,
 'J449': 60,
 '3051': 61,
 'E669': 62,
 '2761': 63,
 'I129': 64,
 '32723': 65,
 'J45909': 66,
 'N189': 67,
 'F17210': 68,
 '5856': 69,
 '30000': 70,
 '496': 71,
 'Z86718': 72,
 '49390': 73,
 'G8929': 74,
 '2720'

In [None]:
# 保存患者历史诊断序列
import pickle
with open('/home/yzq/paper/HetMS-AMRGNN/het/data2/patient_history_sequences.pkl', 'wb') as f:
    pickle.dump(patient_history, f)

# 保存诊断索引映射
with open('/home/yzq/paper/HetMS-AMRGNN/het/data2/diagnosis_index_map.pkl', 'wb') as f:
    pickle.dump(diagnosis_to_index, f)

In [4]:
# 处理每个患者的历史诊断
patient_history = {subject_id: [] for subject_id in subject_ids}
all_diagnoses = []

for subject_id, group in historical_diagnoses.groupby('subject_id'):
    last_hadm_id = last_admissions[subject_id]
    patient_diagnoses = []
    for _, row in group.iterrows():
        if row['hadm_id'] != last_hadm_id:  # 排除最后一次住院
            icd_code = row['icd_code']
            patient_diagnoses.append(icd_code)
            all_diagnoses.append(icd_code)
    patient_history[subject_id] = patient_diagnoses

In [5]:
# 计算诊断频次
diagnosis_freq = Counter(all_diagnoses)

# 选择高频诊断（例如，选择前50个高频诊断）
top_diagnoses = [diag for diag, _ in diagnosis_freq.most_common(50) if diag not in existing_diagnoses]

# 更新诊断到索引的映射
for diag in top_diagnoses:
    if diag not in diagnosis_to_index:
        diagnosis_to_index[diag] = len(diagnosis_to_index)

# 创建患者历史诊断序列
patient_history_sequences = {}
for subject_id, diagnoses in patient_history.items():
    sequence = [diagnosis_to_index[d] for d in diagnoses if d in diagnosis_to_index]
    patient_history_sequences[subject_id] = sequence

In [9]:
diagnosis_to_index

{'0389': 0,
 '25000': 1,
 '2724': 2,
 '2762': 3,
 '2859': 4,
 '4019': 5,
 '42731': 6,
 '4280': 7,
 '486': 8,
 '51881': 9,
 '5849': 10,
 '5990': 11,
 '78552': 12,
 '99592': 13,
 'A419': 14,
 'D62': 15,
 'D649': 16,
 'D696': 17,
 'E039': 18,
 'E119': 19,
 'E785': 20,
 'E870': 21,
 'E871': 22,
 'E872': 23,
 'E875': 24,
 'E876': 25,
 'F329': 26,
 'F419': 27,
 'G4733': 28,
 'G92': 29,
 'I10': 30,
 'I2510': 31,
 'I4891': 32,
 'J189': 33,
 'J9601': 34,
 'K219': 35,
 'N170': 36,
 'N179': 37,
 'N390': 38,
 'R6521': 39,
 'U071': 40,
 'V4986': 41,
 'Y92230': 42,
 'Z20822': 43,
 'Z515': 44,
 'Z66': 45,
 'Z781': 46,
 'Z7901': 47,
 'Z794': 48,
 'Z87891': 49,
 '53081': 50,
 'V1582': 51,
 '311': 52,
 '41401': 53,
 '40390': 54,
 'V5861': 55,
 '2449': 56,
 'V5867': 57,
 '5859': 58,
 'E1122': 59,
 '3051': 60,
 '32723': 61,
 '5856': 62,
 'J449': 63,
 '49390': 64,
 '30000': 65,
 '2720': 66,
 '40391': 67,
 'V1251': 68,
 '496': 69,
 '412': 70,
 'E669': 71,
 '2761': 72,
 'J45909': 73,
 'V4582': 74}

In [11]:
patient_history_sequences

{14012080: [],
 10003372: [30, 35, 49, 30, 35, 20, 33, 29, 22, 30, 20, 49, 44],
 14013810: [],
 14015628: [10,
  2,
  52,
  4,
  51,
  5,
  2,
  72,
  50,
  69,
  5,
  50,
  2,
  51,
  2,
  50,
  4,
  52,
  65],
 14018555: [10, 11, 6, 55, 5, 57],
 14009087: [],
 14019276: [],
 19985545: [8,
  11,
  72,
  6,
  5,
  52,
  50,
  65,
  8,
  8,
  5,
  52,
  8,
  5,
  5,
  4,
  5,
  56,
  4,
  50,
  65,
  5,
  5,
  50,
  65,
  4,
  5,
  65,
  52,
  52,
  65,
  41,
  50,
  5,
  64,
  65,
  52,
  50,
  41,
  65,
  52,
  50,
  5,
  65,
  52,
  51,
  25,
  49,
  26,
  37,
  22,
  16,
  17,
  49,
  26,
  17,
  26,
  73,
  26,
  17,
  73,
  49,
  34,
  37,
  32,
  26,
  25,
  27,
  26,
  49,
  17,
  37,
  23,
  27,
  26,
  37,
  38,
  23,
  32,
  73,
  26,
  47,
  35,
  25,
  49,
  33,
  73,
  32,
  27,
  25,
  26,
  49,
  34,
  35,
  49,
  49,
  25,
  26,
  27,
  73,
  49,
  30,
  25,
  27,
  26,
  73,
  23,
  25,
  27,
  35,
  26,
  37,
  73,
  26,
  49,
  29,
  22,
  26,
  27,
  49,
  49,
  33,

In [7]:
len(top_diagnoses)

25

In [None]:
# 保存患者历史诊断序列
import pickle
with open('patient_history_sequences.pkl', 'wb') as f:
    pickle.dump(patient_history_sequences, f)

# 保存诊断索引映射
with open('diagnosis_index_map.pkl', 'wb') as f:
    pickle.dump(diagnosis_to_index, f)