This notebook contains the code to run the variationally regularized GNN with additional modifications to help with VRAM consumption. Originally the authors had a P100 (16gb VRAM) to run their models, and the code itself was unoptimized for GPUs with smaller VRAMs. Hence taking advantage of the newer versions of PyTorch as well as the developments in 3rd party libraries, a more VRAM efficient version of the code has been produced. 

In addition, improvements such as residual connections to the attention mechanism was also implemented as it is known to help with the performance (something that the authors were missing)

Furthermore dynamic GPU VRAM management was also implemented to prevent the VRAM usage from exploding and leading to either OOM errors or shared GPU VRAM usage (shifting data between GPU and RAM lead to ~2-3x slowdown).

More details can be found on the Github page, where the scientific improvements (eg. Focal loss, Residual connections) and as well as the engineering (eg. dynamic vram management, downcasting to fp16) are explained. 

## Import Libraries

In [1]:
# Import libraries
# Import all torch + DL related libraries
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb
# Import other libraries
import copy
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
from datetime import datetime
import os
import logging
from sklearn.metrics import precision_recall_curve, auc
import gc
from matplotlib import pyplot as plt
# For aggregate logging
import pandas as pd

In [2]:
# Set the device to be cuda if cuda GPU exists
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


## Helper Functions

In [3]:
# Use AMP to reduce VRAM load and to speed up training
scaler = GradScaler()

# Modify the train function to use autocast and return the loss as a Tensor for scaling
def train(data, model, optim, criterion, lbd, max_clip_norm=5):
    model.train()
    input = data[:, :-1].to(device)
    label = data[:, -1].float().to(device)
    
    optim.zero_grad()
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):  # Enable AMP here
        logits, kld = model(input)
        logits = logits.squeeze(-1)
        kld = kld.sum()
        
        bce = criterion(logits, label)
        loss = bce + lbd * kld
    
    # Scale the loss before backward pass
    scaler.scale(loss).backward()
    
    # Apply gradient clipping only when needed
    if max_clip_norm:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
    
    # Update the optimizer with scaled gradients
    scaler.step(optim)
    scaler.update()  # Adjust scaler for next iteration
    
    return loss.item(), kld.item(), bce.item()


def evaluate(model, data_iter, length):
    model.eval()
    y_pred = np.zeros(length)
    y_true = np.zeros(length)
    y_prob = np.zeros(length)
    pointer = 0
    for data in data_iter:
        input = data[:, :-1].to(device)
        label = data[:, -1]
        batch_size = len(label)
        probability, _ = model(input)
        probability = torch.sigmoid(probability.squeeze(-1).detach())
        predicted = probability > 0.5
        y_true[pointer: pointer + batch_size] = label.cpu().numpy()
        y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
        y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
        pointer += batch_size
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    return auc(recall, precision), (y_pred, y_prob, y_true), (recall, precision)

class EHRData(Dataset):
    def __init__(self, data, cla):
        self.data = data
        self.cla = cla

    def __len__(self):
        return len(self.cla)

    def __getitem__(self, idx):
        return self.data[idx], self.cla[idx]

# Reduce the for loops as much as possible
def collate_fn(data):
    # Convert the sparse matrices to dense arrays in a batch operation
    features = np.array([datum[0].toarray().ravel() for datum in data], dtype=np.float32)
    labels = np.array([datum[1] for datum in data], dtype=np.float32).reshape(-1, 1)
    
    # Stack features and labels along the last axis
    data_combined = np.hstack((features, labels))

    # Convert directly to a PyTorch tensor
    return torch.from_numpy(data_combined).long().to(device)

class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Compute the binary cross entropy with logits
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        # Compute the Focal Loss
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        # Allow for different types of calculation for focal loss
        if self.reduction == 'sum':
            return focal_loss.sum()
        elif self.reduction == 'mean':
            return focal_loss.mean()
        else:
            return focal_loss

