In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import copy

import argparse
from torch import optim
from torch.utils.data import DataLoader
from collections import Counter
import pickle
from tqdm import tqdm
from datetime import datetime
import os
import logging

from sklearn.metrics import precision_recall_curve, auc
from torch.utils.data import Dataset
import gc
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb
from torch.utils.checkpoint import checkpoint

In [2]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


In [3]:
scaler = GradScaler()

# Modify the train function to use autocast and return the loss as a Tensor for scaling
def train(data, model, optim, criterion, lbd, max_clip_norm=5):
    model.train()
    input = data[:, :-1].to(device)
    label = data[:, -1].float().to(device)
    
    optim.zero_grad()
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):  # Enable AMP here
        logits, kld = model(input)
        logits = logits.squeeze(-1)
        kld = kld.sum()
        
        bce = criterion(logits, label)
        loss = bce + lbd * kld
    
    # Scale the loss before backward pass
    scaler.scale(loss).backward()
    
    # Apply gradient clipping only when needed
    if max_clip_norm:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
    
    # Update the optimizer with scaled gradients
    scaler.step(optim)
    scaler.update()  # Adjust scaler for next iteration
    
    return loss.item(), kld.item(), bce.item()


def evaluate(model, data_iter, length):
    model.eval()
    y_pred = np.zeros(length)
    y_true = np.zeros(length)
    y_prob = np.zeros(length)
    pointer = 0
    for data in data_iter:
        input = data[:, :-1].to(device)
        label = data[:, -1]
        batch_size = len(label)
        probability, _ = model(input)
        probability = torch.sigmoid(probability.squeeze(-1).detach())
        predicted = probability > 0.5
        y_true[pointer: pointer + batch_size] = label.cpu().numpy()
        y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
        y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
        pointer += batch_size
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    return auc(recall, precision), (y_pred, y_prob, y_true)

class EHRData(Dataset):
    def __init__(self, data, cla):
        self.data = data
        self.cla = cla

    def __len__(self):
        return len(self.cla)

    def __getitem__(self, idx):
        return self.data[idx], self.cla[idx]

# Reduce the for loops as much as possible
def collate_fn(data):
    # Convert the sparse matrices to dense arrays in a batch operation
    features = np.array([datum[0].toarray().ravel() for datum in data], dtype=np.float32)
    labels = np.array([datum[1] for datum in data], dtype=np.float32).reshape(-1, 1)
    
    # Stack features and labels along the last axis
    data_combined = np.hstack((features, labels))

    # Convert directly to a PyTorch tensor
    return torch.from_numpy(data_combined).long().to(device)

  scaler = GradScaler()


