In [None]:
import logging
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
from collections import Counter
import random
import pickle as pkl
import os
from sklearn.cluster import KMeans

## Creating patient similarity graphs

The cosine similarity matrix, obtained from the similarity among patient feature vectors from medical codes, will be the basis for the creation of the patient similarity graph/GNN. 
In the end, we will add these embeddings as node feature vectors and reindex the graph to avoid bugs in the GNNs.  
First, we will create a single graph for our transductive learning setting.  

In [None]:
# loading patient data
path = r"PATH/patient_emb_df.csv"
patient_data = pd.read_csv(path)

patient_data = patient_data.drop(columns=['Unnamed: 0'])
patient_data

In [None]:
embeddings_df = pd.read_csv(r"C:PATH/patient_emb_cols.csv") 
embeddings_df = embeddings_df.drop(columns=['Unnamed: 0'])
embeddings_df

In [None]:
patient_data = patient_data.join(embeddings_df)
patient_data

In [None]:
patient_data = patient_data.drop(columns=["patient_embedding"])
patient_data

In [None]:
patient_data.to_csv(r"PATH/patient_data.csv")

#### Loading the patient similarity matrix
Based on the cosine similarity among patient vectors, dim=300.

In [None]:
def generate_similarity_graph(k, sim_matrix, patient_df):
    logging.info('Generating similarity graph...')
    G_knn = nx.Graph()

    # ensure the patient_df and sim_matrix are of the same length
    assert len(patient_df) == len(sim_matrix), "The patient df and similarity matrix must have the same number of rows."

    patient_ids = patient_df['patient_id'].values
    G_knn.add_nodes_from(patient_ids) # patients as nodes
    
    # top k most similar neighbors
    for i, patient_i in tqdm(enumerate(patient_ids), total=len(patient_ids), desc='Generating graph'):
        # get the indices of the top k most similar patients
        # note: argsort()[-k-1:-1] could include the patient themself if they are in their own top k neighbors,
        # so we need to check for self-loops when adding edges
        patient_i_similarity = sim_matrix.iloc[i] # get patient i's row of similarity in the matrix
        patient_i_similarity_asc = patient_i_similarity.argsort() # sort scores in ascending order
        top_k_neighbors_indices = patient_i_similarity_asc[-k-1:-1][::-1] # select top k indices excluding patient
        
        # add edges between the patient and their top k neighbors
        for j in top_k_neighbors_indices:
            patient_j = patient_ids[j]
            # check if patient_i and patient_j are not the same to avoid self-loops
            if patient_i != patient_j:# and not G_knn.has_edge(patient_i, patient_j) and not G_knn.has_edge(patient_j, patient_i):
                similarity = sim_matrix.iloc[i, j]
                G_knn.add_edge(patient_i, patient_j, weight=similarity)

    logging.info('Graph generation complete.')
    return G_knn

In [None]:
similarity_matrix_cos = pd.read_csv(r"PATH/cosine_sim_df.csv").drop(columns=['Unnamed: 0']) 
similarity_matrix_cos = similarity_matrix_cos.clip(lower=0) # no need for negative sim
similarity_matrix_cos.head()

In [None]:
# we should get a square matrix, where the diagonal = 1 (patient is 100% similar to itself) and the rest are sim measurements
similarity_matrix_cos.shape

In [None]:
# evaluating similarity matrix
similarity_matrix = similarity_matrix_cos.to_numpy()
print(f"Mean similarity: {np.nanmean(similarity_matrix)}")
print(f"Median similarity: {np.nanmedian(similarity_matrix)}")
print(f"Standard deviation: {np.nanstd(similarity_matrix)}")

In [None]:
# calculate quartiles
q1 = np.percentile(similarity_matrix, 25)
q2 = np.percentile(similarity_matrix, 50) 
q3 = np.percentile(similarity_matrix, 75)
q4 = np.max(similarity_matrix)

print(f"1st quartile: {q1}")
print(f"2nd quartile (median): {q2}")
print(f"3rd quartile: {q3}")
print(f"4th quartile (max): {q4}")

In [None]:
# histogram of sim scores
flattened = similarity_matrix.flatten()
plt.hist(similarity_matrix.flatten(), bins=30, edgecolor='k', alpha=0.7)
plt.title('Histogram of similarity scores')
plt.xlabel('Similarity score')
plt.ylabel('Frequency')
plt.show()

