In [21]:
import pandas as pd
import numpy as np
import collections
from typing import List, Tuple
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt




## Utility functions

In [22]:
def check_subject_id_in_mimic(embeddings_df: pd.DataFrame, mimic_df: pd.DataFrame) -> bool:

    embeddings_ids = embeddings_df['subject_id'].unique()
    mimic_ids = mimic_df['subject_id'].unique() 

    embeddings_id_set = set(embeddings_ids)
    mimic_id_set = set(mimic_ids)

    # Check for any common elements
    if embeddings_id_set.intersection(mimic_id_set):
        return True
    else:
        return False

In [23]:
def split_by_patient_id(df, test_size=0.2, random_state=42, id_col='subject_id'):
    unique_ids = df[id_col].unique()
    train_ids, test_ids = train_test_split(unique_ids, test_size=test_size, random_state=random_state)

    train_df = df[df[id_col].isin(train_ids)].copy()
    test_df = df[df[id_col].isin(test_ids)].copy()

    return train_df, test_df, train_ids, test_ids 

In [24]:
def assign_group(row):
    if (row['osa'] == True) and (row['hf'] == False):
        return 0  # OSA only
    elif (row['osa'] == True) and (row['hf'] == True):
        return 1  # OSA + HF
    elif (row['osa'] == False) and (row['hf'] == True):
        return 2  # HF only

## load two main data inputs (mimic + embeddings)

In [25]:
# load data from notebook
mimic = pd.read_csv('/global/cfs/projectdirs/m1532/Projects_MVP/Sophia/2025/cohort/osa_hf_demo_lab_diagnosis_med_procedure_processed_with_id.csv')
text_embeddings = pd.read_csv('/global/cfs/projectdirs/m1532/Projects_MVP/adong/2024/data/bioclinical_embeddings.csv')
mimic_copy = mimic.copy()

## explore before processing

In [26]:
text_embeddings.head(10)

Unnamed: 0.1,Unnamed: 0,subject_id,hadm_id,note_id,0,1,2,3,4,5,...,758,759,760,761,762,763,764,765,766,767
0,0,10000980,20897796,10000980-DS-26,1.684451,-6.711592,0.158224,2.127581,3.598444,-3.479481,...,0.084615,1.073412,-4.436857,-1.260715,4.461096,-0.996854,1.113404,-0.272567,-5.099933,1.122292
1,1,10000980,24947999,10000980-DS-22,0.803342,-3.435924,-0.270163,0.267355,1.387757,-2.100942,...,-0.580867,0.013614,-2.584508,-1.386371,3.283558,-0.690634,1.070792,-1.081939,-4.329378,0.4667
2,2,10000980,25242409,10000980-DS-23,1.593685,-4.284272,-1.291691,-0.358811,1.320054,-1.134105,...,-1.410985,0.392991,-3.36079,-2.833561,4.08387,-1.133995,2.588758,-0.529504,-5.102419,1.339941
3,3,10000980,26913865,10000980-DS-21,1.14877,-3.489123,0.396844,0.233212,2.278701,-2.507564,...,0.021621,0.873185,-3.169588,-0.723119,2.53361,-1.914018,1.323397,-0.110037,-3.557183,0.725257
4,4,10000980,29654838,10000980-DS-20,0.872454,-4.563379,-0.083856,-0.138723,1.607062,-2.468211,...,-1.355164,0.408696,-2.376289,-0.435094,4.197835,-2.647004,0.990659,-0.034303,-3.766662,1.293148
5,5,10001667,22672901,10001667-DS-10,0.347112,-4.697465,-0.622722,-0.322343,2.201615,-2.085006,...,0.098884,0.952318,-2.824231,-2.408501,4.743751,-1.328005,1.648815,-0.879433,-5.251045,1.21046
6,6,10001877,21320596,10001877-DS-20,1.11463,-3.658412,-0.492278,-0.560952,-0.086992,-1.721783,...,-1.439749,0.155669,-1.656865,-1.760112,3.460104,-1.247112,1.285931,-0.810649,-3.569918,1.493029
7,7,10002013,21763296,10002013-DS-12,0.521005,-4.810385,-0.941124,-0.181864,3.671541,-2.954611,...,-0.292088,1.157761,-3.920733,0.02802,4.923773,-2.913791,1.383603,0.612273,-4.959509,1.939288
8,8,10002013,21975601,10002013-DS-6,0.027079,-4.193328,0.091312,-0.366392,0.889691,-2.177943,...,-1.305785,0.147394,-3.476484,-0.990173,3.157353,-2.544333,1.556366,-0.295789,-4.404293,0.884647
9,9,10002013,23581541,10002013-DS-7,0.13532,-4.190922,-0.571324,-0.607218,1.851192,-2.291935,...,-1.302045,0.946209,-2.788171,-0.716852,4.532771,-2.562863,1.183258,0.088128,-4.535398,1.611459