def check_and_clear_vram(threshold=0.9):
    # Get total and reserved memory
    total_vram = torch.cuda.get_device_properties(0).total_memory
    used_vram = torch.cuda.memory_reserved(0)
    usage_ratio = used_vram / total_vram
    
    # Clear cache if VRAM usage exceeds threshold
    if usage_ratio >= threshold:
        gc.collect()
        torch.cuda.empty_cache()
    return usage_ratio

  scaler = GradScaler()


## Graph Layer

In [4]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def clone_params(param, N):
    return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class GraphLayer(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
                 num_of_heads, dropout, alpha, concat=True):
        super(GraphLayer, self).__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.num_of_nodes = num_of_nodes
        self.num_of_heads = num_of_heads

        # Single W Linear layer for all heads
        self.W = nn.Linear(in_features, hidden_features * num_of_heads, bias=False)
        self.a = nn.Parameter(torch.rand((num_of_heads, 2 * hidden_features), requires_grad=True))

        # Define V based on whether heads are concatenated
        self.V = nn.Linear(num_of_heads * hidden_features if concat else hidden_features, out_features)

        # Dropout and LeakyReLU
        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

        # Layer normalization
        self.norm = LayerNorm(num_of_heads * hidden_features if concat else hidden_features)

    def initialize(self):
        nn.init.xavier_normal_(self.W.weight.data)
        nn.init.xavier_normal_(self.a.data)
        nn.init.xavier_normal_(self.V.weight.data)

    def attention(self, N, data, edge):
        # Project data to (N, num_heads, hidden_features)
        data_proj = self.W(data).view(N, self.num_of_heads, self.hidden_features)

        # Gather source and destination features for each edge
        edge_src, edge_dst = edge
        h_src = data_proj[edge_src, :, :]  # (E, num_heads, hidden_features)
        h_dst = data_proj[edge_dst, :, :]  # (E, num_heads, hidden_features)

        # Concatenate features of edge endpoints and compute attention scores
        # Standard attention structure using concat
        edge_h = torch.cat([h_src, h_dst], dim=-1)  # (E, num_heads, 2 * hidden_features)
        edge_e = self.leakyrelu((self.a.unsqueeze(0) * edge_h).sum(dim=-1))  # (E, num_heads)

        e_rowsum = torch.zeros((N, self.num_of_heads), device=data.device)  # Shape: (N, num_heads)
        h_prime = torch.zeros((N, self.num_of_heads, self.hidden_features), device=data.device)  # Shape: (N, num_heads, hidden_features)

        # Aggregate across all edges in one pass to improve efficiency
        e_rowsum.index_add_(0, edge_dst, edge_e)  # Shape: (N, num_heads)
        h_prime.index_add_(0, edge_dst, edge_e.unsqueeze(-1) * h_src)  # Shape: (N, num_heads, hidden_features)

        # Normalize in-place to avoid creating new tensors
        e_rowsum.clamp_(min=1.0)  # Prevent division by zero
        h_prime.div_(e_rowsum.unsqueeze(-1))

        return h_prime

    def forward(self, edge, data):
        N = self.num_of_nodes
        h_prime = self.attention(N, data, edge)

        # Concatenate or average heads based on `concat`
        if self.concat:
            h_prime = h_prime.view(N, -1)  # Concatenate heads (N, num_heads * hidden_features)
            h_prime = F.elu(self.norm(h_prime))  # Apply ELU activation in-place
        else:
            h_prime = self.V(F.relu(self.norm(h_prime.mean(dim=1))))  # Apply ReLU activation in-place

        # Apply dropout
        h_prime = self.dropout(h_prime)

        return h_prime

## Variational GNN

