In [1]:
from math import ceil
import random

import pandas as pd
import torch
import numpy as np
import os
import json

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sqlalchemy.orm import sessionmaker
from tqdm import tqdm
from torch_geometric.data import Data
from mimic.orm_create.mimiciv_v3_orm import Labels, Note, PreprocessedRevisedNote
from sqlalchemy import create_engine, and_, func
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn import global_mean_pool
from torch import optim
from torch.nn import functional as F
from sklearn import metrics
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
import os
import gc
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/'
DIR_NAME = "20250816_qwen32b_kgs_out"#"20250812_qwen_1_7b_kgs_out" #"20250810_qwen_14b_kgs_out"

  import pkg_resources


In [2]:
def set_all_seeds(seed):
    """Set seeds for reproducibility ."""
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    g = torch.Generator()
    g.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    return g

def seed_worker(worker_id):
    """Set seed for DataLoader workers."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [3]:
def get_session():
    DB_URI = "postgresql://postgres:password@localhost:5432/mimicIV_v3"
    engine = create_engine(DB_URI)
    Session = sessionmaker(bind=engine)
    session = Session()
    return session

In [4]:
from torch_geometric.utils import add_self_loops
import re
def get_note(graph_dir):
	file_name = f"batched_notes/{graph_dir}.json"
	with open(file_name, "r") as f:
		data = json.load(f)
	return data

def get_edges(graph):
    ## TODO Build up edge vocabulary in graph
    entity_dict = {key: i for i, key in enumerate(graph["entities"])}
    edge_dict = {key: i for i, key in enumerate(graph["edges"])}
    edge_idx, edge_attr = [], []
    for relation in graph["relations"]:
        src, rel, trg = relation
        edge_idx.append((entity_dict[src], entity_dict[trg]))
    edge_idx = np.array(edge_idx)
    edge_idx = torch.from_numpy(edge_idx).type(torch.int64)
    edge_idx = edge_idx.transpose(-1, 0)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float32)
    if edge_idx.ndim == 1:
        edge_idx = torch.tensor([[], []], dtype=torch.long)
    if edge_attr.ndim == 1:
        edge_attr = torch.tensor([], dtype=torch.float)
    return edge_idx, edge_attr

def get_label(session, dir, note):
    row_id = note["row_id"]
    label = session.query(Labels.label).where(Labels.row_id == row_id).one_or_none()
    label = int(label[0])
    label = torch.tensor([label], dtype=torch.float)
    return label, row_id
  
def get_entities_list(graph):
    graph["entities"] = list(map(lambda e: e, graph["entities"])) ## prevent nothing and use def. tokenizer -> Lets just map multiple tokens back to one word via sum of the tokenizer
    entities_list = graph["entities"]
    return entities_list

def add_readout_node(data):
    data.x = torch.concatenate([data.x, torch.mean(data.x, dim = 0).unsqueeze(0)])
    data.readout_mask = torch.zeros(data.x.shape[0], dtype=torch.bool)
    data.readout_mask[-1] = 1
    """Connect all nodes in the KG to the last read-out node"""
    new_src = torch.arange(data.x.shape[0])
    new_trg = torch.ones_like(new_src)*(data.x.shape[0]- 1)
    ## -1 and -1 represent mean that I dont want to add a self edge on the readout node, i.e., theoretically on attention heads from KG nodes can contribute to the read out node
    readout_edges = torch.stack([new_src[:-1], new_trg[:-1]])
    data.edge_index = torch.cat([data.edge_index, readout_edges], dim = -1)
    #data.edge_attr = torch.cat([data.edge_attr, torch.ones((readout_edges.shape[-1], 768), dtype = torch.float)], dim = 0)
    
def create_graph(session, dir):
    graph = json.load(open(f"{DIR_NAME}/{dir}/graph.json")) #json.load(open(os.path.join("revised_kgs", dir, "graph.json"), "r"))
    note = get_note(dir)
    graph["entities"] = list(set(graph["entities"] + re.findall(r"(?u)\b\w\w+\b", note["text"])))
    if len(graph["entities"]) == 0: return None
    edge_index, _ = get_edges(graph)
    y, row_id = get_label(session, dir, note)
    x = torch.zeros(len(graph["entities"]), 1, dtype=torch.float32)
    entities_list = get_entities_list(graph)
    edge_index = add_self_loops(edge_index)[0]
    data = Data(x=x, edge_index=edge_index, edge_attr=None, y=y, entities_list = entities_list, row_id= row_id, dir = dir)
    add_readout_node(data)
    gc.collect()
    return data

In [5]:
def get_valid_note_row_ids(session):
    """Necessary because in my prevoius filtering steps I didnt excluded them and KG creation is still running (dont want interrupt)"""
    db_note_row_ids = session.query(PreprocessedRevisedNote.row_id).filter(and_(func.lower(PreprocessedRevisedNote.text).not_like("%sepsis%"), func.lower(PreprocessedRevisedNote.text).not_like("%septic%"), func.lower(PreprocessedRevisedNote.text).not_like("%shock%"))).all()
    db_note_row_ids = list(map(lambda n: n[0], db_note_row_ids))
    return db_note_row_ids

In [15]:
def get_graphs():
    data_graphs = []
    session = get_session()
    db_note_row_ids = get_valid_note_row_ids(session)
    kg_dirs = os.listdir(DIR_NAME) #os.listdir("revised_kgs")
    kg_dirs = list(filter(lambda n: "." not in n, kg_dirs))
    kg_dirs = list(map(int, kg_dirs))
    kg_dirs.sort()
    for dir in tqdm(kg_dirs[:30]):
        note = get_note(dir)
        if int(note["row_id"]) not in db_note_row_ids: continue
        data = create_graph(session, str(dir))
        data_graphs.append(data)
    session.close()
    return data_graphs

In [16]:
def get_train_and_test_graphs():
    data_graphs = get_graphs()
    ## TODO Use time-based splits across all experiments
    train_idx, test_idx = train_test_split(np.arange(len(data_graphs)), test_size=0.2, stratify=[data.y for data in data_graphs], random_state=42) #
    
    train_data = [data_graphs[idx] for idx in train_idx]
    test_data = [data_graphs[idx] for idx in test_idx]
    return train_data, test_data

In [17]:
def get_vectorized_train_and_test_graphs():
    train_data, test_data = get_train_and_test_graphs()
    vectorizer = TfidfVectorizer()
    print("Fit tokenizer")
    vectorizer.fit(list(map(lambda d: "\t".join(d.entities_list), train_data)))
    num_features = len(vectorizer.vocabulary_)

    # Process both train and test datasets
    os.makedirs("vectorized_graphs", exist_ok = True)
    for data_set_name, data_set in [("train", train_data), ("test", test_data)]:
        print(f"Processing {data_set_name} set...")
        processed_graphs = []
        for data in tqdm(data_set, desc=f"Vectorizing {data_set_name} graphs"):
            # Create a unique file path based on the graph's dir attribute
            file_path = os.path.join("vectorized_graphs", f"{data.dir}.pt")

            if os.path.exists(file_path):
                data = torch.load(file_path)
            else:
                # If not, create the vectorized graph
                vectorized_entities = vectorizer.transform(data.entities_list)
                vectorized_entities = vectorized_entities.toarray()
                vectorized_entities = torch.from_numpy(vectorized_entities).float()

                # Ensure feature dimensions match, handling potential empty lists
                if vectorized_entities.shape[0] == 0:
                    # Handle graphs with no entities
                    vectorized_entities = torch.zeros((0, num_features), dtype=torch.float32)

                readout_node = torch.zeros((1, num_features), dtype=torch.float32)
                data.x = torch.cat([vectorized_entities, readout_node], dim=0)

                # Clean up attributes that are no longer needed
                if hasattr(data, 'entities_list'):
                    del data.entities_list
                if hasattr(data, 'entities'):
                    del data.entities
                print(data.x.shape)
                print(data.edge_index.shape)
                print(data)
                # Save the newly created graph to disk for future use
                torch.save(data, file_path)

            processed_graphs.append(data)
            gc.collect()

        # Replace the original dataset with the processed one
        if data_set_name == "train":
            train_data = processed_graphs
        else:
            test_data = processed_graphs

    return train_data, test_data, vectorizer

In [18]:
gc.collect()

283489

In [19]:
train_graphs, test_graphs, vectorizer = get_vectorized_train_and_test_graphs()

100%|██████████| 30/30 [00:05<00:00,  5.49it/s]


Fit tokenizer
Processing train set...


Vectorizing train graphs:   4%|▍         | 1/24 [00:00<00:03,  6.04it/s]

torch.Size([241, 2254])
torch.Size([2, 463])
Data(x=[241, 2254], edge_index=[2, 463], y=[1], row_id=10056, dir='29', readout_mask=[241])
torch.Size([370, 2254])
torch.Size([2, 755])
Data(x=[370, 2254], edge_index=[2, 755], y=[1], row_id=100406, dir='18', readout_mask=[370])


Vectorizing train graphs:  12%|█▎        | 3/24 [00:00<00:04,  5.03it/s]

torch.Size([214, 2254])
torch.Size([2, 391])
Data(x=[214, 2254], edge_index=[2, 391], y=[1], row_id=100495, dir='25', readout_mask=[214])


Vectorizing train graphs:  17%|█▋        | 4/24 [00:00<00:04,  4.82it/s]

torch.Size([220, 2254])
torch.Size([2, 443])
Data(x=[220, 2254], edge_index=[2, 443], y=[1], row_id=100150, dir='5', readout_mask=[220])
torch.Size([282, 2254])
torch.Size([2, 559])
Data(x=[282, 2254], edge_index=[2, 559], y=[1], row_id=100581, dir='30', readout_mask=[282])


Vectorizing train graphs:  25%|██▌       | 6/24 [00:01<00:03,  5.00it/s]

torch.Size([240, 2254])
torch.Size([2, 487])
Data(x=[240, 2254], edge_index=[2, 487], y=[1], row_id=100341, dir='14', readout_mask=[240])
torch.Size([262, 2254])
torch.Size([2, 505])
Data(x=[262, 2254], edge_index=[2, 505], y=[1], row_id=100137, dir='4', readout_mask=[262])


Vectorizing train graphs:  33%|███▎      | 8/24 [00:01<00:03,  5.03it/s]

torch.Size([394, 2254])
torch.Size([2, 837])
Data(x=[394, 2254], edge_index=[2, 837], y=[1], row_id=100194, dir='10', readout_mask=[394])
torch.Size([235, 2254])
torch.Size([2, 499])
Data(x=[235, 2254], edge_index=[2, 499], y=[1], row_id=100494, dir='24', readout_mask=[235])


Vectorizing train graphs:  42%|████▏     | 10/24 [00:01<00:02,  5.07it/s]

torch.Size([249, 2254])
torch.Size([2, 530])
Data(x=[249, 2254], edge_index=[2, 530], y=[1], row_id=100557, dir='28', readout_mask=[249])
torch.Size([123, 2254])
torch.Size([2, 269])
Data(x=[123, 2254], edge_index=[2, 269], y=[1], row_id=100156, dir='6', readout_mask=[123])


Vectorizing train graphs:  50%|█████     | 12/24 [00:02<00:02,  5.60it/s]

torch.Size([283, 2254])
torch.Size([2, 574])
Data(x=[283, 2254], edge_index=[2, 574], y=[1], row_id=100312, dir='13', readout_mask=[283])


Vectorizing train graphs:  54%|█████▍    | 13/24 [00:02<00:02,  5.24it/s]

torch.Size([90, 2254])
torch.Size([2, 175])
Data(x=[90, 2254], edge_index=[2, 175], y=[1], row_id=10045, dir='21', readout_mask=[90])
torch.Size([311, 2254])
torch.Size([2, 614])
Data(x=[311, 2254], edge_index=[2, 614], y=[1], row_id=100422, dir='20', readout_mask=[311])


Vectorizing train graphs:  62%|██████▎   | 15/24 [00:02<00:01,  5.66it/s]

torch.Size([122, 2254])
torch.Size([2, 246])
Data(x=[122, 2254], edge_index=[2, 246], y=[1], row_id=100374, dir='17', readout_mask=[122])
torch.Size([222, 2254])
torch.Size([2, 386])
Data(x=[222, 2254], edge_index=[2, 386], y=[1], row_id=100069, dir='1', readout_mask=[222])


Vectorizing train graphs:  71%|███████   | 17/24 [00:03<00:01,  5.40it/s]

torch.Size([218, 2254])
torch.Size([2, 466])
Data(x=[218, 2254], edge_index=[2, 466], y=[1], row_id=10019, dir='9', readout_mask=[218])
torch.Size([290, 2254])
torch.Size([2, 536])
Data(x=[290, 2254], edge_index=[2, 536], y=[1], row_id=100230, dir='12', readout_mask=[290])


Vectorizing train graphs:  79%|███████▉  | 19/24 [00:03<00:00,  5.72it/s]

torch.Size([287, 2254])
torch.Size([2, 567])
Data(x=[287, 2254], edge_index=[2, 567], y=[1], row_id=100541, dir='27', readout_mask=[287])
torch.Size([274, 2254])
torch.Size([2, 563])
Data(x=[274, 2254], edge_index=[2, 563], y=[1], row_id=100130, dir='3', readout_mask=[274])


Vectorizing train graphs:  88%|████████▊ | 21/24 [00:03<00:00,  5.91it/s]

torch.Size([237, 2254])
torch.Size([2, 482])
Data(x=[237, 2254], edge_index=[2, 482], y=[1], row_id=100418, dir='19', readout_mask=[237])
torch.Size([547, 2254])
torch.Size([2, 1147])
Data(x=[547, 2254], edge_index=[2, 1147], y=[1], row_id=100360, dir='15', readout_mask=[547])


Vectorizing train graphs:  96%|█████████▌| 23/24 [00:04<00:00,  6.07it/s]

torch.Size([321, 2254])
torch.Size([2, 663])
Data(x=[321, 2254], edge_index=[2, 663], y=[1], row_id=100535, dir='26', readout_mask=[321])
torch.Size([135, 2254])
torch.Size([2, 292])
Data(x=[135, 2254], edge_index=[2, 292], y=[1], row_id=100101, dir='2', readout_mask=[135])


Vectorizing train graphs: 100%|██████████| 24/24 [00:04<00:00,  5.54it/s]


Processing test set...


Vectorizing test graphs:  17%|█▋        | 1/6 [00:00<00:00,  6.29it/s]

torch.Size([296, 2254])
torch.Size([2, 573])
Data(x=[296, 2254], edge_index=[2, 573], y=[1], row_id=100490, dir='23', readout_mask=[296])
torch.Size([170, 2254])
torch.Size([2, 364])
Data(x=[170, 2254], edge_index=[2, 364], y=[1], row_id=100182, dir='7', readout_mask=[170])


Vectorizing test graphs:  50%|█████     | 3/6 [00:00<00:00,  5.63it/s]

torch.Size([238, 2254])
torch.Size([2, 353])
Data(x=[238, 2254], edge_index=[2, 353], y=[1], row_id=100213, dir='11', readout_mask=[238])
torch.Size([446, 2254])
torch.Size([2, 882])
Data(x=[446, 2254], edge_index=[2, 882], y=[1], row_id=100362, dir='16', readout_mask=[446])


Vectorizing test graphs:  83%|████████▎ | 5/6 [00:00<00:00,  5.52it/s]

torch.Size([261, 2254])
torch.Size([2, 555])
Data(x=[261, 2254], edge_index=[2, 555], y=[1], row_id=100185, dir='8', readout_mask=[261])
torch.Size([176, 2254])
torch.Size([2, 354])
Data(x=[176, 2254], edge_index=[2, 354], y=[1], row_id=100477, dir='22', readout_mask=[176])


Vectorizing test graphs: 100%|██████████| 6/6 [00:01<00:00,  5.77it/s]


In [10]:
train_graphs[0]

Data(x=[241, 2254], edge_index=[2, 489], y=[1], entities_list=[240], row_id=10056, dir='29', readout_mask=[241])

In [11]:
np.where(np.array(["sepsis" in  " ".join(train_graphs[i].entities_list).lower() for i in range(len(train_graphs))]))

(array([ 6049,  7658, 11966]),)

In [13]:
train_graphs[6049].row_id

48495

In [14]:
train_graphs[6049].entities_list

['GI Bleed',
 'Whipple disease',
 'Diabetes Mellitus (DM)',
 'Hypertension (HTN)',
 'Hyperlipidemia (HLD)',
 'Urosepsis',
 'Melena',
 'Obstructive stone (5 mm in right UPJ)',
 'Percutaneous Nephrostomy (PNT)',
 'Esophagogastroduodenoscopy (EGD)',
 'Capsule study',
 'Gastroesophageal (GE) junction blockage',
 'Hematocrit (HCT) drop',
 'Red Blood Cell (RBC) transfusions',
 'Colonoscopy',
 'Outpatient follow-up',
 'Productive cough',
 'Left leg knee pain',
 'Intramuscular injection (back of knee)',
 'Bilateral buttock rash',
 'Hemoglobin (Hb) 5.7',
 'Hematocrit (HCT) 19.1',
 'Melanotic stools',
 'Hypotension',
 'Intravenous (IV) Proton Pump Inhibitor (PPI)',
 'Packed Red Blood Cells (PRBC)',
 'NPO (nothing by mouth) for scope',
 'Urinalysis (UA) with infection',
 'Leukocytes in urine',
 'Blood in urine',
 'Red Blood Cells (RBC) in urine',
 'White Blood Cells (WBC) in urine',
 'Bacteria in urine',
 'Ceftriaxone (CTX)',
 'Insulin',
 'Dextrose',
 'Calcium gluconate',
 'Creatinine (Cr) 3.3',


In [15]:
note = json.load(open(f"batched_notes/{train_graphs[6049].dir}.json", "r"))
print(note["text"])

 
Name:  ___                  Unit No:   ___
 
Admission Date:  ___              Discharge Date:   ___
 
Date of Birth:  ___             Sex:   M
 
Service: MEDICINE
 
Allergies: 
Penicillins / flu vaccine
 
Attending: ___.
 
Chief Complaint:
GI Bleed
 
History of Present Illness:
 Mr. ___ is an ___ year old man with history of Whipple in 
___, now insulin dependent, HTN, HLD who is admitted to the 
MICU for GI bleed. 

Patient was recently admitted from ___ to ___ for 
urospesis complicated by melena. He was found to have a 5 mm 
obstructive stone in his right UPJ s/p PNT. On ___ patient's 
hct dropped from 30 to 20 requiring 6 RBC transfusions. Plan was 
to follow up as an outpatient. GI was consulted at the time and 
performed EGD showing no source for the bleed. Capsule study was 
then attempted but incomplete because it became blocked at the 
GE junction. Given that his HCT remained stable, GI recommended 
colonoscopy and capsule study as an outpatient. Hct on discharge 
29.6, HCT

In [16]:
session = get_session()
valid_row_ids = get_valid_note_row_ids(session)

KeyboardInterrupt: 

In [53]:
from torch_geometric.nn.conv import SAGEConv, GCNConv

class GNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout, heads=1, heads_dropout = .0):
       super(GNN, self).__init__()
       self.conv1 = SAGEConv(in_dim, hidden_dim, project=False, aggr="sum") #SAGEConv(in_dim, hidden_dim, project=False, aggr="sum")#GATConv(in_dim, hidden_dim,  add_self_loops=False, heads=heads, concat=True) #edge_dim=hidden_dim,
       self.conv2 = GATConv(hidden_dim*1, 1, add_self_loops=False, concat=False, heads=heads, dropout=heads_dropout) # edge_dim=hidden_dim,
       self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x, edge_index, **kwargs):
        # x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.dropout(x)
        x = self.conv1(x, edge_index)#, edge_attr
        # x = F.relu(x)
        x = self.dropout(x)
        x, (edge_index_attn, alpha) = self.conv2(x, edge_index, return_attention_weights=True) #, edge_attr
        return x, (edge_index_attn, alpha)



In [54]:
 def evaluate(model, loader, device):
    """Evaluates the model and returns AUROC and AUPRC."""
    model = model.to(device)
    with torch.inference_mode():
       model.eval()
       pred_probas = []
       y_trues = []
       for batch_data in loader:
          batch_data = batch_data.to(device)
          logit, _ = model(batch_data.x, batch_data.edge_index, edge_attr=batch_data.edge_attr)
          logit = logit[batch_data.readout_mask]
          pred_proba = torch.sigmoid(logit)
          pred_probas.extend(pred_proba.cpu().tolist())
          y_trues.extend(batch_data.y.cpu().tolist())

    y_true = np.array(y_trues)
    y_pred_proba = np.array(pred_probas)
    auroc = metrics.roc_auc_score(y_true, y_pred_proba)
    auprc = metrics.average_precision_score(y_true, y_pred_proba)
    return auroc, auprc

In [55]:
N_SPLITS = 3
MAX_EVALS = 50
SEED = 50 #0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
final_generator = set_all_seeds(SEED)

In [56]:
space = {
    # "hidden_dim": hp.choice('hidden_dim', [8, 16, 32]),
    "dropout": hp.loguniform('dropout', np.log(1e-2), np.log(1e-1)),
    "heads_dropout": hp.loguniform('heads_dropout', np.log(1e-2), np.log(2e-1)),
    "lr": hp.loguniform('lr', np.log(1e-3), np.log(1e-2)),
    # "weight_decay": hp.loguniform('weight_decay', np.log(1e-6), np.log(1e-3)),
    "weight_decay": hp.choice('weight_decay', [0]),
    "heads": hp.quniform('heads', 1, 6, 1),
    # "batch_size": hp.quniform('batch_size', 16, 128, 16),
    "epochs": hp.quniform('epochs', 10, 200, 1) # Epochs are now tunable
}

In [73]:
y_train_full = [d.y.item() for d in train_graphs]
# best_hyperparams = {'batch_size': 256, 'dropout': 0.8, 'epochs': 600, 'heads': 2, 'hidden_dim': 16, 'lr': .01, 'weight_decay': 1e-5, "heads_dropout": .1}
# best_hyperparams = {'batch_size': 256, 'dropout': 0.0, 'epochs': 40, 'heads': 2, 'hidden_dim': 1, 'lr': .01, 'weight_decay': 1e-5, "heads_dropout": .1}
best_hyperparams = {'batch_size': 256, 'dropout': 0.3, 'epochs': 100, 'heads': 1, 'hidden_dim': 4, 'lr': .01, 'weight_decay': 0, "heads_dropout": 0.2}

In [74]:
def train(loss_fn, loader, best_hyperparams):
    model = GNN(
                in_dim=train_graphs[0].x.shape[1],
                hidden_dim=best_hyperparams["hidden_dim"],
                dropout=best_hyperparams["dropout"],
                heads=best_hyperparams["heads"],
                heads_dropout=best_hyperparams["heads_dropout"]
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=best_hyperparams["lr"], weight_decay=best_hyperparams["weight_decay"])
    for epoch in range(best_hyperparams['epochs']):
        epoch_loss = 0
        model.train()
        for batch_data in loader:
           optimizer.zero_grad()
           batch_data = batch_data.to(device)
           logit, _ = model(batch_data.x, batch_data.edge_index)
           logit = logit[batch_data.readout_mask]
           loss = loss_fn(logit.squeeze(), batch_data.y)
           loss.backward()
           optimizer.step()
           epoch_loss += loss.item()
        avg_loss = epoch_loss / len(final_train_loader)
        #print(f"Epoch: {epoch} Train AUROC: {auroc_train:.4f} | Test AUROC: {auroc_test:.4f}")
    return model

In [75]:
from operator import itemgetter

def tune_hyperparameters(train_graphs, y_train_full, space_params, k = 3, max_evals = 50, maximize_metric = True):

    def objective(params):
        if 'epochs' in params:
            params['epochs'] = int(params['epochs'])
        if 'heads' in params:
            params['heads'] = int(params['heads'])

        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
        scores = []

        for train_index, val_index in skf.split(train_graphs, y_train_full):
            inner_train_graphs = itemgetter(*train_index)(train_graphs)
            inner_val_graphs = itemgetter(*val_index)(train_graphs)
            inner_train_graphs = [graph.to(device) for graph in inner_train_graphs]
            inner_val_graphs = [graph.to(device) for graph in inner_val_graphs]
            inner_train_loader = DataLoader(
                inner_train_graphs,
                batch_size=256,
                shuffle=True,
                worker_init_fn=seed_worker,
                generator=final_generator
            )
            inner_val_loader = DataLoader(
                inner_val_graphs,
                batch_size=256,
                shuffle=False
            )
            pos_weight = ceil(sum([data.y[0] == 0 for data in inner_train_graphs]) / sum([data.y[0] == 1 for data in inner_train_graphs]))
            loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
            model = train(loss_fn, inner_train_loader, best_hyperparams)
            auroc_test, auprc_test = evaluate(model, inner_val_loader, device)
            #auroc_train, auprc_train = evaluate(model, final_train_loader, device)
            scores.append(auroc_test)

        average_score = np.mean(scores)
        
        loss = -average_score if maximize_metric else average_score

        return {'loss': loss, 'status': STATUS_OK}

    trials = Trials()
    best_params = fmin(
        fn=objective,
        space=space_params,
        algo=tpe.suggest,
        max_evals=max_evals,
        trials=trials,
        rstate=np.random.default_rng(42)
    )

    best_metric_score = -trials.best_trial['result']['loss'] if maximize_metric else trials.best_trial['result']['loss']

    return best_params, best_metric_score

In [76]:
# best_hyperparams, best_metric_score = tune_hyperparameters(train_graphs, y_train_full, space, max_evals=100, maximize_metric = True, k = 5)
# best_hyperparams, best_metric_score

In [77]:
## TODO add self loop for all besides the last node

In [78]:
# DataLoaders for final training
# train_graphs = [graph.to(device) for graph in train_graphs]
# test_graphs = [graph.to(device) for graph in test_graphs]
final_train_loader = DataLoader(
    train_graphs,
    batch_size=256,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=final_generator
)
# Test loader does not need to be shuffled
test_loader = DataLoader(
    test_graphs,
    batch_size=32,
    shuffle=False
)


# Instantiate the final model
final_model = GNN(
    in_dim=train_graphs[0].x.shape[1],
    hidden_dim=1,
    dropout=best_hyperparams["dropout"],
    heads=int(best_hyperparams["heads"]),
    heads_dropout=best_hyperparams["heads_dropout"]
).to(device)

In [79]:
optimizer = optim.Adam(
    final_model.parameters(),
    lr=best_hyperparams["lr"],
    weight_decay=best_hyperparams["weight_decay"]
)

# Recalculate pos_weight on the full training data
pos_weight = ceil(sum([data.y[0] == 0 for data in train_graphs]) / sum([data.y[0] == 1 for data in train_graphs]))
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

In [80]:
num_epochs_final = 10 #int(best_hyperparams['epochs'])
for epoch in range(num_epochs_final):
    epoch_loss = 0
    final_model.train()
    for batch_data in final_train_loader:
       optimizer.zero_grad()
       batch_data = batch_data.to(device)
       logit, _ = final_model(batch_data.x, batch_data.edge_index)
       logit = logit[batch_data.readout_mask]
       loss = loss_fn(logit.squeeze(), batch_data.y)
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
    avg_loss = epoch_loss / len(final_train_loader)
    auroc_test, auprc_test = evaluate(final_model, test_loader, device)
    auroc_train, auprc_train = evaluate(final_model, final_train_loader, device)

    print(f"Epoch: {epoch} Train AUROC: {auroc_train:.4f} | Test AUROC: {auroc_test:.4f}")

Epoch: 0 Train AUROC: 0.7884 | Test AUROC: 0.7268
Epoch: 1 Train AUROC: 0.8686 | Test AUROC: 0.8130
Epoch: 2 Train AUROC: 0.8954 | Test AUROC: 0.8356
Epoch: 3 Train AUROC: 0.9204 | Test AUROC: 0.8304
Epoch: 4 Train AUROC: 0.9398 | Test AUROC: 0.8125
Epoch: 5 Train AUROC: 0.9531 | Test AUROC: 0.8023
Epoch: 6 Train AUROC: 0.9633 | Test AUROC: 0.7990
Epoch: 7 Train AUROC: 0.9717 | Test AUROC: 0.8012
Epoch: 8 Train AUROC: 0.9755 | Test AUROC: 0.8031
Epoch: 9 Train AUROC: 0.9791 | Test AUROC: 0.7984


In [81]:
train_labels = np.array(list(map(lambda x: x.y[0].item(), train_graphs)))
np.where(train_labels == 1)

(array([   39,    72,   112,   152,   227,   229,   233,   256,   266,
          292,   347,   397,   406,   441,   470,   473,   494,   512,
          517,   573,   593,   606,   646,   657,   658,   670,   757,
          759,   782,   810,   849,   853,   858,   920,   943,   953,
          980,  1032,  1044,  1051,  1081,  1091,  1116,  1151,  1166,
         1167,  1212,  1229,  1245,  1249,  1270,  1275,  1292,  1429,
         1486,  1520,  1558,  1581,  1647,  1673,  1677,  1719,  1723,
         1753,  1769,  1784,  1815,  1932,  1952,  1999,  2029,  2059,
         2063,  2067,  2074,  2089,  2111,  2117,  2118,  2120,  2136,
         2141,  2150,  2201,  2251,  2272,  2288,  2292,  2317,  2335,
         2336,  2345,  2369,  2423,  2468,  2481,  2551,  2602,  2641,
         2669,  2730,  2741,  2774,  2800,  2804,  2857,  2874,  2905,
         2962,  2977,  2988,  3032,  3066,  3067,  3086,  3093,  3100,
         3117,  3176,  3183,  3186,  3227,  3233,  3248,  3299,  3300,
      

In [None]:
test_labels = np.array(list(map(lambda x: x.y[0].item(), test_graphs)))
np.where(test_labels == 1)

In [None]:
def get_logits_and_attn_heads(batch_data):
    with torch.inference_mode():
        final_model.eval()
        logits, (edge_idx, attn_out) = final_model(batch_data.x,batch_data.edge_index)
        pred_proba = torch.sigmoid(logits[-1])
    return logits[-1], attn_out

def get_influence_scores(batch_data):
    logits_control, attn_out = get_logits_and_attn_heads(batch_data)
    mask = batch_data.edge_index[1] == (batch_data.x.shape[0]-1)
    assert attn_out[mask].shape[0] == len(batch_data.entities_list), "Doesnt match entities"
    assert (torch.sort(batch_data.edge_index[0][mask])[0] != batch_data.edge_index[0][mask]).sum() == 0, "Is not sorted"
    #print(torch.sigmoid(logits_control), batch_data.y)
    
    w_l = final_model.state_dict()["conv1.lin_l.weight"]
    w_r = final_model.state_dict()["conv1.lin_r.weight"]
    rlt = final_model.state_dict()["conv2.lin.weight"]
    bl = final_model.state_dict()['conv1.lin_l.bias']
    b2 = final_model.state_dict()['conv2.bias']
        
    left_aggr = torch.zeros_like(batch_data.x[:-1])
    lifted_nodes = torch.index_select(batch_data.x[:-1], 0, batch_data.edge_index[0, ~mask])
    left_aggr.scatter_reduce_(0,batch_data.edge_index[1, ~mask].repeat(batch_data.x.shape[-1], 1).t(), lifted_nodes, reduce="sum")
    
    transformed_left = left_aggr @ w_l.t()
    transformed_right = batch_data.x[:-1] @ w_r.t()
    sage_out = transformed_left + transformed_right + bl
    sage_out_transformed = sage_out @ rlt.t()
    sage_out_weighted_transformed = sage_out_transformed * attn_out[mask]
    logits = (sage_out_weighted_transformed.mean(-1)).sum() + b2
    assert torch.allclose(logits, logits_control, rtol=1e-4, atol=1e-9), f"Influence scores are wrong {logits.item()} - {logits_control.item()}"
    return sage_out_weighted_transformed.mean(-1).cpu()

In [None]:
get_influence_scores(test_graphs[100])

In [None]:
import matplotlib
from operator import itemgetter
import matplotlib.pyplot as plt

def plot_influence_scores(batch_data):
    ## Overall influence
    font = {'family' : 'arial',
            'weight' : 'bold',
            'size'   : 14}
    influence_scores = get_influence_scores(batch_data)
    matplotlib.rc('font', **font)
    plt.figure(figsize=(25, 4))
    sorted_idx = influence_scores.abs().argsort(descending=True)
    plt.bar(itemgetter(*sorted_idx)(batch_data.entities_list), influence_scores[sorted_idx])
    plt.xticks(rotation=45, ha='right')
    plt.title("Overall influence")
    plt.xlabel("Entities")
    plt.ylabel("Influence weight on prediction")
    plt.grid(which="both")
    plt.show()
    
plot_influence_scores(train_graphs[-1])

In [None]:
def get_global_influence_scores():
    summed_aggregated_influence_scores = torch.zeros(train_graphs[0].x.shape[-1])
    graphs = train_graphs
    for batch_data in graphs:
        influence_scores = get_influence_scores(batch_data)
        aggregated_influence_scores = torch.zeros_like(batch_data.x).cpu()
        non_zero_idx = torch.nonzero(batch_data.x).cpu()
        ## TODO Think about some kind of normalization here ? 
        aggregated_influence_scores[non_zero_idx[:, 0], non_zero_idx[:, 1]] = influence_scores[non_zero_idx[:, 0]]
        summed_aggregated_influence_scores += aggregated_influence_scores.sum(0)
    return summed_aggregated_influence_scores / len(graphs)

global_influence_scores = get_global_influence_scores()

In [None]:
top_k = 20
global_influence_scores_sort_idx = global_influence_scores.argsort(descending=True)[:top_k]
plt.figure(figsize=(25, 4))
plt.bar(vectorizer.get_feature_names_out()[global_influence_scores_sort_idx], global_influence_scores[global_influence_scores_sort_idx])
plt.xticks(rotation=45, ha='right')
plt.title("Global influence of tokens")
plt.xlabel("Tokens")
plt.ylabel("Global influence weight")
plt.grid(which="both")
plt.show()

In [None]:
## Check how important edges are overall - is solely readout enough and dont we even need the links between entities or are they important? -> If not KG procedure is not necessary

In [None]:
sage_out_weighted_transformed.mean(-1).cpu().sum()

In [None]:

itemgetter(*attn_out.mean(-1)[mask].argsort(descending=True))(batch_data.entities_list), attn_out.mean(-1)[mask].sort(descending=True)

In [None]:
#overall_influence_copy = torch.clone(overall_influence)
left_aggr_copy = torch.clone(left_aggr)

In [None]:
import torch

# Let's assume 'final_model' and 'batch_data' are defined as in your example.
# batch_data.x is assumed to have shape [49, F]

# --- Step 1: Get reference output from the full model ---
with torch.inference_mode():
    final_model.eval()
    # model_logits will have shape [49, 1]
    model_logits, (edge_index_attn, alpha) = final_model(batch_data.x, batch_data.edge_index)

# --- Step 2: Manually reproduce the forward pass on the FULL graph ---
with torch.inference_mode():
    final_model.eval()
    # Weight extraction is correct
    w_l = final_model.state_dict()["conv1.lin_l.weight"]
    w_r = final_model.state_dict()["conv1.lin_r.weight"]
    rlt = final_model.state_dict()["conv2.lin.weight"]
    bl = final_model.state_dict()['conv1.lin_l.bias']
    b2 = final_model.state_dict()['conv2.bias']

    all_nodes = batch_data.x
    num_nodes = all_nodes.shape[0]  # This will now be 49

    left_aggr = torch.zeros_like(all_nodes)
    
    lifted_nodes = torch.index_select(all_nodes, 0, batch_data.edge_index[0, ~mask])
    
    index_for_sage = batch_data.edge_index[1, ~mask].repeat(batch_data.x.shape[-1], 1).t()
    left_aggr.scatter_reduce_(0, index_for_sage, lifted_nodes, reduce="sum", include_self=False)

    overall_influence = (all_nodes @ w_r.t()) + (left_aggr @ w_l.t()  + bl)
    x_transformed = overall_influence @ rlt.t()

    source_node_indices = edge_index_attn[0]
    target_node_indices = edge_index_attn[1]
    
    source_features_transformed = x_transformed[source_node_indices]
    weighted_messages = source_features_transformed * alpha
    aggregated_output = torch.zeros(num_nodes, rlt.shape[0], device=all_nodes.device)
    index_for_gat = target_node_indices.unsqueeze(1).expand_as(weighted_messages)
    aggregated_output.scatter_add_(0, index_for_gat, weighted_messages)
    
    manual_logits = aggregated_output.mean(dim=1, keepdim=True) + b2

    print("Manual logits match model logits:", torch.allclose(manual_logits, model_logits))
    print("Max Difference:", (manual_logits - model_logits).abs().max()) 

In [None]:
model_logits

In [None]:
final_model = final_model.to('cuda')

In [None]:
import torch

# --- 1. SET UP THE DEVICE ---
# This is the most important step. Define the device to use.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# --- 2. MOVE MODEL AND DATA TO THE DEVICE ---
# Let's assume 'final_model' and 'batch_data' are loaded and are on the CPU initially.

final_model.to(device)
final_model.eval() # Set to evaluation mode

# This is the crucial fix: move the entire data batch to the selected device.
batch_data = batch_data.to(device)


# Now, all tensors involved in the calculation will be on the same device.
with torch.inference_mode():
    # --- Step 1: Get reference output from the full model ---
    # This will now run entirely on the specified `device`.
    model_logits, (edge_index_attn, alpha) = final_model(batch_data.x, batch_data.edge_index)

    # --- Step 2: Extract weights (they are already on the correct device) ---
    state_dict = final_model.state_dict()
    w_l = state_dict["conv1.lin_l.weight"]
    w_r = state_dict["conv1.lin_r.weight"]
    rlt = state_dict["conv2.lin.weight"]
    bl = state_dict['conv1.lin_l.bias']
    b2 = state_dict['conv2.bias']

    # --- Step 3: Manual SAGEConv Layer ---
    # `all_nodes` is now on the correct device because `batch_data` was moved.
    all_nodes = batch_data.x 
    num_nodes = all_nodes.shape[0]

    # `zeros_like` will create a tensor on the same device as `all_nodes`.
    left_aggr = torch.zeros_like(all_nodes)
    
    lifted_nodes = torch.index_select(all_nodes, 0, batch_data.edge_index[0, ~mask])
    
    index_for_sage = batch_data.edge_index[1, ~mask].unsqueeze(1).expand_as(lifted_nodes)
    left_aggr.scatter_reduce_(0, index_for_sage, lifted_nodes, reduce="sum", include_self=False)

    # This matrix multiplication will now work, as all tensors are on the same device.
    overall_influence = (all_nodes @ w_r.t()) + (left_aggr @ w_l.t()) + bl
    
    # --- Step 4: Direct GATConv Calculation for the Readout Node ---
    x_transformed = overall_influence @ rlt.t()

    readout_node_idx = num_nodes - 1
    is_edge_to_readout = (edge_index_attn[1] == readout_node_idx)
    
    source_nodes_for_readout = edge_index_attn[0][is_edge_to_readout]
    attention_for_readout = alpha[is_edge_to_readout]

    source_features = x_transformed[source_nodes_for_readout]
    
    weighted_messages = source_features * attention_for_readout
    aggregated_message = weighted_messages.sum(dim=0)

    manual_readout_logit = aggregated_message.mean() + b2

    # --- Verification ---
    final_model_logit = model_logits[-1]
    
    print(f"Manual Readout Logit: {manual_readout_logit.item()}")
    print(f"Model Readout Logit:    {final_model_logit.item()}")
    print("Logits Match:", torch.allclose(manual_readout_logit, final_model_logit))

In [None]:
torch.allclose(x_transformed[:-1].mean(-1), sage_out_transformed)

In [None]:
torch.allclose(left_aggr @ w_l.t(), left_aggr * w_l.squeeze())

In [None]:
manual_logits[-1], model_logits[-1]

In [None]:
weighted_messages.mean(-1)[mask].sum(), model_logits

In [None]:
with torch.inference_mode():
    final_model.eval()
    logits = final_model(batch_data.x, batch_data.edge_index)
logits[0]

In [None]:
((overall_influence.scatter_reduce(0, batch_data.edge_index[1, ~mask].cpu(), lifted_neighbor_influence, reduce="sum").cpu() * rlt[:].mean(0).cpu()) * attn_out[mask, :].mean(-1).cpu()).sum() 

In [None]:
batch_data.edge_index

In [None]:
attn_out[mask, :].cpu().shape

In [None]:
batch_data.x[:-1].nonzero() @ w_r.squeeze() * rlt[:].mean() * attn_out[mask, :].mean(-1)

In [None]:
batch_data.x[:-1].nonzero()

In [None]:

# import matplotlib
# import matplotlib.pyplot as plt
# ## Overall  Self influence
# font = {'family' : 'arial',
#         'weight' : 'bold',
#         'size'   : 14}
# 
# matplotlib.rc('font', **font)
# plt.figure(figsize=(25, 4))
attn_identity_matrix = torch.zeros_like(batch_data.x[:-1].cpu()).cpu()
indices = batch_data.x[:-1].nonzero().cpu()
attn_identity_matrix[indices[:, 0], indices[:, 1]] = 1
attn_identity_matrix = attn_identity_matrix.to(device)
## TODO: Need to get global influence from here

self_influence = (attn_identity_matrix @ w_r.squeeze() * rlt[:].mean() * attn_out[mask, :].mean(-1)).cpu()

# plt.bar(batch_data.entities_list, self_influence)
# plt.xticks(rotation=45, ha='right')
# plt.title("Overall Self influence")
# plt.xlabel("Entities")
# plt.ylabel("Influence weight on prediction")
# plt.grid(which="both")
# plt.show()

In [None]:
batch_data.x[:-1].nonzero().cpu()

In [None]:
attn_identity_matrix.shape

In [None]:
attn_identity_matrix

In [None]:
non_zero_idx = batch_data.x.nonzero() #w_r.squeeze() * rlt[:].mean() * attn_out[mask, :].mean(-1)
torch.index_select(attn_out[mask], 0, non_zero_idx[:, 0])

In [None]:
stacked_gcn_influence = torch.cat((final_model.state_dict()["conv1.lin_l.weight"], final_model.state_dict()["conv1.lin_r.weight"]), dim = 0).transpose(-1, 0)
stacked_gcn_influence.mean(1), stacked_gcn_influence.std(1)

In [None]:
(vectorizer.get_feature_names_out().shape, final_model.state_dict()["conv1.lin_r.weight"].shape)

In [None]:
final_model.state_dict()["conv1.lin_r.weight"].squeeze()

In [None]:
import matplotlib.pyplot as plt
top_n = 50
fig = plt.figure(figsize=(20, 5))
sort_idx = final_model.state_dict()["conv1.lin_r.weight"].squeeze().abs().argsort(descending=True).cpu()
plt.bar(vectorizer.get_feature_names_out()[sort_idx][:top_n], final_model.state_dict()["conv1.lin_r.weight"].squeeze()[sort_idx][:top_n].cpu())


In [None]:
## Least important
vectorizer.get_feature_names_out()[sort_idx][-top_n:]

In [None]:
final_model.state_dict()

In [None]:
sepsis_attn_outs, control_attn_outs = [], []
sepsis_entities, control_entities = [], []
for i, batch_data in enumerate(train_graphs):
    batch_data = batch_data.to(device)
    with torch.inference_mode():
        final_model.eval()
        logits, (edge_idx, attn_out) = final_model(batch_data.x,batch_data.edge_index)
        # logits = torch.sigmoid(logits[-1])
    mask = edge_idx[1] == (batch_data.x.shape[0]-1)
    assert attn_out[mask].shape[0] == len(batch_data.entities_list), "Doesnt match entities"
    assert (torch.sort(edge_idx[0][mask])[0] != edge_idx[0][mask]).sum() == 0, "Is not sorted"
    if batch_data.y[0] == 1:
        sepsis_entities.extend(batch_data.entities_list)
        print(batch_data.entities_list)
        print(batch_data.x.nonzero())
        raise Exception("")
        sepsis_attn_outs.append(attn_out[mask])
    if batch_data.y[0] == 0:
        control_entities.extend(batch_data.entities_list)
        control_attn_outs.append(attn_out[mask])
    # if i == 2:
    #     break
sepsis_attn_outs = torch.cat(sepsis_attn_outs, dim  = 0)
control_attn_outs = torch.cat(control_attn_outs, dim  = 0)

In [None]:
vectorizer.transform(["IDDM with periperal neuropathy"]).toarray().nonzero()

In [None]:
from operator import itemgetter
print(itemgetter(*[2482, 3900, 5143, 5733, 8203])(vectorizer.get_feature_names_out()))


In [None]:
import pandas as pd
sepsis_attn_df = pd.DataFrame()
sepsis_attn_df["entities"] = sepsis_entities
sepsis_attn_df["attn_out_0"] = sepsis_attn_outs[:, 0].cpu()
sepsis_attn_df["attn_out_1"] = sepsis_attn_outs[:, 1].cpu()
sepsis_attn_df["attn_out_mean"] = sepsis_attn_outs.mean(-1).cpu()
sepsis_attn_df.sort_values(by=["attn_out_mean"], ascending=False, inplace=True)
sepsis_attn_df

In [None]:
import pandas as pd
control_attn_df = pd.DataFrame()
control_attn_df["entities"] = control_entities
control_attn_df["attn_out_0"] = control_attn_outs[:, 0].cpu()
control_attn_df["attn_out_1"] = control_attn_outs[:, 1].cpu()
control_attn_df["attn_out_mean"] = control_attn_outs.mean(-1).cpu()
control_attn_df.sort_values(by=["attn_out_mean"], ascending=False, inplace=True)
control_attn_df

In [None]:
pd.concat((control_attn_df, sepsis_attn_df) , axis = 0)

In [None]:
print(sepsis_attn_df.iloc[:20, :])

In [None]:
data = train_graphs[80].to(device)
with torch.inference_mode():
    final_model.eval()
    logits, (edge_idx, attn_out) = final_model(data.x,data.edge_index)
    logits = torch.sigmoid(logits[-1])
attn_out.shape

In [None]:
edge_idx

In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from typing import List
import math

def plot_graph(
    edge_index: torch.Tensor,
    node_labels: List[str],
    edge_weights: torch.Tensor,
    node_size: int = 2000,
    node_color: str = 'skyblue',
    central_node_color: str = 'lightcoral',
    font_size: int = 12,
    font_color: str = 'black',
    edge_width: float = 2.0,
    figure_size: tuple = (12, 12)
):
    """
    Plots a graph with a custom layout where the last node is central.
    Node distance is based on the inverse of edge weight to the central node.

    Args:
        edge_index (torch.Tensor): A tensor of shape (2, E) representing the edges in COO format.
        node_labels (List[str]): A list of N strings for node labels.
        edge_weights (torch.Tensor): A tensor of shape (1, E) or (E,) with edge weights between 0 and 1.
        node_size (int): The size of the nodes.
        node_color (str): The color of the non-central nodes.
        central_node_color (str): The color of the central node.
        font_size (int): The font size of the node labels.
        font_color (str): The color of the node labels.
        edge_width (float): The width of the edges.
        figure_size (tuple): The size of the plot figure.
    """
    # --- Graph and Data Preparation ---
    graph = nx.Graph()
    num_nodes = len(node_labels)
    graph.add_nodes_from(range(num_nodes))
    labels = {i: label for i, label in enumerate(node_labels)}
    edge_list = edge_index.t().tolist()
    weights_list = edge_weights.squeeze().tolist()
    
    # Create a lookup dictionary for edge weights for efficient access
    edge_to_weight = {tuple(sorted(edge)): weight for edge, weight in zip(edge_list, weights_list)}

    # --- Custom Layout Calculation ---
    pos = {}
    central_node_idx = num_nodes - 1
    other_node_indices = [i for i in range(num_nodes) if i != central_node_idx]
    
    # Place central node at the origin
    pos[central_node_idx] = (0, 0)
    
    # Arrange other nodes in a circle around the central node
    angle_step = 2 * math.pi / len(other_node_indices)
    max_radius = 1.0 # Base radius for layout scaling
    
    for i, node_idx in enumerate(other_node_indices):
        weight = edge_to_weight.get(tuple(sorted((node_idx, central_node_idx))))
        
        # Distance is inversely proportional to weight. Add a small constant to avoid zero distance.
        # If no edge exists, place it at the maximum distance.
        radius = max_radius * (1.1 - weight) if weight is not None else max_radius * 1.2
        
        angle = i * angle_step
        pos[node_idx] = (radius * math.cos(angle), radius * math.sin(angle))

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=figure_size)
    custom_cmap = LinearSegmentedColormap.from_list('blue_purple_red', ['blue', 'purple', 'red'])

    # Assign colors to nodes
    node_colors = [central_node_color if i == central_node_idx else node_color for i in range(num_nodes)]

    # Draw nodes and labels
    nx.draw_networkx_nodes(graph, pos, node_size=node_size, node_color=node_colors, ax=ax)
    nx.draw_networkx_labels(graph, pos, labels, font_size=font_size, font_color=font_color, ax=ax)
    
    # Draw edges with heatmap
    edges = nx.draw_networkx_edges(
        graph,
        pos,
        edgelist=edge_list,
        edge_color=weights_list,
        edge_cmap=custom_cmap,
        width=edge_width,
        ax=ax
    )

    # Add a colorbar for edge weights
    sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.8)
    cbar.set_label('Edge Weight', rotation=270, labelpad=15)
    
    ax.set_title("Graph with Central Node Layout")
    ax.axis('equal') # Ensure the circular layout is not distorted
    ax.margins(0.1)
    plt.show()


data = test_graphs[80].to(device)
with torch.inference_mode():
    final_model.eval()
    logits, (edge_idx, attn_out) = final_model(data.x,data.edge_index)
    logits = torch.sigmoid(logits[-1])
    print(logits)
# Define the graph data
edge_index_tensor = data.edge_index
node_names = [*data.entities_list, "Read-Out"]
edge_weights_tensor = attn_out.mean(-1)

# Plot the graph
plot_graph(
    edge_index=edge_index_tensor,
    node_labels=node_names,
    edge_weights=edge_weights_tensor
)

In [None]:
readout_mask = edge_index_tensor[1] == (data.x.shape[0]-1)

data.entities_list[edge_index_tensor[0, readout_mask][2]]

In [None]:
final_model.state_dict()

In [None]:
edge_index_tensor[:, readout_mask]

In [None]:
edge_weights_tensor[readout_mask]

In [None]:
sorted_weights_idx = torch.argsort(edge_weights_tensor[readout_mask], descending=True)

In [None]:
from operator import itemgetter
print(itemgetter(*sorted_weights_idx)(data.entities_list))

In [None]:
print(itemgetter(*sorted_weights_idx)(edge_weights_tensor[readout_mask]))

In [None]:
edge_index_tensor[:, readout_mask][0, sorted_weights_idx]

In [None]:
data.x[7][data.x[7].nonzero().squeeze()]

In [None]:
final_model.state_dict()

In [None]:
tf_idf_weights = final_model.state_dict()['conv1.lin_r.weight']
tf_idf_weights_l = final_model.state_dict()['conv1.lin_l.weight']
#[data.x[7].nonzero().squeeze().cpu()]

In [None]:
sorted_w_l_idx = tf_idf_weights_l.abs().argsort(descending = True).cpu()
vectorizer.get_feature_names_out()[sorted_w_l_idx][0, :50]

In [None]:

data.x @ tf_idf_weights_l.squeeze()

In [None]:
data.x @ tf_idf_weights.squeeze()

In [None]:
[data.x[7].nonzero().squeeze().cpu()]

In [None]:
tf_idf_weights.squeeze()[data.x[7].nonzero().squeeze().cpu()]

In [None]:
raise Exception("")

In [None]:
vectorizer.get_feature_names_out()[data.x[7].nonzero().squeeze().cpu()]

In [None]:
read_out_mask = edge_idx[1] == data.x.shape[0]-1
attn_out[read_out_mask].mean(1)

In [None]:
final_model.state_dict()

In [None]:
## TODO Edge attr with edge attr TfIDF vectorizer

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer

device = torch.device("cpu")
final_model.eval()
final_model = final_model.to(device)
# Assume 'data' is your graph data object (e.g., from a PyG dataset)
data = test_graphs[0]

explainer = Explainer(
    model=final_model,
    algorithm=GNNExplainer(epochs=10),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='raw',
    ),
)
node_index = -1 # which node index to explain
explanation = explainer(data.x, data.edge_index, index=node_index)

In [None]:
explanation.visualize_feature_importance(top_k=5)