In [None]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

import networkx as nx
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader

import ast
from itertools import chain
import gc

In [None]:
# (helps with displaying dataframes containing long strings)
pd.set_option('display.max_colwidth', 0)

# Load data

In [None]:
# Get mapping of ICD-9 codes to symptoms (i.e., based on ChatGPT responses)
temp_fp = "icd9_symptom_map_v2.csv"
icd9_symptoms_map = pd.read_csv(temp_fp, dtype={"icd9_first_3": str, "symptoms": str})

In [None]:
# Some ChatGPT responses included some extra preamble and formatting - remove this for ease of analysis
icd9_symptoms_map["symptoms_clean"] =  icd9_symptoms_map["symptoms"].apply(lambda x: x.split("Symptoms: ")[1])

In [None]:
# Convert symptoms from string to list
icd9_symptoms_map["symptoms_list"] = icd9_symptoms_map["symptoms_clean"].apply(lambda x: [y.lower() for y in x.split(", ")])

In [None]:
# Get symptoms extracted from clinical notes for each admission
temp_fp = "notes_and_symptoms.csv"
notes_and_symptoms = pd.read_csv(temp_fp)

In [None]:
# If there are multiple records with the same HADM_ID, just keep the first one
notes_and_symptoms = notes_and_symptoms.drop_duplicates("HADM_ID", keep="first")

In [None]:
# Convert symptoms from string to list
notes_and_symptoms["symptoms"] = notes_and_symptoms["symptoms"].apply(lambda x: ast.literal_eval(x))

In [None]:
# Get unique symptoms for each admission
notes_and_symptoms["symptoms_unique"] = notes_and_symptoms["symptoms"].apply(lambda x: list(set(x)))

In [None]:
# Get diagnoses
diagnoses = pd.read_csv("DIAGNOSES_ICD.csv.gz")

In [None]:
# For simplicity, pick first diagnosis in sequence for each admission
diagnoses = diagnoses.query("SEQ_NUM == 1")
diagnoses = diagnoses[["HADM_ID", "ICD9_CODE"]]

In [None]:
# Get first 3 digits of ICD9 code
diagnoses["icd9_first_3"] = diagnoses["ICD9_CODE"].apply(lambda x: x[0:3])

In [None]:
# Merge diagnoses with symptoms from discharge notes
merge_df = pd.merge(diagnoses, notes_and_symptoms, on="HADM_ID", how="inner")

In [None]:
# Filter ICD9-symptoms map to only include diagnoses that appear in the data (first 3 digits only)
icd9_symptoms_map = icd9_symptoms_map[icd9_symptoms_map["icd9_first_3"].isin(merge_df["icd9_first_3"])]

In [None]:
# Filter merged data to only include diagnoses where we looked up associated symptoms
# (i.e., excluding supplementary info)
merge_df = merge_df[merge_df["icd9_first_3"].isin(icd9_symptoms_map["icd9_first_3"])]

In [None]:
# Split admissions into train and test sets
train_perc = 0.5
np.random.seed(777)

n_total = merge_df.shape[0]
train_idx = np.random.choice(list(range(0, n_total)), size=round(train_perc * n_total), replace=False)
test_idx = [i for i in range(0, n_total) if i not in train_idx]

# Build graph

In [None]:
G = nx.Graph()

In [None]:
all_icd = icd9_symptoms_map["icd9_first_3"].values
all_symptoms = list(
    set(chain.from_iterable(icd9_symptoms_map["symptoms_list"])).union(
        set(chain.from_iterable(merge_df["symptoms_unique"])
    )
))
all_admissions = merge_df["HADM_ID"].values

In [None]:
# Add nodes
G.add_nodes_from(["icd_" + x for x in all_icd])
G.add_nodes_from(["symptom_" + x for x in all_symptoms])
G.add_nodes_from(["hadm_" + str(x) for x in all_admissions])

In [None]:
# Create a mapping from node labels to indices
node_mapping = {node: idx for idx, node in enumerate(G.nodes())}

In [None]:
# Add edges - ICD9 to symptom
for i in range(0, icd9_symptoms_map.shape[0]):
    temp_icd = icd9_symptoms_map["icd9_first_3"].iloc[i]
    temp_symptoms_list = icd9_symptoms_map["symptoms_list"].iloc[i]
    for s in temp_symptoms_list:
        G.add_edges_from([("icd_" + temp_icd, "symptom_" + s)])

In [None]:
# Add edges - admission to symptom
for i in range(0, merge_df.shape[0]):
    temp_admit = merge_df["HADM_ID"].iloc[i]
    temp_symptoms_list = merge_df["symptoms"].iloc[i]
    for s in temp_symptoms_list:
        G.add_edges_from([("hadm_" + str(temp_admit), "symptom_" + s)])

In [None]:
# Convert to PyTorch Geometric data
edge_index = torch.tensor([[node_mapping[edge[0]], node_mapping[edge[1]]] for edge in G.edges()]).t().contiguous()

# Use identity matrix as node features
x = torch.eye(len(G.nodes()))

PyG_data = Data(x=x, edge_index=edge_index)

In [None]:
# Get node indices for ICD9 codes
icd_node_idx = torch.tensor([node_mapping[node] for node in ["icd_" + x for x in all_icd]])

