In [1]:
import pandas as pd
# show all columns
pd.set_option('display.max_columns', None)
import numpy as np
import sys
import os
import wandb
# mute wandb outputs
os.environ["WANDB_SILENT"] = "true"
import torch
from sklearn.metrics import f1_score

sys.path.append(os.path.abspath("/home/lideyi/AKI_GNN/notebooks/utils"))
from metrics import performance_per_class, visualize_embeddings
import copy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# login wandb
wandb.login(key="62d0c78e72de6dacd620fc6d13ebfecfa7ce68a1")

True

# Read Dataset

In [4]:
onset_df_pilot = pd.read_csv('/blue/yonghui.wu/lideyi/AKI_GNN/raw_data/norm_df_pilot.csv')

# Build PyG Data Object

In [5]:
from torch_geometric.data import Data
from sklearn.neighbors import kneighbors_graph

In [6]:
feature_columns = [col for col in onset_df_pilot.columns if col not in ['AKI_TARGET', 'TRAIN_SET', 'VAL_SET', 'TEST_SET']]
node_features = onset_df_pilot[feature_columns].copy(deep = True).values
node_labels = onset_df_pilot['AKI_TARGET'].copy(deep = True).values
train_mask = onset_df_pilot['TRAIN_SET'].copy(deep = True).values
val_mask = onset_df_pilot['VAL_SET'].copy(deep = True).values
test_mask = onset_df_pilot['TEST_SET'].copy(deep = True).values

In [None]:
# Generate a k-NN graph (e.g., k=5), note that the returned matrix is not symmetric
k = 5
A = kneighbors_graph(node_features, k, mode='connectivity', metric = 'cosine', include_self=False, n_jobs = -1).toarray()
# make adjacent matrix symmetric
A = A + A.T
# Ensure binary adjacent matrix
A = (A > 0).astype(int)
edge_index = (torch.tensor(A) > 0).nonzero().t().contiguous()
edge_index = edge_index.to(torch.long)

In [None]:
data = Data(x = torch.tensor(node_features, dtype = torch.float), 
            edge_index = edge_index, y = torch.tensor(node_labels, dtype = torch.long), 
            num_classes = len(np.unique(node_labels)),
            train_mask = torch.tensor(train_mask, dtype = torch.bool), 
            val_mask = torch.tensor(val_mask, dtype = torch.bool), 
            test_mask = torch.tensor(test_mask, dtype = torch.bool))
# Sorts by the destination nodes (edge_index[1]), as required for some models' aggregation.
data = data.sort(sort_by_row=False)

In [None]:
# analyse the graph
print(f'Number of features: {data.num_features}')
print(f'Number of classes: {data.num_classes}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
print(f'Is sorted by destination nodes: {data.is_sorted(sort_by_row = False)}')

# Wandb Train Functions

In [None]:
from wandb.sdk.wandb_config import Config
# turn the data into loader
from torch_geometric.loader import ClusterData, ClusterLoader

In [None]:
def evaluate_GNN(data: Data, build_GNN_func: callable, wandb_project_name: str, parameters: dict) -> pd.DataFrame:
    sweep_config = build_sweep_config(parameters)
    sweep_id = wandb.sweep(sweep_config, project = wandb_project_name)
    sweep_func = lambda: train_GNN_main(data = data, build_GNN_func = build_GNN_func, config = None)
    wandb.agent(sweep_id, sweep_func)
    performance = test_best_GNN(data, sweep_id, build_GNN_func)
    return performance

In [None]:
def build_sweep_config(parameters: dict) -> dict:
    sweep_config = {
    'method': 'grid',
    'metric': {'name': 'val_F1', 'goal': 'maximize'},
    'parameters': parameters,
    }
    return sweep_config

In [None]:
def train_GNN_main(data: Data, build_GNN_func: callable, config = None) -> None:
    # Initialize a new wandb run
    with wandb.init(config=config):
        config = wandb.config
        model = build_GNN_func(config)
        optimizer = build_optimizer(model, config.optimizer, config.lr)
        data_loader = build_dataloader(data, config.graph_num_parts, config.batch_size)
        train_GNN(model, config.epochs, optimizer, data_loader)

In [None]:
def build_optimizer(model: torch.nn.Module, optimizer: str, lr: float) -> torch.optim.Optimizer:
    if optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    elif optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    return optimizer

In [None]:
def build_dataloader(data: Data, graph_num_parts: int, batch_size: int) -> ClusterLoader:
    torch.manual_seed(888)
    cluster_data = ClusterData(data, num_parts=graph_num_parts, log = False)
    data_loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True) 
    return data_loader