In [5]:
# With gradient checkpointing - conditional -90% vram to speed up - residual connections -ok and good
#This isnt really giving the speed up unfortunately
def conditional_checkpoint(layer, *inputs):
    # Calculate current VRAM usage
    vram_usage_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
    # Apply checkpointing if VRAM usage exceeds the threshold
    if vram_usage_ratio > VRAM_THRESHOLD:
        return checkpoint(layer, *inputs)
    else:
        return layer(*inputs)

class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Calculate edges
        input_edges, output_edges = self.data_to_edges(data)

        # Embed the nodes
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())
        if h_prime.requires_grad is False:
            h_prime.requires_grad = True

        # Apply gradient checkpointing with VRAM monitoring
        for attn in self.in_att:
            if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
                torch.cuda.empty_cache()  # Conditionally clear cache when VRAM is > 90%

            # Residual connection for each layer in in_att
            h_prime_res = h_prime
            h_prime = conditional_checkpoint(attn, input_edges, h_prime) + h_prime_res  # Add residual connection but with conditional checkpointing
            #h_prime = checkpoint(attn, input_edges, h_prime) + h_prime_res  # Add residual connection
            
        # Variational encoding step
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        # Residual connection for the out_att layer as well
        h_prime_res = h_prime
        if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
            torch.cuda.empty_cache()  # Clear cache if VRAM > 90%
        h_prime = checkpoint(self.out_att, output_edges, h_prime) + h_prime_res  # Add residual connection

        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

## Training + Evaluation