## main code for structured data portion




In [27]:
structured_train, structured_test, train_ids, test_ids = split_by_patient_id(mimic_copy)

In [28]:
# Prepare data
drop_cols_struct = ['subject_id', 'group']
X_train_struct = structured_train.drop(columns=drop_cols_struct)
y_train_struct = structured_train['group']
X_test_struct = structured_test.drop(columns=drop_cols_struct)
y_test_struct = structured_test['group']


In [29]:
# Scale features
scaler_struct = StandardScaler()
X_train_struct_scaled = scaler_struct.fit_transform(X_train_struct)
X_test_struct_scaled = scaler_struct.transform(X_test_struct)

In [30]:
# Train logistic regression
structured_model = LogisticRegression(max_iter=1000, multi_class='multinomial', solver='lbfgs')
structured_model.fit(X_train_struct_scaled, y_train_struct)

# Evaluate
y_pred_struct = structured_model.predict(X_test_struct_scaled)
print("Accuracy:", accuracy_score(y_test_struct, y_pred_struct))
print("Classification Report:\n", classification_report(y_test_struct, y_pred_struct, target_names=["OSA", "OSA + HF", "HF"]))
print("Confusion Matrix:\n", confusion_matrix(y_test_struct, y_pred_struct))

Accuracy: 0.7986341823551126
Classification Report:
               precision    recall  f1-score   support

         OSA       0.85      0.81      0.83      3758
    OSA + HF       0.46      0.26      0.33      1142
          HF       0.80      0.89      0.85      5936

    accuracy                           0.80     10836
   macro avg       0.70      0.66      0.67     10836
weighted avg       0.78      0.80      0.79     10836

Confusion Matrix:
 [[3040  102  616]
 [ 159  302  681]
 [ 366  258 5312]]


## code for text embeddings

In [31]:
mimic_copy.head(10)

Unnamed: 0,subject_id,hadm_id,group,hospital_expire_flag,los,age,charlson,bmi,"Calculated Bicarbonate, Whole Blood",Calculated Total CO2,...,insurance_Medicare,insurance_Other,admission_type_DIRECT EMER.,admission_type_DIRECT OBSERVATION,admission_type_ELECTIVE,admission_type_EU OBSERVATION,admission_type_EW EMER.,admission_type_OBSERVATION ADMIT,admission_type_SURGICAL SAME DAY ADMISSION,admission_type_URGENT
0,12300663,20000588,0,0,0.073929,0.783685,0.78761,0.607774,0.0,0.0,...,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,16679562,20001395,2,0,0.398989,0.799803,0.683011,0.522518,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,11389019,20001863,0,0,0.157144,0.342316,0.0,0.773812,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3,19476485,20002267,2,0,0.355609,0.791799,0.639151,0.486267,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,17429492,20002493,0,0,0.448281,0.775458,0.683011,0.566813,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
5,13390157,20002497,1,0,0.434407,0.507205,0.528634,0.866653,0.0,0.27704,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
6,18303950,20002634,0,0,0.169115,0.705135,0.528634,0.793241,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
7,12527107,20002636,0,0,0.402574,0.635849,0.360849,0.520669,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
8,15497465,20002661,2,0,0.191351,0.908377,0.683011,0.206134,0.0,0.0,...,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
9,10160622,20002800,2,0,0.122597,0.783685,0.756304,0.499905,0.0,0.0,...,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


In [32]:
text_embeddings['subject_id'] = text_embeddings['subject_id'].astype(int)
mimic_copy['subject_id'] = mimic_copy['subject_id'].astype(int)

In [33]:
print("Text Embeddings Columns:")
print(text_embeddings.columns.tolist())

# Check for subject_id and embedding dimension columns 