In [4]:
# Original 4 -Ok
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def clone_params(param, N):
    return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class GraphLayer(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
                 num_of_heads, dropout, alpha, concat=True):
        super(GraphLayer, self).__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.num_of_nodes = num_of_nodes
        self.num_of_heads = num_of_heads

        # Single W Linear layer for all heads
        self.W = nn.Linear(in_features, hidden_features * num_of_heads, bias=False)
        self.a = nn.Parameter(torch.rand((num_of_heads, 2 * hidden_features), requires_grad=True))

        # Define V based on whether heads are concatenated
        self.V = nn.Linear(num_of_heads * hidden_features if concat else hidden_features, out_features)

        # Dropout and LeakyReLU
        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

        # Layer normalization
        self.norm = LayerNorm(num_of_heads * hidden_features if concat else hidden_features)

    def initialize(self):
        nn.init.xavier_normal_(self.W.weight.data)
        nn.init.xavier_normal_(self.a.data)
        nn.init.xavier_normal_(self.V.weight.data)

    def attention(self, N, data, edge):
        # Project data to (N, num_heads, hidden_features)
        data_proj = self.W(data).view(N, self.num_of_heads, self.hidden_features)

        # Gather source and destination features for each edge
        edge_src, edge_dst = edge
        h_src = data_proj[edge_src, :, :]  # (E, num_heads, hidden_features)
        h_dst = data_proj[edge_dst, :, :]  # (E, num_heads, hidden_features)

        # Concatenate features of edge endpoints and compute attention scores
        edge_h = torch.cat([h_src, h_dst], dim=-1)  # (E, num_heads, 2 * hidden_features)
        edge_e = self.leakyrelu((self.a.unsqueeze(0) * edge_h).sum(dim=-1))  # (E, num_heads)

        e_rowsum = torch.zeros((N, self.num_of_heads), device=data.device)  # Shape: (N, num_heads)
        h_prime = torch.zeros((N, self.num_of_heads, self.hidden_features), device=data.device)  # Shape: (N, num_heads, hidden_features)

        # Aggregate across all edges in one pass
        e_rowsum.index_add_(0, edge_dst, edge_e)  # Shape: (N, num_heads)
        h_prime.index_add_(0, edge_dst, edge_e.unsqueeze(-1) * h_src)  # Shape: (N, num_heads, hidden_features)

        # Normalize in-place to avoid creating new tensors
        e_rowsum.clamp_(min=1.0)  # Prevent division by zero
        h_prime.div_(e_rowsum.unsqueeze(-1))

        return h_prime

    def forward(self, edge, data):
        N = self.num_of_nodes
        h_prime = self.attention(N, data, edge)

        # Concatenate or average heads based on `concat`
        if self.concat:
            h_prime = h_prime.view(N, -1)  # Concatenate heads (N, num_heads * hidden_features)
            h_prime = F.elu(self.norm(h_prime))  # Apply ELU activation in-place
        else:
            h_prime = self.V(F.relu(self.norm(h_prime.mean(dim=1))))  # Apply ReLU activation in-place

        # Apply dropout
        h_prime = self.dropout(h_prime)

        return h_prime

In [5]:
# Opt 1 - OK
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Process batch all at once instead of one by one for efficiency
        input_edges, output_edges = self.data_to_edges(data)
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())
        
        for attn in self.in_att:
            h_prime = attn(input_edges, h_prime)
        
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]
        
        h_prime = self.out_att(output_edges, h_prime)
        
        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

In [6]:
# With gradient checkpointing
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Apply gradient checkpointing to memory-intensive attention layers
        input_edges, output_edges = self.data_to_edges(data)
        
        # Embed the nodes
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())
        
        # Apply gradient checkpointing on each layer in in_att
        for attn in self.in_att:
            h_prime = checkpoint(attn, input_edges, h_prime)

        # Variational encoding step
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        # Apply checkpointing to the output attention layer as well
        h_prime = checkpoint(self.out_att, output_edges, h_prime)
        
        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

In [7]:
# With gradient checkpointing -90% vram to speed up - Still best so far for speed vs performance
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Calculate edges
        input_edges, output_edges = self.data_to_edges(data)

        # Embed the nodes
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())

        # Apply gradient checkpointing with VRAM monitoring
        for attn in self.in_att:
            # Check VRAM usage before applying checkpoint
            if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
                torch.cuda.empty_cache()  # Conditionally clear cache when VRAM is > 90%
            h_prime = checkpoint(attn, input_edges, h_prime)

        # Variational encoding step
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        # Apply checkpointing to the output attention layer with VRAM monitoring
        if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
            torch.cuda.empty_cache()  # Clear cache if VRAM > 90%
        h_prime = checkpoint(self.out_att, output_edges, h_prime)

        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

In [5]:
# This isnt better :(
# More complex attention
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers, 
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes, 
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes, 
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2 * n_heads)
        
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar, num_heads=3):
        if self.training:
            samples = []
            for _ in range(num_heads):
                std = (0.5 * logvar).exp()
                eps = torch.randn_like(std, device=mu.device)
                samples.append(eps.mul(std).add_(mu))
            return torch.stack(samples).mean(dim=0)
        return mu

    def encoder_decoder(self, data):
        input_edges, output_edges = self.data_to_edges(data)
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())

        for attn in self.in_att:
            h_residual = h_prime
            h_prime = attn(input_edges, h_prime)
            h_prime = F.relu(h_prime + h_residual)
            h_prime = self.dropout(h_prime)
        
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        h_prime = self.out_att(output_edges, h_prime)
        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