In [6]:
def train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,lr,reg,n_heads,dropout,
                  alpha,batch_size,number_of_epochs,eval_freq,lbd,VRAM_THRESHOLD
                  ):
    # Load data
    train_x, train_y = pickle.load(open(data_path + '/train_csr.pkl', 'rb'))
    val_x, val_y = pickle.load(open(data_path + '/validation_csr.pkl', 'rb'))
    test_x, test_y = pickle.load(open(data_path + '/test_csr.pkl', 'rb'))

    # Upsample training data
    train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
    train_x = train_x[train_upsampling]
    train_y = train_y[train_upsampling]

    # Create result root
    s = datetime.now().strftime('%Y%m%d%H%M%S')
    result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}'
    os.makedirs(result_root, exist_ok=True)
    logging.basicConfig(filename=f'{result_root}/train.log', format='%(asctime)s %(message)s', level=logging.INFO)
    logging.info(f"Time: {s}")

    # Initialize model
    num_of_nodes = train_x.shape[1] + 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VariationalGNN(embedding_size, embedding_size, num_of_nodes, n_heads, n_layers - 1,
                           dropout=dropout, alpha=alpha, variational=reg, none_graph_features=0).to(device)

    model = nn.DataParallel(model)
    val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=batch_size,
                            collate_fn=collate_fn, shuffle=False
                           )
    # 8 bit optimizer to speed things up and reduce memory load
    optimizer = bnb.optim.Adam8bit(
        [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    pr_list = []
    for epoch in range(number_of_epochs):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        ratio = Counter(train_y)
        train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                                  collate_fn=collate_fn, shuffle=True)

        # Initialize the Focal Loss criterion with calculated pos_weight
        pos_weight = ratio[False] / ratio[True]
        criterion = FocalLoss(alpha=pos_weight, gamma=1.5, reduction="sum").to(device)  # Modify gamma and alpha as needed

        t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
        model.train()
        total_loss = np.zeros(3)

        for idx, batch_data in enumerate(t):
            # Train the model on this batch
            loss, kld, focal_loss = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
            total_loss += np.array([loss, focal_loss, kld])

            # Check VRAM usage and clear cache if needed
            vram_usage = check_and_clear_vram(VRAM_THRESHOLD)  # Check VRAM only when threshold is met

            # Periodic evaluation and saving
            if idx % eval_freq == 0 and idx > 0:
                torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

                # Free resources before evaluation
                with torch.no_grad():
                    model.eval()
                    val_auprc, _,prc = evaluate(model, val_loader, len(val_y))
                    pr_list.append(prc)

                logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

                del val_auprc  # Clear evaluation results
                model.train()

            # Update progress display
            if idx % 50 == 0 and idx > 0:
                t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                t.refresh()

        # Update scheduler and check VRAM at the end of each epoch
        scheduler.step()
        check_and_clear_vram()
        np.savetxt(f'{result_root}/pr_list.npy', pr_list, delimiter=',')

In [7]:
VRAM_THRESHOLD = 0.9  # 90% usage (prevent leak or OOM dynamically to improve training throughput)
# Folder paths
result_path = "./output/focal/"
data_path = "./data/mimic/output/"
# Hyperparameters
embedding_size = 32
in_feature = embedding_size
out_feature =embedding_size
n_layers = 2
lr = 1e-4
reg = True
n_heads = 1
dropout = 0.2
alpha = 0.1
batch_size = 256
number_of_epochs = 2
#eval_freq = 1389
#eval_freq = 173
lbd = 1
focal_loss = False

In [18]:
def train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
                  alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr,dropout,focal_loss
                  ):
    # Load data
    train_x, train_y = pickle.load(open(data_path + '/train_csr.pkl', 'rb'))
    val_x, val_y = pickle.load(open(data_path + '/validation_csr.pkl', 'rb'))
    test_x, test_y = pickle.load(open(data_path + '/test_csr.pkl', 'rb'))

    # Upsample training data
    train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
    train_x = train_x[train_upsampling]
    train_y = train_y[train_upsampling]

    # Create result root
    s = datetime.now().strftime('%Y%m%d%H%M%S')
    if focal_loss == True:
        result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}-focal_True'
    else:
        result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}-focal_False'
    os.makedirs(result_root, exist_ok=True)
    logging.basicConfig(filename=f'{result_root}/train.log', format='%(asctime)s %(message)s', level=logging.INFO)
    logging.info(f"Time: {s}")

    # Initialize model
    num_of_nodes = train_x.shape[1] + 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VariationalGNN(embedding_size, embedding_size, num_of_nodes, n_heads, n_layers - 1,
                           dropout=dropout, alpha=alpha, variational=reg, none_graph_features=0).to(device)

    model = nn.DataParallel(model)
    val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=batch_size,
                            collate_fn=collate_fn, shuffle=False
                           )
    # 8 bit optimizer to speed things up and reduce memory load
    optimizer = bnb.optim.Adam8bit(
        [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    pr_list = []
    val_list = []
    for epoch in range(number_of_epochs):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        ratio = Counter(train_y)
        train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                                  collate_fn=collate_fn, shuffle=True)

        if focal_loss == True:
            # Initialize the Focal Loss criterion with calculated pos_weight
            pos_weight = ratio[False] / ratio[True]
            criterion = FocalLoss(alpha=pos_weight, gamma=1.5, reduction="sum").to(device)  # Modify gamma and alpha as needed
        else:
            pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
            criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)

        t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
        model.train()
        total_loss = np.zeros(3)

        for idx, batch_data in enumerate(t):
            # Train the model on this batch
            loss, kld, focal_loss = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
            total_loss += np.array([loss, focal_loss, kld])

            # Check VRAM usage and clear cache if needed
            vram_usage = check_and_clear_vram(VRAM_THRESHOLD)  # Check VRAM only when threshold is met

            # Periodic evaluation and saving
            if idx % (len(train_loader)-1) == 0 and idx > 0:
                torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

                # Free resources before evaluation
                with torch.no_grad():
                    model.eval()
                    val_auprc, _,prc = evaluate(model, val_loader, len(val_y))
                    pr_list.append(prc)
                    val_list.append(val_auprc)

                logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

                del val_auprc  # Clear evaluation results
                model.train()

            # Update progress display
            if idx % 50 == 0 and idx > 0:
                t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                t.refresh()

        # Update scheduler and check VRAM at the end of each epoch
        scheduler.step()
        check_and_clear_vram()

    data = []
    for epoch_idx, (recall,precision) in enumerate(pr_list, start=1):
        data.append({
            "Embedding_Size": embedding_size,
            "Learning_Rate" : lr,
            "Dropout": dropout,
            "Alpha":alpha,
            "Focal" : str(focal_loss),
            "epoch": epoch_idx,
            "precision": precision.tolist(),
            "recall": recall.tolist(),
            "auprc": val_list[epoch_idx - 1]
        })

    df = pd.DataFrame(data)
    df.to_excel(f'{result_root}/summary_runs.xlsx', index=False)
    
    return df