In [None]:
def train_GNN(model: torch.nn.Module, epochs: int, optimizer: torch.optim.Optimizer, data_loader: ClusterLoader, 
              log: bool = True) -> None:
    for _ in range(epochs):
        avg_loss_train = train_epoch(model, optimizer, data_loader)
        train_F1, val_F1 = val_epoch(model, data_loader)
        if log:
            wandb.log({"train_loss": avg_loss_train, "train_F1": train_F1, "val_F1": val_F1})

In [None]:
def train_epoch(model: torch.nn.Module, optimizer: torch.optim.Optimizer, data_loader: ClusterLoader) -> tuple:
    model.train()
    total_loss = 0
    for sub_data in data_loader:
        sub_data = sub_data.to(device)
        optimizer.zero_grad()
        out = model(sub_data.x, sub_data.edge_index)
        loss = torch.nn.functional.cross_entropy(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)

In [None]:
def val_epoch(model: torch.nn.Module, data_loader: ClusterLoader) -> float:
    model.eval()  # Set the model to evaluation mode.
    
    # Store predictions and ground truths for each mask.
    y_true_masks = {key: [] for key in ["train", "val"]}
    y_pred_masks = {key: [] for key in ["train", "val"]}
    
    with torch.no_grad():  # Disable gradient computation for evaluation.
        for sub_data in data_loader:  # Iterate over mini-batches.
                sub_data = sub_data.to(device)
                out = model(sub_data.x, sub_data.edge_index)  # Forward pass.
                y_pred = out.argmax(dim=1)  # Use the class with the highest probability.
                
                # Collect predictions and ground truths for each mask.
                for mask, key in zip(
                [sub_data.train_mask, sub_data.val_mask], 
                ["train", "val"]):
                    y_pred_masks[key].append(y_pred[mask].cpu())
                    y_true_masks[key].append(sub_data.y[mask].cpu())
    
    # Compute F1 scores for each mask.
    F1_scores = []
    for key in ["train", "val"]:
        y_true_combined = torch.cat(y_true_masks[key], dim=0).numpy()
        y_pred_combined = torch.cat(y_pred_masks[key], dim=0).numpy()
        F1_scores.append(
                f1_score(y_true_combined, y_pred_combined, average="macro")
        )
    
    return F1_scores

In [None]:
def test_best_GNN(data: Data, sweep_id: str, build_GNN_func: callable) -> pd.DataFrame:
    best_config = fetch_best_config(sweep_id)
    best_config = turn_config_dict_to_config(best_config)
    data_loader = build_dataloader(data, best_config.graph_num_parts, best_config.batch_size)
    best_GNN = build_GNN_func(best_config)
    performance = evaluate_on_test_set(best_GNN, data_loader)
    return performance

In [None]:
def turn_config_dict_to_config(config_dict: dict) -> Config:
    config = Config()
    for key, value in config_dict.items():
        setattr(config, key, value)
    return config

In [None]:
def fetch_best_config(sweep_id: str) -> dict:
    # Authenticate with W&B
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    runs = sweep.runs
    
    # Find the best run
    best_run = max(runs, key=lambda run: run.summary.get("val_F1", float("-inf")))
    best_config = best_run.config
    return best_config

In [None]:
def evaluate_on_test_set(model: torch.nn.Module, data_loader: ClusterLoader) -> pd.DataFrame:
    y_test_pred = []
    y_test_pred_proba = []
    y_test_true = []

    model.eval()
    with torch.no_grad():
        for sub_data in data_loader:
            sub_data = sub_data.to(device)
            out = model(sub_data.x, sub_data.edge_index)
            y_test_pred.append(out[sub_data.test_mask].argmax(dim=1).cpu())
            y_test_pred_proba.append(out[sub_data.test_mask].softmax(dim=1).cpu())
            y_test_true.append(sub_data.y[sub_data.test_mask].cpu())

    y_test_pred = torch.cat(y_test_pred, dim=0).numpy()
    y_test_pred_proba = torch.cat(y_test_pred_proba, dim=0).numpy()
    y_test_true = torch.cat(y_test_true, dim=0).numpy()
    performance = performance_per_class(y_test_true, y_test_pred, y_test_pred_proba)
    return performance

# GNN Models

In [None]:
import torch
from torch.nn import Sequential, Dropout, Linear
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.nn.conv import MessagePassing