In [None]:
# This isnt better :(
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers, 
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes, 
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes, 
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2 * n_heads)
        
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Original data_to_edges code here
        pass

    def reparameterise(self, mu, logvar, num_heads=1):
        if self.training and num_heads > 1:
            samples = []
            for _ in range(num_heads):
                std = (0.5 * logvar).exp()
                eps = torch.randn_like(std, device=mu.device)
                samples.append(eps.mul(std).add_(mu))
            return torch.stack(samples).mean(dim=0)
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std, device=mu.device)
        return eps.mul(std).add_(mu)

    def encoder_decoder(self, data):
        input_edges, output_edges = self.data_to_edges(data)
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())

        for attn in self.in_att:
            h_residual = h_prime
            h_prime = checkpoint(attn, input_edges, h_prime)
            h_prime = F.relu(h_prime + h_residual)
            h_prime = self.dropout(h_prime)
        
        if self.variational:
            h_prime = checkpoint(self.parameterize, h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        h_prime = self.out_att(output_edges, h_prime)
        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        # Original forward method code here, with optional checkpointing for `encoder_decoder`
        pass

In [5]:
# With gradient checkpointing -90% vram to speed up - residual connections -ok and good
class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Calculate edges
        input_edges, output_edges = self.data_to_edges(data)

        # Embed the nodes
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())

        # Apply gradient checkpointing with VRAM monitoring
        for attn in self.in_att:
            if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
                torch.cuda.empty_cache()  # Conditionally clear cache when VRAM is > 90%

            # Residual connection for each layer in in_att
            h_prime_res = h_prime
            h_prime = checkpoint(attn, input_edges, h_prime) + h_prime_res  # Add residual connection

        # Variational encoding step
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        # Residual connection for the out_att layer as well
        h_prime_res = h_prime
        if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
            torch.cuda.empty_cache()  # Clear cache if VRAM > 90%
        h_prime = checkpoint(self.out_att, output_edges, h_prime) + h_prime_res  # Add residual connection

        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

In [6]:
embedding_size = 256
# Modify Hyperparameters here
result_path = "./output/"
data_path = "./data/mimic/output/"
in_feature = embedding_size
out_feature =embedding_size
n_layers = 2
lr = 1e-4
reg = True
n_heads = 1
dropout = 0.2
alpha = 0.1
batch_size = 32
number_of_epochs = 20
eval_freq = 1389
lbd = 1

In [7]:
# Load data
train_x, train_y = pickle.load(open(data_path + '/train_csr.pkl', 'rb'))
val_x, val_y = pickle.load(open(data_path + '/validation_csr.pkl', 'rb'))
test_x, test_y = pickle.load(open(data_path + '/test_csr.pkl', 'rb'))

# Upsample training data
train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
train_x = train_x[train_upsampling]
train_y = train_y[train_upsampling]

# Create result root
s = datetime.now().strftime('%Y%m%d%H%M%S')
result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}'
os.makedirs(result_root, exist_ok=True)
logging.basicConfig(filename=f'{result_root}/train.log', format='%(asctime)s %(message)s', level=logging.INFO)
logging.info(f"Time: {s}")

# Initialize model
num_of_nodes = train_x.shape[1] + 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariationalGNN(embedding_size, embedding_size, num_of_nodes, n_heads, n_layers - 1,
                       dropout=dropout, alpha=alpha, variational=reg, none_graph_features=0).to(device)

#model = torch.compile(model, mode="reduce-overhead")
model = nn.DataParallel(model)
val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=batch_size,
                        collate_fn=collate_fn, shuffle=False
                        #,num_workers=4
                        #,pin_memory=True
                       )