In [9]:
# df_run = train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
#                   alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr,dropout,focal_loss
#                   )

## Hyperparameter Optimization (Optional)

In [10]:
# Set up hyperparameter tuning space
# Need to fit 50
# Experiment space on embedding size, lr, dropout, alpha
# emb_size = [128,256,512]
# lr = [0.0001,5e-5,1e-5]
# dropout = [0.2,0.3,0.4]
# F_loss = [true,false]
# So 3*3*3*2 = 54 runs
# Drop epochs to 10 
# 540 epochs * ~7min = 3780 mins
# Split between 3 machines
# So each takes 18 runs -> 18*10 epochs*~10min (upper bound) -> 1800 mins = 30h
# So do 4 days of training straight -> k=3 cv

In [19]:
# Fixed HPO
VRAM_THRESHOLD = 0.9  # 90% usage (prevent leak or OOM dynamically to improve training throughput)
# Folder paths
result_path = "./output/focal/"
data_path = "./data/mimic/output/"
# Hyperparameters
embedding_size = 32
in_feature = embedding_size
out_feature =embedding_size
n_layers = 2
reg = True
n_heads = 1
alpha = 0.1
batch_size = 256
lbd = 1
number_of_epochs = 1

# So for this we do 256, kaggle = 128, 3080ti = 512
# lr_list = [1e-4,5e-5,1e-5]
# dropout_list = [0.2,0.3,0.4]
# focal_list = [True, False]

lr_list = [1e-4]
dropout_list = [0.2,0.3]
focal_list = [True]

# Set focal loss = true then switch out for false
df_run_all = pd.DataFrame()
for i in lr_list:
    for j in dropout_list:
        for k in focal_list:
            df_run = train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
                  alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr=i,dropout=j,focal_loss=k)
            df_run_all = pd.concat([df_run_all, df_run], ignore_index=True)
            df_run_all.to_excel(f'{result_path}/full_summary_runs.xlsx', index=False)

Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.16961619864653804; loss: 277.2233, focal_loss: 207.8182, kld: 69.4051
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.12761971194863092; loss: 304.0126, focal_loss: 222.1381, kld: 81.8745




## Visualization

In [None]:
# from sklearn.metrics import PrecisionRecallDisplay

In [None]:
# auprc_val = auc(pr_list[5][0],pr_list[5][1])

In [None]:
# # plt.plot(pr_list[0][0],pr_list[0][1])
# # plt.plot(pr_list[1][0],pr_list[1][1])
# # plt.plot(pr_list[2][0],pr_list[2][1])
# # plt.plot(pr_list[3][0],pr_list[3][1])
# # plt.plot(pr_list[4][0],pr_list[4][1])
# plt.plot(pr_list[5][0],pr_list[5][1])
# plt.xlabel('Recall')
# plt.ylabel('Precision')
# plt.title('Precision-Recall Curve')
# plt.fill_between(pr_list[5][0],pr_list[5][1], alpha=0.6)
# x_label = 0.3
# y_label = 0.4
# label_text = f"AUPRC: {auprc_val:.2f}"
# #plt.text(x_label, y_label, label_text, color="Black", fontsize=12, ha="center", va="center")

# #PrecisionRecallDisplay(precision=pr_list[5][1], recall=pr_list[5][0]).plot()