Text Embeddings Columns:
['Unnamed: 0', 'subject_id', 'hadm_id', 'note_id', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147'

In [34]:
# Merge to bring in group label into the text embeddings
group_info = mimic_copy[['subject_id', 'group']].drop_duplicates()
text_embeddings_merged = text_embeddings.drop(columns=['group'], errors='ignore')  # remove if it exists
text_embeddings_merged = text_embeddings_merged.merge(group_info, on='subject_id', how='inner')


In [35]:
unstructured_train, unstructured_test, train_ids, test_ids = split_by_patient_id(text_embeddings_merged)


In [36]:
# Prepare features and labels
drop_cols_text = ['Unnamed: 0', 'subject_id', 'hadm_id', 'note_id', 'group']
X_train_text = unstructured_train.drop(columns=drop_cols_text)
y_train_text = unstructured_train['group']
X_test_text = unstructured_test.drop(columns=drop_cols_text)
y_test_text = unstructured_test['group']

In [37]:
# Scale features
scaler_text = StandardScaler()
X_train_text_scaled = scaler_text.fit_transform(X_train_text)
X_test_text_scaled = scaler_text.transform(X_test_text)

In [38]:
# Train logistic regression with class weighting
unstructured_model = LogisticRegression(
    max_iter=2500,
    multi_class='multinomial',
    solver='lbfgs',
    class_weight='balanced'  # ← This is the key addition
)
unstructured_model.fit(X_train_text_scaled, y_train_text)

# Evaluate
y_pred_text = unstructured_model.predict(X_test_text_scaled)
print("Accuracy:", accuracy_score(y_test_text, y_pred_text))
print("Classification Report:\n", classification_report(y_test_text, y_pred_text, target_names=["OSA", "OSA + HF", "HF"]))
print("Confusion Matrix:\n", confusion_matrix(y_test_text, y_pred_text))


Accuracy: 0.5713310114047668
Classification Report:
               precision    recall  f1-score   support

         OSA       0.62      0.64      0.63      4936
    OSA + HF       0.28      0.37      0.32      2580
          HF       0.70      0.59      0.64      7127

    accuracy                           0.57     14643
   macro avg       0.53      0.54      0.53     14643
weighted avg       0.60      0.57      0.58     14643

Confusion Matrix:
 [[3172  876  888]
 [ 717  955  908]
 [1262 1626 4239]]


### combine results and get classification report

In [None]:
# Find shared subject_ids
shared_ids = set(structured_test['subject_id']) & set(unstructured_test['subject_id'])

# Filter both test sets
structured_filtered = structured_test[structured_test['subject_id'].isin(shared_ids)].copy()
unstructured_filtered = unstructured_test[unstructured_test['subject_id'].isin(shared_ids)].copy()

In [None]:
#  Deduplicate by subject_id (take first occurrence)
structured_filtered = structured_filtered.drop_duplicates(subset='subject_id')
unstructured_filtered = unstructured_filtered.drop_duplicates(subset='subject_id')

# Sort both by subject_id to align rows
structured_filtered = structured_filtered.sort_values('subject_id')
unstructured_filtered = unstructured_filtered.sort_values('subject_id')

In [None]:
# Extract x and y 
X_test_struct = structured_filtered.drop(columns=drop_cols_struct)
X_test_text = unstructured_filtered.drop(columns=drop_cols_text)
y_test_struct = structured_filtered['group']  # can also use unstructured_filtered['group']

In [39]:

# transform and predict
X_test_struct_scaled = scaler_struct.transform(X_test_struct)
X_test_text_scaled = scaler_text.transform(X_test_text)

probs_structured = structured_model.predict_proba(X_test_struct_scaled)
probs_unstructured = unstructured_model.predict_proba(X_test_text_scaled)

# average and predict
fused_probs = (probs_structured + probs_unstructured) / 2
final_preds = fused_probs.argmax(axis=1)

# evaluation metrics / report
print("\n Late Fusion Model (Structured + Unstructured) ")
print("Fused Accuracy:", accuracy_score(y_test_struct, final_preds))
print("Fused Classification Report:\n", classification_report(y_test_struct, final_preds, target_names=["OSA", "OSA + HF", "HF"]))
print("Fused Confusion Matrix:\n", confusion_matrix(y_test_struct, final_preds))



=== Late Fusion Model (Structured + Unstructured) ===
Fused Accuracy: 0.8342792281498297
Fused Classification Report:
               precision    recall  f1-score   support

         OSA       0.89      0.89      0.89       384
    OSA + HF       0.38      0.26      0.31        58
          HF       0.83      0.86      0.85       439

    accuracy                           0.83       881
   macro avg       0.70      0.67      0.68       881
weighted avg       0.82      0.83      0.83       881

Fused Confusion Matrix:
 [[341   5  38]
 [  4  15  39]
 [ 40  20 379]]


In [None]:
print(mimic['group'].value_counts())