In [None]:
# visualize scores between 0.75 and 0.9
plt.hist(flattened, bins=30, range=(0.75, 0.9), edgecolor='k', alpha=0.7)
plt.title('Similarity scores (0.75 to 0.9)')
plt.xlabel('Similarity score')
plt.ylabel('Frequency')
plt.show()

### Generating the similarity graph
Now that we have our similarity data, let's generate the transductive training similarity graph.  
We will use the calculated cosine similarity to measure how similar patients are, and the KNN algorithm to establish an edge between the top k most similar patients, given a patient. K will be decided with the elbow method.  
The edges will be weighted by the similarity score (attention, not all GNN architectures use this information).  

In [None]:
import plotly.express as px

def plot_elbow_curve(sim_matrix, max_k=10):
    if isinstance(sim_matrix, pd.DataFrame):
        sim_matrix = sim_matrix.values

    distortions = []
    for k in range(1, max_k + 1):
        kmeans = KMeans(n_clusters=k, random_state=22, n_init='auto').fit(sim_matrix)
        distortions.append(kmeans.inertia_)

    elbow_data = pd.DataFrame({
        "Number of clusters (K)": range(1, max_k + 1),
        "Distortion": distortions
    })

    fig = px.scatter(elbow_data,
                     x="Number of clusters (K)",
                     y="Distortion",
                     title="Elbow plot: patient similarity graph",
                     labels={
                         "Number of clusters (K)": "Number of clusters (K)",
                         "Distortion": "Distortion"
                     })

    fig.update_traces(mode="lines+markers", line=dict(color='#636EFA'))

    fig.update_layout(
        xaxis=dict(tickmode="linear", tick0=1, dtick=1),
        width=900,
        height=700,
    )

    fig.show()

In [None]:
plot_elbow_curve(similarity_matrix_cos, max_k=10)

In [None]:
# creating the graph with k = 3
G_knn = generate_similarity_graph(3, similarity_matrix_cos, patient_data) #patient_data_merged

### Evaluating the generated graph

In [None]:
print(f"Number of nodes: {G_knn.number_of_nodes()}")
print(f"Number of edges: {G_knn.number_of_edges()}")
print(f"Average degree: {sum(dict(G_knn.degree()).values()) / G_knn.number_of_nodes()}")

In [None]:
# counting ocurrence of degrees 
degree_sequence = sorted([d for n, d in G_knn.degree()], reverse=True)
degree_count = Counter(degree_sequence)
degree_df = pd.DataFrame(list(degree_count.items()), columns=['Degree', 'Count']).sort_values('Degree').reset_index(drop=True)

In [None]:
# histogram of node degrees
plt.figure(figsize=(10,6))
plt.bar(degree_df['Degree'], degree_df['Count'], color='skyblue', edgecolor='black')
plt.title('Histogram of node degrees')
plt.xlabel('Degree')
plt.ylabel('No. of nodes')
plt.grid(axis='y')
plt.show()

In [None]:
num_self_loops = len(list(nx.selfloop_edges(G_knn)))
print(f"self-loops in the graph: {num_self_loops}")

In [None]:
import matplotlib.patches as mpatches
# attention - run again after loading node features to see label colors
def bfs_sample(graph, start_node, num_nodes):
    visited = set()
    queue = [start_node]

    while queue and len(visited) < num_nodes:
        node = queue.pop(0)
        if node not in visited:
            visited.add(node)
            neighbors = set(graph.neighbors(node)) - visited
            queue.extend(neighbors)

    return list(visited)

random.seed(22)
start_node = random.choice(list(G_knn.nodes()))
sample_nodes = bfs_sample(G_knn, start_node, 200)

subG = G_knn.subgraph(sample_nodes)

def color_map(graph):
    return ['#636EFA' if graph.nodes[n].get('label', 0.0) == 0.0 else '#e8e337' for n in graph.nodes()] #add features and run again else it wont work

node_colors = color_map(subG)
plt.figure(figsize=(12, 12))
pos = nx.kamada_kawai_layout(subG) # or spectral_layout, spring_layout, circular_layout, shell_layout
nx.draw(subG, pos, with_labels=False, node_size=150, node_color=node_colors)

blue_patch = mpatches.Patch(color='#636EFA', label='Negative (0)')
yellow_patch = mpatches.Patch(color='#e8e337', label='Positive (1)')
plt.legend(handles=[blue_patch, yellow_patch])
plt.title("Patient similarity: subgraph visualization")
plt.show()

