In [1]:
import numpy as np
from scipy import sparse
import pandas as pd
import torch
import argparse
import json
import os

In [2]:
import torch
import torch_geometric
from torch_geometric.data import Data

In [3]:

with open("../paths.json", "r") as f:
        paths = json.load(f)
        hdf = paths['hdf_dir'] 
        graph = paths['graph_dir']

In [4]:
train_diagnoses = pd.read_hdf(f'{hdf}train/diagnoses.h5',key='table')
val_diagnoses = pd.read_hdf(f'{hdf}val/diagnoses.h5',key='table')
test_diagnoses = pd.read_hdf(f'{hdf}test/diagnoses.h5',key='table')
all_diagnoses = pd.concat([train_diagnoses, val_diagnoses, test_diagnoses], axis=0)

print("the size of all diagnoses is: ", all_diagnoses.shape)

the size of all diagnoses is:  (11698, 124)


In [5]:
args = {
    "k": 3,  # 'Number of nearest neighbors for k_closest mode
    "mode": 'k_closest',  # Graph mode: k_closest or threshold
    "freq_adjust": 'store_true',  # Apply frequency adjustment
}

In [6]:
freq_adjustment = all_diagnoses.sum(axis=0) if args["freq_adjust"] else None

In [7]:
all_diagnoses

Unnamed: 0_level_0,Cardiovascular (R),Cardiovascular (R)|AICD,Cardiovascular (R)|Angina,Cardiovascular (R)|Arrhythmias,Cardiovascular (R)|Arrhythmias|atrial fibrillation - chronic,Cardiovascular (R)|Arrhythmias|atrial fibrillation - intermittent,Cardiovascular (R)|Congestive Heart Failure,Cardiovascular (R)|Congestive Heart Failure|CHF,Cardiovascular (R)|Congestive Heart Failure|CHF - severity unknown,Cardiovascular (R)|Coronary Artery Bypass,...,"apacheadmissiondx_Rhythm disturbance (atrial, supraventricular)",apacheadmissiondx_Rhythm disturbance (conduction defect),apacheadmissiondx_Seizures (primary-no structural brain disease),"apacheadmissiondx_Sepsis, GI","apacheadmissiondx_Sepsis, cutaneous/soft tissue","apacheadmissiondx_Sepsis, other","apacheadmissiondx_Sepsis, pulmonary","apacheadmissiondx_Sepsis, renal/UTI (including bladder)","apacheadmissiondx_Sepsis, unknown","grouped_apacheadmissiondx_Overdose,"
patient,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2869970,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2237473,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,...,0,0,0,0,0,0,0,1,0,0
2700691,1.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,...,0,0,0,0,0,0,1,0,0,0
1752854,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
3144222,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2885433,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2857227,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,...,0,0,0,0,0,0,1,0,0,0
1840603,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,1,0,0,0
989666,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,1,0


In [8]:
def get_device():
    """Get the best device (CUDA or CPU) for computation."""
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def calculate_score_matrix(diagnoses, freq_adjustment=None, debug=False):
    """Calculate the score matrix based on diagnosis data."""
    print('==> Calculating score matrix')
    device = get_device()
    
    # Convert diagnoses to a PyTorch tensor
    diagnoses = torch.tensor(diagnoses.values, dtype=torch.float16, device=device)
    
    if freq_adjustment is not None:
        freq_adjustment = torch.tensor(freq_adjustment.values, dtype=torch.float16, device=device)
        freq_adjustment = 1 / (freq_adjustment + 1e-8)  # Avoid division by zero
        freq_adjustment = freq_adjustment.unsqueeze(0)  # Make it broadcastable
        diagnoses *= freq_adjustment  # Apply frequency adjustment
    
    if debug:
        diagnoses = diagnoses[:1000]  # Limit data size in debug mode

    num_rows = diagnoses.size(0)
    scores = torch.zeros((num_rows, num_rows), dtype=torch.float16, device=device)
    batch_size = 500
    
    print(f'==> Processing in batches (batch size: {batch_size})...')

    # Compute score matrix in batches
    for start in range(0, num_rows, batch_size):
        end = min(start + batch_size, num_rows)
        batch = diagnoses[start:end]
        scores[start:end] = torch.mm(batch, diagnoses.T)

        # Clear cache to reduce memory pressure
        del batch
        torch.cuda.empty_cache()

    # Convert to CPU numpy array
    scores = scores.cpu().numpy()
    
    
    return scores

In [9]:
# Calculate score matrix
scores = calculate_score_matrix(all_diagnoses, freq_adjustment=freq_adjustment)
print(f'Score matrix shape: {scores.shape}')

==> Calculating score matrix
==> Processing in batches (batch size: 500)...
Score matrix shape: (11698, 11698)


In [10]:

def create_graph_pyg(diagnoses, scores, k=3, penalize=True):
    """
    use the score matrix to create a graph in PyG format
    """
    print('==> Step 1: calculate the  Penalty Similarity ')
    patient_ids = diagnoses.index.values
    diagnoses = torch.tensor(diagnoses.values).float()
    scores = torch.tensor(scores).float()
    scores.fill_diagonal_(0)  # 去掉自连接

    if penalize:
        diags_per_pt = diagnoses.sum(axis=1)
        total_combined_diags = diags_per_pt.view(-1, 1) + diags_per_pt.view(1, -1)
        scores = 5 * scores - total_combined_diags  # 惩罚项

    print('==> Step 2: select the top k edges')
    edge_index = []
    edge_attr = []

    for i in range(scores.shape[0]):
        k_highest = torch.topk(scores[i], k=k).indices
        for j in k_highest:
            edge_index.append([i, j.item()])
            edge_attr.append(scores[i, j].item())  # 边的权重

    edge_index = torch.tensor(edge_index).T  # 转换为 PyG 格式
    edge_attr = torch.tensor(edge_attr).float()

    print(f'==> generated {len(edge_attr)} edges')

    print('==> Step 3: generate the PyG data object')
    
    data = Data(
        edge_index=edge_index,
        edge_attr=edge_attr,
        num_nodes=len(diagnoses),
        patient_ids=torch.tensor(patient_ids, dtype=torch.long)  # Store patient_id
    )
    return data


In [11]:
data = create_graph_pyg(all_diagnoses, scores, k=args["k"], penalize=True)

==> Step 1: calculate the  Penalty Similarity 
==> Step 2: select the top k edges
==> generated 35094 edges
==> Step 3: generate the PyG data object


In [12]:
# Save the graph
graph_path = f'{graph}/diagnosis_graph_{args["mode"]}_k{args["k"]}.pt'
torch.save(data, graph_path)

#locad the graph
# loaded_data = torch.load("graph_data.pt")
 

In [13]:

edges_df = pd.DataFrame({"source": data.edge_index[0].cpu().numpy(), 
                         "target": data.edge_index[1].cpu().numpy(), 
                         "weight": data.edge_attr.cpu().numpy()})

# save the graph
edges_df.to_csv(graph+"graph_edges.csv", index=False)

In [None]:
# edges_df = pd.read_csv(graph+"graph_edges.csv")

In [14]:
edges_df

Unnamed: 0,source,target,weight
0,0,9101,-9.999999
1,0,10780,-9.999999
2,0,1228,-10.000000
3,1,9101,-14.999999
4,1,10780,-14.999999
...,...,...,...
35089,11696,10780,-8.999999
35090,11696,1228,-9.000000
35091,11697,10780,-9.999999
35092,11697,9101,-9.999999
