In [14]:
# 📦 Imports
import pandas as pd
import os
from collections import defaultdict

# 📂 路径配置
DATA_DIR = '../data'
notes_file = os.path.join(DATA_DIR, 'NOTEEVENTS.csv')
diagnosis_file = os.path.join(DATA_DIR, 'DIAGNOSES_ICD.csv')
icd_desc_file = os.path.join(DATA_DIR, 'D_ICD_DIAGNOSES.csv')

# ✅ Step 1: 读取出院总结
def load_discharge_summaries(file_path):
    print("Loading discharge summaries...")
    df = pd.read_csv(file_path, low_memory=False)
    discharge_df = df[df['CATEGORY'] == 'Discharge summary']
    discharge_df = discharge_df.sort_values(['HADM_ID', 'CHARTDATE']).drop_duplicates('HADM_ID', keep='last')
    return discharge_df[['SUBJECT_ID', 'HADM_ID', 'TEXT']]

discharge_df = load_discharge_summaries(notes_file)
print(f"Discharge summaries: {discharge_df.shape}")

# ✅ Step 2: ICD 标签处理
def load_icd_labels(diag_file):
    print("Loading ICD codes...")
    diag_df = pd.read_csv(diag_file)
    grouped = diag_df.groupby('HADM_ID')['ICD9_CODE'].apply(list).reset_index()
    return grouped

icd_df = load_icd_labels(diagnosis_file)
print(f"ICD labels: {icd_df.shape}")

from collections import Counter

# 标签展开后计数
all_labels = [code for label_list in data['ICD9_CODE'] for code in label_list]
label_counts = Counter(all_labels)

# 保留出现次数 ≥ 10 的标签（例如）
frequent_labels = set([code for code, cnt in label_counts.items() if cnt >= 10])

# 过滤样本（只保留仍有标签的）
def filter_labels(label_list):
    return [code for code in label_list if code in frequent_labels]

data['ICD9_CODE'] = data['ICD9_CODE'].apply(filter_labels)
data = data[data['ICD9_CODE'].map(len) > 0]


# ✅ Step 3: 合并文本与标签
print("Merging text and labels...")
data = pd.merge(discharge_df, icd_df, on='HADM_ID')
data = data.dropna(subset=['TEXT', 'ICD9_CODE'])

# ✅ Step 3.5: 类型标准化，避免 downstream 错误
def ensure_text(x):
    return str(x) if not isinstance(x, str) else x

def ensure_list_of_str(x):
    if isinstance(x, str):  # 如果是单个字符串，转为单元素 list
        return [x]
    return list(map(str, x))  # 全部转为 str

data['TEXT'] = data['TEXT'].apply(ensure_text)
data['ICD9_CODE'] = data['ICD9_CODE'].apply(ensure_list_of_str)
print(f"ICD labels: {icd_df.shape}")

# ✅ 可选断言检查（保证每行格式正确）
assert isinstance(data['TEXT'].iloc[0], str)
assert isinstance(data['ICD9_CODE'].iloc[0], list)
assert isinstance(data['ICD9_CODE'].iloc[0][0], str)


# ✅ Step 4: 保存为 pickle，供 baseline 和后续使用
output_path = os.path.join(DATA_DIR, 'mimic3_data_test.pkl')
data.to_pickle(output_path)
print(f"Saved merged dataset to {output_path}")


Loading discharge summaries...
Discharge summaries: (52726, 3)
Loading ICD codes...
ICD labels: (58976, 2)
Merging text and labels...
ICD labels: (58976, 2)
Saved merged dataset to ../data/mimic3_data_test.pkl