### Interpretability: visualizing node neighborhood

In [None]:
def color_map(graph, highlight_node=None):
    colors = []
    for n in graph.nodes():
        if n == highlight_node:
            colors.append('#FF5733')  
        else:
            colors.append('#636EFA' if graph.nodes[n].get('label', 0.0) == 0.0 else '#e8e337')
    return colors

def show_node_neighborhood(graph, patient_id):
    if patient_id not in graph:
        print(f"Node {patient_id} not found in the graph.")
        return
    
    immediate_neighbors = list(graph.neighbors(patient_id))
    neighbors_of_neighbors = [n for neighbor in immediate_neighbors for n in graph.neighbors(neighbor)]
    neighbors_of_neighbors = list(set(neighbors_of_neighbors + immediate_neighbors + [patient_id]))

    neighborhood_subG = graph.subgraph(neighbors_of_neighbors)
    
    node_colors = color_map(neighborhood_subG, highlight_node=patient_id)
    
    plt.figure(figsize=(10, 10))
    pos = nx.shell_layout(neighborhood_subG)

    nx.draw(neighborhood_subG, pos, with_labels=True, node_size=300, node_color=node_colors, edge_color='#BBBBBB')
    
    #target node
    node_color = '#636EFA' if graph.nodes[patient_id].get('label', 0.0) == 0.0 else '#e8e337'
    nx.draw_networkx_nodes(neighborhood_subG, pos, nodelist=[patient_id], node_size=1000, 
                           node_color=node_color, edgecolors='red', linewidths=5)

    blue_patch = mpatches.Patch(color='#636EFA', label='Negative (0)')
    yellow_patch = mpatches.Patch(color='#e8e337', label='Positive (1)')
    highlight_patch = mpatches.Patch(color='red', label='Target node')
    plt.legend(handles=[blue_patch, yellow_patch], title="True label")

    # dashed circle around immediate neighbors
    for neighbor in immediate_neighbors:
        circle = plt.Circle(pos[neighbor], radius=0.1, color='gray', fill=False, linestyle='dashed')
        plt.gca().add_patch(circle)
    
    plt.title(f"Neighborhood of FP patient, ID = {patient_id}")
    plt.show()

patient_id = 11745
show_node_neighborhood(G_knn, patient_id)


### Adding node features
For each node, let's add feature vectors - the same feature vector we used to build the graph. With this, we will then run the GNN for HF prediction in a transductive setting.

In [None]:
def add_features(label_df, graph):
    logging.info('Adding node features...')

    for _, row in tqdm(label_df.iterrows(), total=label_df.shape[0], desc='Processing features'):
        patient_id = row['patient_id']

        if graph.has_node(patient_id):
            # create a feature vector excluding the 'patient_id' and 'label' columns
            feature_vector = row.drop(['patient_id', 'label'])
            graph.nodes[patient_id]['features'] = feature_vector

            # add label
            graph.nodes[patient_id]['label'] = row['label']
        else:
            logging.warning(f"Warning: patient_id {patient_id} does not exist in the graph.")

    logging.info('Node features added.')
    return graph

In [None]:
graph_patient_ids = set(G_knn.nodes())
df_patient_ids = set(patient_data['patient_id'])

# find the patient IDs that are in label_df but not in the graph
missing_in_graph = df_patient_ids - graph_patient_ids

# find the patient IDs that are in the graph but not in label_df
extra_in_graph = graph_patient_ids - df_patient_ids

print("Missing in graph:", len(missing_in_graph))
print("Extra in graph:", len(extra_in_graph))

In [None]:
# for ablation/feature studies - replace node features
patient_data_diag = pd.read_csv(r"PATH/patient_emb_cols_ablation-diag.csv")
patient_data_proc = pd.read_csv(r"PATH/patient_emb_cols_ablation-proc.csv")
patient_data_pres = pd.read_csv(r"PATH/patient_emb_cols_ablation-pres.csv")

patient_data_demo = pd.read_csv(r"PATH/patient_emb_cols_ablation-demo.csv")

patient_data_without_diag = pd.read_csv(r"PATH/patient_emb_cols_ablation-without_diag.csv")
patient_data_without_proc = pd.read_csv(r"PATH/patient_emb_cols_ablation-without_proc.csv")
patient_data_without_pres = pd.read_csv(r"PATH/patient_emb_cols_ablation-without_pres.csv")