In [None]:
# Get node indices for admissions - training set
hadm_node_idx_train = torch.tensor([node_mapping[node] for node in ["hadm_" + str(x) for x in all_admissions[train_idx]]])

In [None]:
# Get node indices for admissions - test set
hadm_node_idx_test = torch.tensor([node_mapping[node] for node in ["hadm_" + str(x) for x in all_admissions[test_idx]]])

In [None]:
# Create label vector for each admission - training set
# (i.e., 1 in the column corresponding with the primary diagnosis, 0 elsewhere)
mlb = MultiLabelBinarizer(classes=all_icd)
train_labels = mlb.fit_transform(merge_df.iloc[train_idx, :]["icd9_first_3"].apply(lambda x: [x]))
train_labels = torch.tensor(train_labels, dtype=torch.float)

In [None]:
# Create label vector for each admission - test set
test_labels = mlb.fit_transform(merge_df.iloc[test_idx, :]["icd9_first_3"].apply(lambda x: [x]))
test_labels = torch.tensor(test_labels, dtype=torch.float)

In [None]:
# Create label vector for each ICD9 node
# (ends up just being identity matrix)
icd_node_labels = torch.eye(len(icd_node_idx))

# Train GNN model

In [None]:
# Create simple GNN model with two graph convolutional layers
class GNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x, edge_index)
        return torch.nn.functional.log_softmax(x, dim=1)

In [None]:
# Initialize model
model = GNN(num_features=PyG_data.num_features, hidden_dim=16, num_classes=len(all_diags))

In [None]:
# Set optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Train model
model.train()
for i in range(0, 10):
    optimizer.zero_grad()
    full_output = model(PyG_data)
    
    # Get output for ICD9 nodes
    icd_output = full_output[icd_node_idx]
    
    # Get output for admission nodes
    hadm_output = full_output[hadm_node_idx_train]
    
    loss = criterion(hadm_output, train_labels) + criterion(icd_output, icd_node_labels)
    print(loss)
    loss.backward()
    optimizer.step()

    gc.collect()

# Evaluate GNN model

In [None]:
# Evaluate training accuracy
pred_idx = hadm_output.argmax(axis=1)
gold_idx = train_labels.argmax(axis=1)
train_accuracy = pred_idx.eq(gold_idx).sum().item() / len(gold_idx)
print(train_accuracy)

In [None]:
# Evaluate accuracy on ICD9 nodes
pred_idx = icd_output.argmax(axis=1)
gold_idx = icd_node_labels.argmax(axis=1)
icd_node_accuracy = pred_idx.eq(gold_idx).sum().item() / len(gold_idx)
print(icd_node_accuracy)

In [None]:
# Evaluate test set accuracy
hadm_output_test = full_output[hadm_node_idx_test]
pred_idx = hadm_output_test.argmax(axis=1)
gold_idx = test_labels.argmax(axis=1)
test_accuracy = pred_idx.eq(gold_idx).sum().item() / len(gold_idx)
print(test_accuracy)

In [None]:
# Define function to help plot ROC curves
def plot_roc(y_prob, y_actual):
    # Calculate ROC curve
    fpr, tpr, thresholds = roc_curve(y_actual, y_prob)
    
    # Calculate AUC
    auc = roc_auc_score(y_actual, y_prob)
    
    # Plot ROC curve
    plt.figure()
    plt.plot(fpr, tpr, color="blue", lw=2, label=f"ROC curve (area = {auc:.4f})")
    plt.plot([0, 1], [0, 1], color="gray", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Plot ROC curve for training set
plot_roc(hadm_output.exp().flatten().detach().numpy(), train_labels.flatten().numpy())

In [None]:
# Plot ROC curve for test set
plot_roc(hadm_output_test.exp().flatten().detach().numpy(), test_labels.flatten().numpy())

# Compare with baseline model

In [None]:
# Initialize logistic regression model
model_logit = LogisticRegression(max_iter=200, C=0.1)

In [None]:
# Convert symptoms extracted from admission data into one-hot encoded columns
mlb = MultiLabelBinarizer()
hadm_symptom_matrix = mlb.fit_transform(merge_df["symptoms_unique"])

In [None]:
# Get train and test sets
X_train = hadm_symptom_matrix[train_idx, :]
X_test = hadm_symptom_matrix[test_idx, :]

y_train = merge_df["icd9_first_3"].iloc[train_idx].values
y_test = merge_df["icd9_first_3"].iloc[test_idx].values

In [None]:
# Train model
model_logit.fit(X_train, y_train)

In [None]:
# Get training accuracy
model_logit.score(X_train, y_train)

In [None]:
# Get test accuracy
model_logit.score(X_test, y_test)

In [None]:
# Plot ROC curve for training set
temp_idx = [np.where(all_icd == x)[0][0] for x in model_logit.classes_]
plot_roc(model_logit.predict_proba(X_train).flatten(), train_labels[:, temp_idx].flatten().numpy())

In [None]:
# Plot ROC curve for test set
temp_idx = [np.where(all_icd == x)[0][0] for x in model_logit.classes_]
plot_roc(model_logit.predict_proba(X_test).flatten(), test_labels[:, temp_idx].flatten().numpy())