In [None]:
def build_activation(activation: str) -> torch.nn.Module:
    if activation == 'relu':
        return torch.nn.ReLU()
    elif activation == 'sigmoid':
        return torch.nn.Sigmoid()
    elif activation == 'tanh':
        return torch.nn.Tanh()
    else:
        raise ValueError("Unsupported activation function. Choose from 'relu', 'sigmoid', or 'tanh'.")

In [None]:
# define a base GNN class for all GNN variants
class BaseGNN(torch.nn.Module):
    def __init__(self, input_dim: int, n_class: int, hidden_dims: list, dropout: float, activation: str, conv_layer: MessagePassing):
        super().__init__()
        torch.manual_seed(888)
        
        activation_fn = build_activation(activation)
        layers = []
        prev_dim = input_dim
        
        # Add convolutional layers
        for h_dim in hidden_dims:
            layers.append(conv_layer(prev_dim, h_dim))  # Conv layer passed as argument
            layers.append(activation_fn)
            layers.append(Dropout(dropout))
            prev_dim = h_dim
        
        # Append classifier layer
        layers.append(Linear(prev_dim, n_class))
        self.network = Sequential(*layers)
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        for layer in self.network:
            if isinstance(layer, (GCNConv, SAGEConv, GATConv)):  # Conv layers need edge_index
                x = layer(x, edge_index)
            else:
                x = layer(x)  # Other layers only need x
        return x

In [None]:
class GCN(BaseGNN):
    def __init__(self, input_dim: int, n_class: int, hidden_dims: list, dropout: float, activation: str):
        super().__init__(input_dim, n_class, hidden_dims, dropout, activation, conv_layer=GCNConv)

In [None]:
class GraphSAGE(BaseGNN):
    def __init__(self, input_dim: int, n_class: int, hidden_dims: list, dropout: float, activation: str, aggr: str):
        super().__init__(input_dim, n_class, hidden_dims, dropout, activation, 
                         conv_layer=lambda in_dim, out_dim: SAGEConv(in_dim, out_dim, aggr=aggr))

In [None]:
class GAT(BaseGNN):
    def __init__(self, input_dim: int, n_class: int, hidden_dims: list, dropout: float, activation: str, heads: int):
        conv_layer = lambda in_dim, out_dim: GATConv(in_dim, out_dim, heads=heads)
        super().__init__(input_dim, n_class, hidden_dims, dropout, activation,conv_layer)

In [None]:
base_GNN_parameters = {
    'hidden_dims': {'values': [[64, 32], [128, 64, 32]]},
    'dropout': {'values': [0.1, 0.3, 0.5]},
    'activation': {'values': ['relu', 'sigmoid', 'tanh']},
    'optimizer': {'values': ['sgd', 'adam']},
    'lr': {'values': [0.001, 0.01, 0.1]},
    'n_class': {'value': data.num_classes},
    'input_dim': {'value': data.num_features},
    'graph_num_parts': {'value': 128},
    'batch_size': {'value': 32},
    'epochs': {'value': 20},
}

# GraphSage

In [None]:
def build_GraphSAGE(config: Config) -> torch.nn.Module:
    model = GraphSAGE(config.input_dim, config.n_class, config.hidden_dims, config.dropout, 
                config.activation, config.aggr).to(device)
    return model

In [None]:
GraphSAGE_parameters = copy.deepcopy(base_GNN_parameters)
GraphSAGE_parameters['aggr'] = {'values': ['mean', 'max']}

In [None]:
GraphSAGE_performance = evaluate_GNN(data, build_GraphSAGE, "AKI_GNN_GraphSAGE", GraphSAGE_parameters)

In [None]:
GraphSAGE_performance

# GAT

In [None]:
def build_GAT(config: Config) -> torch.nn.Module:
    model = GAT(config.input_dim, config.n_class, config.hidden_dims, config.dropout, 
                config.activation, config.heads).to(device)
    return model

In [None]:
GAT_parameters = copy.deepcopy(base_GNN_parameters)
GAT_parameters['heads'] = {'values': [2, 4, 8]}


In [None]:
GAT_performance = evaluate_GNN(data, build_GAT, "AKI_GNN_GAT", GAT_parameters)

In [None]:
GAT_performance

# GCN

In [None]:
def build_GCN(config: Config) -> torch.nn.Module:
    model = GCN(config.input_dim, config.n_class, config.hidden_dims, config.dropout, config.activation).to(device)
    return model

In [None]:
GCN_parameters = copy.deepcopy(base_GNN_parameters)

In [None]:
GCN_performance = evaluate_GNN(data, build_GCN, "AKI_GNN_GCN", GCN_parameters)

In [None]:
GCN_performance