In [None]:
# adding node features
G_knn = add_features(patient_data, G_knn) #patient_data_merged for use with other onehot features 

# for ablation: patient_data_diag, patient_data_proc, patient_data_pres, patient_data_demo (demographics) (just change in here: ADD_FEATURES fn)!!

### Inspecting the graph with features

In [None]:
# check positive and negative labeled nodes
positive_nodes = [node for node, data in G_knn.nodes(data=True) if data['label'] == 1]
negative_nodes = [node for node, data in G_knn.nodes(data=True) if data['label'] == 0]

print(f"No. positive HF patients: {len(positive_nodes)}")
print(f"No. negative HF patients: {len(negative_nodes)}")

In [None]:
# print a node's attributes
def inspect_node(node):
    data = G_knn.nodes[node]
    neighbors = list(G_knn.neighbors(node))

    edge_weights = [G_knn[node][neighbor]['weight'] for neighbor in neighbors]
    
    return {
        "Node/Patient ID": node,
        "Label": data['label'],
        "Features": data['features'],
        "Neighbors": neighbors,
        "Edge weights": edge_weights,
        "Degrees": len(neighbors)
    }

In [None]:
# evaluating random positive and negative samples
sample_pos_node = positive_nodes[0]
sample_neg_node = negative_nodes[0]

inspect_node(sample_pos_node), inspect_node(sample_neg_node)

In [None]:
# edge weights distribution
edge_weights = [data['weight'] for _, _, data in G_knn.edges(data=True)]
plt.figure(figsize=(10,6))
plt.hist(edge_weights, bins=50, edgecolor='black', alpha=0.7)
plt.title("Distribution of edge weights")
plt.xlabel("Edge weight (similarity score)")
plt.ylabel("No. of edges")
plt.grid(True)
plt.show()

In [None]:
# check if there are any isolated nodes
isolated_nodes = [node for node, degree in G_knn.degree() if degree == 0]
print(f"No. isolated nodes: {len(isolated_nodes)}")

### Reindexing and saving the final graph
We have to reindex the final graph so we won't have any problems with the GNN processing.

In [None]:
def reindex_with_patient_id(G):
    #reindexes the nodes of the graph G from 0 to n-1, retains original patient_id as an attribute, transfers all node and edge attributes
    #G = graph to be reindexed
    node_mapping = {node: i for i, node in enumerate(G.nodes())}
    
    G_reindexed = nx.Graph()
    
    # reindex nodes and transfer node attributes
    for node, data in G.nodes(data=True):
        G_reindexed.add_node(node_mapping[node], patient_id=node, **data)
    
    # reindex edges and transfer edge attributes
    for u, v, data in G.edges(data=True):
        G_reindexed.add_edge(node_mapping[u], node_mapping[v], **data)
    
    return G_reindexed

def check_patient_id_consistency(original_graph, reindexed_graph):
    #check if nodes in the reindexed graph have the correct patient_id info
    # create a mapping from reindexed nodes to their original patient_id
    reindexed_to_patient_id = nx.get_node_attributes(reindexed_graph, 'patient_id')
    
    # for each node in the original graph, check if its corresponding node in the reindexed graph 
    # has the correct patient_id
    for original_node in original_graph.nodes():
        reindexed_node = list(original_graph.nodes()).index(original_node)
        if reindexed_to_patient_id[reindexed_node] != original_node:
            print(f"Mismatch found: Node {reindexed_node} in reindexed graph should have patient_id {original_node}, but has {reindexed_to_patient_id[reindexed_node]}")
            return False

    print("All nodes in the reindexed graph have the correct patient_id.")
    return True


In [None]:
G_knn_reindexed = reindex_with_patient_id(G_knn)

In [None]:
G_knn_reindexed.number_of_nodes

In [None]:
is_consistent = check_patient_id_consistency(G_knn, G_knn_reindexed)
print(f"Consistency check result: {is_consistent}")

In [None]:
canonical_edges = {(min(u, v), max(u, v)) for u, v in G_knn_reindexed.edges()}
len(canonical_edges), G_knn_reindexed.is_directed()

In [None]:
# saving final graph, with features and reindexed
graph_path = r"PATH/medicalcodegraph.pkl" 

if os.path.exists(graph_path):
    with open(graph_path, 'rb') as f:
        G_knn_reindexed = pkl.load(f)
else:

    with open(graph_path, 'wb') as f:
        pkl.dump(G_knn_reindexed, f)