# 8 bit optimizer
optimizer = bnb.optim.Adam8bit(
    [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8
)
#optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [8]:
VRAM_THRESHOLD = 0.9  # 90% usage

def check_and_clear_vram(threshold=VRAM_THRESHOLD):
    """
    Check if VRAM usage exceeds the threshold; if so, clear the cache.
    """
    # Get total and reserved memory
    total_vram = torch.cuda.get_device_properties(0).total_memory
    used_vram = torch.cuda.memory_reserved(0)
    usage_ratio = used_vram / total_vram
    
    # Clear cache if VRAM usage exceeds threshold
    if usage_ratio >= threshold:
        print(f"Clearing VRAM cache... Current usage: {usage_ratio * 100:.2f}%")
        gc.collect()
        torch.cuda.empty_cache()
    return usage_ratio

# Main training loop
for epoch in range(number_of_epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    ratio = Counter(train_y)
    train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                              collate_fn=collate_fn, shuffle=True)
    pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
    
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    
    for idx, batch_data in enumerate(t):
        # Train the model on this batch
        loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
        total_loss += np.array([loss, bce, kld])

        # Check VRAM usage and clear cache if needed
        vram_usage = check_and_clear_vram()  # Check VRAM only when threshold is met

        # Periodic evaluation and saving
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

            # Free resources before evaluation
            with torch.no_grad():
                model.eval()
                val_auprc, _ = evaluate(model, val_loader, len(val_y))

            logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

            del val_auprc  # Clear evaluation results
            model.train()

        # Update progress display
        if idx % 50 == 0 and idx > 0:
            t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            t.refresh()

    # Update scheduler and check VRAM at the end of each epoch
    scheduler.step()
    check_and_clear_vram()  # Check VRAM usage once more after each epoch

Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5986756537258878; loss: 43.4463, bce: 23.1734, kld: 20.2729
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6343070712910882; loss: 38.3987, bce: 17.9859, kld: 20.4128
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6510007184945322; loss: 37.7184, bce: 16.9546, kld: 20.7638
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6647215965718658; loss: 36.7798, bce: 15.5750, kld: 21.2048
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6670706979595098; loss: 36.0039, bce: 14.5923, kld: 21.4116
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6659855695603792; loss: 35.1342, bce: 13.4790, kld: 21.6552
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6748988995367622; loss: 34.8331, bce: 12.9343, kld: 21.8988
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6717231425902123; loss: 34.1422, bce: 11.9812, kld: 22.1610
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.668310021243541; loss: 34.1355, bce: 11.7734, kld: 22.3621
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

KeyboardInterrupt: 

In [13]:
for epoch in range(number_of_epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    ratio = Counter(train_y)
    train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                              collate_fn=collate_fn, shuffle=True)
    pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
    
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    
    for idx, batch_data in enumerate(t):
        # Call the modified train function with autocast enabled
        loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
        total_loss += np.array([loss, bce, kld])

        # Remove unneeded variables from memory immediately after processing
        del batch_data, loss, kld, bce  # Remove references to free memory
        gc.collect()  # Manually trigger garbage collection to free memory
        torch.cuda.empty_cache()  # Clear GPU cache to free VRAM

        # Periodic evaluation and saving
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

            # Free resources before evaluation
            with torch.no_grad():  # Disable gradient tracking
                model.eval()  # Set the model to evaluation mode
                val_auprc, _ = evaluate(model, val_loader, len(val_y))

            # Logging
            logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

            # Clear evaluation variables
            del val_auprc  # Clear evaluation results if not needed anymore
            gc.collect()  # Force garbage collection
            torch.cuda.empty_cache()  # Free up GPU memory

            # Reset the model to training mode
            model.train()  # Switch back to training mode

        # Update progress display
        if idx % 50 == 0 and idx > 0:
            t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            t.refresh()

    # After each epoch, free memory and update scheduler
    scheduler.step()
    gc.collect()  # Collect garbage after epoch
    torch.cuda.empty_cache()  # Clear GPU cache after epoch

Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5634297207023011; loss: 32.9255, bce: 24.4059, kld: 8.5195




Learning rate: 0.0001


  return fn(*args, **kwargs)


epoch:2 AUPRC:0.632972596779331; loss: 27.9923, bce: 19.1867, kld: 8.8055


                                                                                                                       

Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

KeyboardInterrupt: 

In [None]:
# Free memory after eval!
for epoch in range(number_of_epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    ratio = Counter(train_y)
    train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                              collate_fn=collate_fn, shuffle=True
                              #,num_workers=4
                              #,pin_memory=True
                             )
    pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
    
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    
    for idx, batch_data in enumerate(t):
        loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, 5)
        total_loss += np.array([loss, bce, kld])
        
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

            # Free resources before evaluation
            with torch.no_grad():  # Disable gradient tracking
                model.eval()  # Set the model to evaluation mode
                val_auprc, _ = evaluate(model, val_loader, len(val_y))

            logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

            # Clear unnecessary variables
            del val_auprc  # Clear evaluation results if not needed anymore
            gc.collect()  # Run garbage collection
            torch.cuda.empty_cache()  # Free up GPU memory

            # Reset the model to training mode
            model.train()  # Switch back to training mode

        if idx % 50 == 0 and idx > 0:
            t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            t.refresh()

    scheduler.step()

