In [1]:
from math import ceil
import random

import pandas as pd
import torch
import numpy as np
import os
import json

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sqlalchemy.orm import sessionmaker
from tqdm import tqdm
from torch_geometric.data import Data
from mimic.orm_create.mimiciv_v3_orm import Labels, Note, PreprocessedRevisedNote
from sqlalchemy import create_engine, and_, func
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn import global_mean_pool
from torch import optim
from torch.nn import functional as F
from sklearn import metrics
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/'
DIR_NAME = "20250816_qwen32b_kgs_out"#"20250812_qwen_1_7b_kgs_out" #"20250810_qwen_14b_kgs_out"

  import pkg_resources


In [2]:
def set_all_seeds(seed):
    """Set seeds for reproducibility ."""
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    g = torch.Generator()
    g.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    return g

def seed_worker(worker_id):
    """Set seed for DataLoader workers."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [3]:
def get_session():
    DB_URI = "postgresql://postgres:password@localhost:5432/mimicIV_v3"
    engine = create_engine(DB_URI)
    Session = sessionmaker(bind=engine)
    session = Session()
    return session

In [4]:
from torch_geometric.utils import add_self_loops

def get_row_id(graph_dir):
	file_name = f"batched_notes/{graph_dir}.json"
	with open(file_name, "r") as f:
		data = json.load(f)
	return data["row_id"]

def get_edges(graph):
    ## TODO Build up edge vocabulary in graph
    entity_dict = {key: i for i, key in enumerate(graph["entities"])}
    edge_dict = {key: i for i, key in enumerate(graph["edges"])}
    edge_idx, edge_attr = [], []
    for relation in graph["relations"]:
        src, rel, trg = relation
        edge_idx.append((entity_dict[src], entity_dict[trg]))
    edge_idx = np.array(edge_idx)
    edge_idx = torch.from_numpy(edge_idx).type(torch.int64)
    edge_idx = edge_idx.transpose(-1, 0)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32)
    if edge_idx.ndim == 1:
        edge_idx = torch.tensor([[], []], dtype=torch.long)
    if edge_attr.ndim == 1:
        edge_attr = torch.tensor([], dtype=torch.float)
    return edge_idx, edge_attr

def get_label(session, dir):
    row_id = get_row_id(dir)
    label = session.query(Labels.label).where(Labels.row_id == row_id).one_or_none()
    label = int(label[0])
    label = torch.tensor([label], dtype=torch.float)
    return label, row_id
  
def get_entities_list(graph):
    graph["entities"] = list(map(lambda e: e, graph["entities"])) ## prevent nothing and use def. tokenizer -> Lets just map multiple tokens back to one word via sum of the tokenizer
    entities_list = graph["entities"]
    return entities_list

def add_readout_node(data):
    data.x = torch.concatenate([data.x, torch.mean(data.x, dim = 0).unsqueeze(0)])
    data.readout_mask = torch.zeros(data.x.shape[0], dtype=torch.bool)
    data.readout_mask[-1] = 1
    """Connect all nodes in the KG to the last read-out node"""
    new_src = torch.arange(data.x.shape[0])
    new_trg = torch.ones_like(new_src)*(data.x.shape[0]- 1)
    ## -1 and -1 represent mean that I dont want to add a self edge on the readout node, i.e., theoretically on attention heads from KG nodes can contribute to the read out node
    readout_edges = torch.stack([new_src[:-1], new_trg[:-1]])
    data.edge_index = torch.cat([data.edge_index, readout_edges], dim = -1)
    #data.edge_attr = torch.cat([data.edge_attr, torch.ones((readout_edges.shape[-1], 768), dtype = torch.float)], dim = 0)
    
def create_graph(session, dir):
    graph = json.load(open(f"{DIR_NAME}/{dir}/graph.json")) #json.load(open(os.path.join("revised_kgs", dir, "graph.json"), "r"))
    if len(graph["entities"]) == 0: return None
    edge_index, _ = get_edges(graph)
    y, row_id = get_label(session, dir)
    x = torch.zeros(len(graph["entities"]), 1, dtype=torch.float32)
    entities_list = get_entities_list(graph)
    edge_index = add_self_loops(edge_index)[0]
    data = Data(x=x, edge_index=edge_index, edge_attr=None, y=y, entities_list = entities_list, row_id= row_id, dir = dir)
    add_readout_node(data)
    return data

In [5]:
def get_valid_note_row_ids(session):
    """Necessary because in my prevoius filtering steps I didnt excluded them and KG creation is still running (dont want interrupt)"""
    db_note_row_ids = session.query(PreprocessedRevisedNote.row_id).filter(and_(func.lower(PreprocessedRevisedNote.text).not_like("%sepsis%"), func.lower(PreprocessedRevisedNote.text).not_like("%septic%"), func.lower(PreprocessedRevisedNote.text).not_like("%shock%"))).all()
    db_note_row_ids = list(map(lambda n: n[0], db_note_row_ids))
    return db_note_row_ids

In [6]:
def get_graphs():
    data_graphs = []
    session = get_session()
    db_note_row_ids = get_valid_note_row_ids(session)

    print(os.getcwd())
    kg_dirs = os.listdir(DIR_NAME) #os.listdir("revised_kgs")
    kg_dirs = list(filter(lambda n: "." not in n, kg_dirs))
    kg_dirs = list(map(int, kg_dirs))
    kg_dirs.sort()
    for dir in tqdm(kg_dirs):
        if int(get_row_id(dir)) not in db_note_row_ids: continue
        data = create_graph(session, str(dir))
        data_graphs.append(data)
    session.close()
    return data_graphs

In [7]:
def get_train_and_test_graphs():
    data_graphs = get_graphs()
    ## TODO Use time-based splits across all experiments
    train_idx, test_idx = train_test_split(np.arange(len(data_graphs)), test_size=0.2, stratify=[data.y for data in data_graphs], random_state=42) #
    
    train_data = [data_graphs[idx] for idx in train_idx]
    test_data = [data_graphs[idx] for idx in test_idx]
    return train_data, test_data

In [8]:
import numpy as np
import torch
import os
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import gc

def get_medical_embeddings(word_groups: list[str], model: SentenceTransformer) -> np.ndarray:
    embeddings = model.encode(word_groups, show_progress_bar=False)
    return embeddings

def get_vectorized_train_and_test_graphs():
    base_processed_directory = "processed_pyg_graphs"
    train_directory = os.path.join(base_processed_directory, "train")
    test_directory = os.path.join(base_processed_directory, "test")

    os.makedirs(train_directory, exist_ok=True)
    os.makedirs(test_directory, exist_ok=True)

    print("Loading Sentence Transformer model into memory...")
    embedding_model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')

    raw_train_data, raw_test_data = get_train_and_test_graphs()

    vectorized_train_graphs = []
    print("Checking for and processing training graphs...")
    for data in tqdm(raw_train_data, desc="Train Set"):
        graph_filename = f"{data.row_id}.pt"
        graph_path = os.path.join(train_directory, graph_filename)

        if os.path.exists(graph_path):
            # --- FIX APPLIED HERE ---
            processed_graph = torch.load(graph_path, weights_only=False)
            vectorized_train_graphs.append(processed_graph)
        else:
            emb = get_medical_embeddings(data.entities_list, embedding_model)
            vectorized_entities = torch.from_numpy(emb).to(torch.float)
            num_features = vectorized_entities.shape[-1]
            readout_node = torch.zeros((1, num_features), dtype=torch.float32)
            data.x = torch.cat([vectorized_entities, readout_node], dim=0)

            torch.save(data, graph_path)
            vectorized_train_graphs.append(data)

            del emb, vectorized_entities, readout_node
            gc.collect()
            torch.cuda.empty_cache()

    vectorized_test_graphs = []
    print("Checking for and processing testing graphs...")
    for data in tqdm(raw_test_data, desc="Test Set"):
        graph_filename = f"{data.row_id}.pt"
        graph_path = os.path.join(test_directory, graph_filename)

        if os.path.exists(graph_path):
            # --- FIX APPLIED HERE ---
            processed_graph = torch.load(graph_path, weights_only=False)
            vectorized_test_graphs.append(processed_graph)
        else:
            emb = get_medical_embeddings(data.entities_list, embedding_model)
            vectorized_entities = torch.from_numpy(emb).to(torch.float)
            num_features = vectorized_entities.shape[-1]
            readout_node = torch.zeros((1, num_features), dtype=torch.float32)
            data.x = torch.cat([vectorized_entities, readout_node], dim=0)

            torch.save(data, graph_path)
            vectorized_test_graphs.append(data)

            del emb, vectorized_entities, readout_node
            gc.collect()
            torch.cuda.empty_cache()

    return vectorized_train_graphs, vectorized_test_graphs

In [9]:
train_graphs, test_graphs = get_vectorized_train_and_test_graphs()

Loading Sentence Transformer model into memory...
C:\Users\danie\git\KARDIA\mimic\read\training


100%|██████████| 16348/16348 [10:28<00:00, 26.01it/s] 


Checking for and processing training graphs...


Train Set: 100%|██████████| 12527/12527 [00:13<00:00, 922.83it/s]


Checking for and processing testing graphs...


Test Set: 100%|██████████| 3132/3132 [00:03<00:00, 923.92it/s]


In [122]:
from torch_geometric.nn.conv import SAGEConv, GCNConv, GATConv

class GNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout, heads=1, heads_dropout = .0):
       super(GNN, self).__init__()
       self.conv1 = GATConv(in_dim, hidden_dim,  add_self_loops=True, concat=True, heads=heads, dropout=heads_dropout, residual=True)
       # self.conv2 = GATConv(hidden_dim*heads, hidden_dim,  add_self_loops=True, concat=True, heads=heads, dropout=heads_dropout, residual=True)
       self.conv3 = GATConv(hidden_dim*heads, 1, add_self_loops=False, concat=False, heads=heads, dropout=heads_dropout, residual=True) # edge_dim=hidden_dim,
       self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x, edge_index, **kwargs):
        # x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        # x = self.dropout(x)
        x = self.conv1(x, edge_index)#, edge_attr
        x = F.elu(x)
        x = self.dropout(x)
        x, (edge_index_attn, alpha) = self.conv3(x, edge_index, return_attention_weights=True) #, edge_attr
        return x, (edge_index_attn, alpha)



In [123]:
train_graphs[0].x.shape

torch.Size([38, 768])

In [124]:
 def evaluate(model, loader, device):
    """Evaluates the model and returns AUROC and AUPRC."""
    model = model.to(device)
    with torch.inference_mode():
       model.eval()
       pred_probas = []
       y_trues = []
       for batch_data in loader:
          batch_data = batch_data.to(device)
          logit, _ = model(batch_data.x, batch_data.edge_index, edge_attr=batch_data.edge_attr)
          logit = logit[batch_data.readout_mask]
          pred_proba = torch.sigmoid(logit)
          pred_probas.extend(pred_proba.cpu().tolist())
          y_trues.extend(batch_data.y.cpu().tolist())

    y_true = np.array(y_trues)
    y_pred_proba = np.array(pred_probas)
    auroc = metrics.roc_auc_score(y_true, y_pred_proba)
    auprc = metrics.average_precision_score(y_true, y_pred_proba)
    return auroc, auprc

In [125]:
N_SPLITS = 3
MAX_EVALS = 50
SEED = 50 #0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
final_generator = set_all_seeds(SEED)

In [126]:
space = {
    # "hidden_dim": hp.choice('hidden_dim', [8, 16, 32]),
    "dropout": hp.loguniform('dropout', np.log(1e-2), np.log(1e-1)),
    "heads_dropout": hp.loguniform('heads_dropout', np.log(1e-2), np.log(2e-1)),
    "lr": hp.loguniform('lr', np.log(1e-3), np.log(1e-2)),
    # "weight_decay": hp.loguniform('weight_decay', np.log(1e-6), np.log(1e-3)),
    "weight_decay": hp.choice('weight_decay', [0]),
    "heads": hp.quniform('heads', 1, 6, 1),
    # "batch_size": hp.quniform('batch_size', 16, 128, 16),
    "epochs": hp.quniform('epochs', 10, 200, 1) # Epochs are now tunable
}

In [149]:
y_train_full = [d.y.item() for d in train_graphs]
# best_hyperparams = {'batch_size': 256, 'dropout': 0.8, 'epochs': 600, 'heads': 2, 'hidden_dim': 16, 'lr': .01, 'weight_decay': 1e-5, "heads_dropout": .1}
# best_hyperparams = {'batch_size': 256, 'dropout': 0.0, 'epochs': 40, 'heads': 2, 'hidden_dim': 1, 'lr': .01, 'weight_decay': 1e-5, "heads_dropout": .1}
best_hyperparams = {'batch_size': 256, 'dropout': 0.5, 'epochs': 100, 'heads': 8, 'hidden_dim': 256, 'lr': .005, 'weight_decay': 0, "heads_dropout": 0.2}

In [150]:
def train(loss_fn, loader, best_hyperparams):
    model = GNN(
                in_dim=train_graphs[0].x.shape[1],
                hidden_dim=best_hyperparams["hidden_dim"],
                dropout=best_hyperparams["dropout"],
                heads=best_hyperparams["heads"],
                heads_dropout=best_hyperparams["heads_dropout"]
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=best_hyperparams["lr"], weight_decay=best_hyperparams["weight_decay"])
    for epoch in range(best_hyperparams['epochs']):
        epoch_loss = 0
        model.train()
        for batch_data in loader:
           optimizer.zero_grad()
           batch_data = batch_data.to(device)
           logit, _ = model(batch_data.x, batch_data.edge_index)
           logit = logit[batch_data.readout_mask]
           loss = loss_fn(logit.squeeze(), batch_data.y)
           loss.backward()
           optimizer.step()
           epoch_loss += loss.item()
        avg_loss = epoch_loss / len(final_train_loader)
        #print(f"Epoch: {epoch} Train AUROC: {auroc_train:.4f} | Test AUROC: {auroc_test:.4f}")
    return model

In [151]:
from operator import itemgetter

def tune_hyperparameters(train_graphs, y_train_full, space_params, k = 3, max_evals = 50, maximize_metric = True):

    def objective(params):
        if 'epochs' in params:
            params['epochs'] = int(params['epochs'])
        if 'heads' in params:
            params['heads'] = int(params['heads'])

        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
        scores = []

        for train_index, val_index in skf.split(train_graphs, y_train_full):
            inner_train_graphs = itemgetter(*train_index)(train_graphs)
            inner_val_graphs = itemgetter(*val_index)(train_graphs)
            inner_train_graphs = [graph.to(device) for graph in inner_train_graphs]
            inner_val_graphs = [graph.to(device) for graph in inner_val_graphs]
            inner_train_loader = DataLoader(
                inner_train_graphs,
                batch_size=256,
                shuffle=True,
                worker_init_fn=seed_worker,
                generator=final_generator
            )
            inner_val_loader = DataLoader(
                inner_val_graphs,
                batch_size=256,
                shuffle=False
            )
            pos_weight = ceil(sum([data.y[0] == 0 for data in inner_train_graphs]) / sum([data.y[0] == 1 for data in inner_train_graphs]))
            loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
            model = train(loss_fn, inner_train_loader, best_hyperparams)
            auroc_test, auprc_test = evaluate(model, inner_val_loader, device)
            #auroc_train, auprc_train = evaluate(model, final_train_loader, device)
            scores.append(auroc_test)

        average_score = np.mean(scores)
        
        loss = -average_score if maximize_metric else average_score

        return {'loss': loss, 'status': STATUS_OK}

    trials = Trials()
    best_params = fmin(
        fn=objective,
        space=space_params,
        algo=tpe.suggest,
        max_evals=max_evals,
        trials=trials,
        rstate=np.random.default_rng(42)
    )

    best_metric_score = -trials.best_trial['result']['loss'] if maximize_metric else trials.best_trial['result']['loss']

    return best_params, best_metric_score

In [152]:
# best_hyperparams, best_metric_score = tune_hyperparameters(train_graphs, y_train_full, space, max_evals=100, maximize_metric = True, k = 5)
# best_hyperparams, best_metric_score

In [153]:
# DataLoaders for final training
# train_graphs = [graph.to(device) for graph in train_graphs]
# test_graphs = [graph.to(device) for graph in test_graphs]
final_train_loader = DataLoader(
    train_graphs,
    batch_size=256,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=final_generator
)
# Test loader does not need to be shuffled
test_loader = DataLoader(
    test_graphs,
    batch_size=32,
    shuffle=False
)


# Instantiate the final model
final_model = GNN(
    in_dim=train_graphs[0].x.shape[1],
    hidden_dim=1,
    dropout=best_hyperparams["dropout"],
    heads=int(best_hyperparams["heads"]),
    heads_dropout=best_hyperparams["heads_dropout"]
).to(device)

In [154]:
optimizer = optim.Adam(
    final_model.parameters(),
    lr=best_hyperparams["lr"],
    weight_decay=best_hyperparams["weight_decay"]
)

# Recalculate pos_weight on the full training data
pos_weight = ceil(sum([data.y[0] == 0 for data in train_graphs]) / sum([data.y[0] == 1 for data in train_graphs]))
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

In [155]:
num_epochs_final = 60 #int(best_hyperparams['epochs'])
for epoch in range(num_epochs_final):
    epoch_loss = 0
    final_model.train()
    for batch_data in final_train_loader:
       optimizer.zero_grad()
       batch_data = batch_data.to(device)
       logit, _ = final_model(batch_data.x, batch_data.edge_index)
       logit = logit[batch_data.readout_mask]
       loss = loss_fn(logit.squeeze(), batch_data.y)
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
    avg_loss = epoch_loss / len(final_train_loader)
    auroc_test, auprc_test = evaluate(final_model, test_loader, device)
    auroc_train, auprc_train = evaluate(final_model, final_train_loader, device)

    print(f"Epoch: {epoch} Train AUROC: {auroc_train:.4f} | Test AUROC: {auroc_test:.4f}")

Epoch: 0 Train AUROC: 0.7078 | Test AUROC: 0.7342
Epoch: 1 Train AUROC: 0.7006 | Test AUROC: 0.7326
Epoch: 2 Train AUROC: 0.6944 | Test AUROC: 0.7127
Epoch: 3 Train AUROC: 0.7328 | Test AUROC: 0.7472
Epoch: 4 Train AUROC: 0.7585 | Test AUROC: 0.7735
Epoch: 5 Train AUROC: 0.7809 | Test AUROC: 0.7905
Epoch: 6 Train AUROC: 0.7916 | Test AUROC: 0.7918
Epoch: 7 Train AUROC: 0.8084 | Test AUROC: 0.8023
Epoch: 8 Train AUROC: 0.8193 | Test AUROC: 0.8055
Epoch: 9 Train AUROC: 0.8290 | Test AUROC: 0.8082
Epoch: 10 Train AUROC: 0.8327 | Test AUROC: 0.8124
Epoch: 11 Train AUROC: 0.8356 | Test AUROC: 0.8156
Epoch: 12 Train AUROC: 0.8417 | Test AUROC: 0.8161
Epoch: 13 Train AUROC: 0.8400 | Test AUROC: 0.8116
Epoch: 14 Train AUROC: 0.8430 | Test AUROC: 0.8145
Epoch: 15 Train AUROC: 0.8526 | Test AUROC: 0.8194
Epoch: 16 Train AUROC: 0.8516 | Test AUROC: 0.8211
Epoch: 17 Train AUROC: 0.8536 | Test AUROC: 0.8203
Epoch: 18 Train AUROC: 0.8575 | Test AUROC: 0.8186
Epoch: 19 Train AUROC: 0.8562 | Test AURO

In [174]:
from torch_geometric.utils import degree

out_degree_list = []
in_degree_list = []
num_edges_list = []
row_ids = []

for i in range(len(train_graphs)):
    out_degree = degree(train_graphs[i].edge_index[0]).mean().item()
    in_degree = degree(train_graphs[i].edge_index[1]).mean().item()
    num_edges = train_graphs[i].edge_index.shape[-1] - train_graphs[i].x.shape[0] + 1 # +1 to not substract the readout node
    out_degree_list.append(out_degree)
    in_degree_list.append(in_degree)
    num_edges_list.append(num_edges)
    row_ids.append(train_graphs[i].dir)

df = pd.DataFrame()
df["out_degree"] = out_degree_list
df["in_degree"] = in_degree_list
df["num_edges"] = num_edges_list
df["row_id"] = row_ids
df.describe()

Unnamed: 0,out_degree,in_degree,num_edges
count,12527.0,12527.0,12527.0
mean,2.488514,2.400479,50.289215
std,0.497382,0.478058,28.836492
min,1.0,0.833333,0.0
25%,2.25,2.177778,31.0
50%,2.56,2.466667,46.0
75%,2.866667,2.75,65.0
max,6.5,6.157895,413.0


In [173]:
df[df["num_edges"] == 0]

Unnamed: 0,out_degree,in_degree,num_edges,row_id
79,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,194
88,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,4086
94,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,8641
105,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,7413
108,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,11791
...,...,...,...,...
12387,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,15129
12394,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,7830
12416,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,5851
12502,tensor(1.),"[tensor(0.), tensor(0.), tensor(0.), tensor(0....",0,12828