In [None]:
for epoch in range(number_of_epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    ratio = Counter(train_y)
    train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                              collate_fn=collate_fn, shuffle=True)
    
    pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
    
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    
    for idx, batch_data in enumerate(t):
        # Call the modified train function with autocast enabled
        loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
        total_loss += np.array([loss, bce, kld])
        
        # Save model and evaluate periodically
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")
            
            # Free resources before evaluation
            with torch.no_grad():
                model.eval()
                val_auprc, _ = evaluate(model, val_loader, len(val_y))

            logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            
            del val_auprc
            gc.collect()
            torch.cuda.empty_cache()
            
            model.train()

        # Update progress display
        if idx % 50 == 0 and idx > 0:
            t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            t.refresh()
    
    scheduler.step()

In [None]:
# Initialize GradScaler
scaler = GradScaler()

for epoch in range(number_of_epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    
    # Set up data loader and criterion
    ratio = Counter(train_y)
    train_loader = DataLoader(
        dataset=EHRData(train_x, train_y), batch_size=batch_size,
        collate_fn=collate_fn, shuffle=True
    )
    pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
    criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)

    # Progress bar
    t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
    model.train()
    total_loss = np.zeros(3)
    
    for idx, batch_data in enumerate(t):
        optimizer.zero_grad()  # Clear gradients

        # Enable mixed precision for forward pass
        with torch.cuda.amp.autocast(True):
            loss, kld, bce = train(batch_data, model, optimizer, criterion, lbd, 5)
        
        # Scale loss and backpropagate
        scaler.scale(loss).backward()

        # Step optimizer with scaled gradients
        scaler.step(optimizer)
        scaler.update()  # Update the scaler for the next iteration

        total_loss += np.array([loss.item(), bce, kld])  # Ensure loss is not a tensor here

        # Evaluation and saving checkpoints
        if idx % eval_freq == 0 and idx > 0:
            torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

            with torch.no_grad():  # No gradient tracking
                model.eval()  # Switch to eval mode
                val_auprc, _ = evaluate(model, val_loader, len(val_y))

            # Log and print evaluation results
            logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

            # Free up memory
            del val_auprc
            gc.collect()
            torch.cuda.empty_cache()
            model.train()  # Switch back to training mode

        # Update progress bar
        if idx % 50 == 0 and idx > 0:
            t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, bce: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
            t.refresh()

    # Step the learning rate scheduler
    scheduler.step()