This notebook contains the code to run the variationally regularized GNN with additional modifications to help with VRAM consumption. Originally the authors had a P100 (16gb VRAM) to run their models, and the code itself was unoptimized for GPUs with smaller VRAMs. Hence taking advantage of the newer versions of PyTorch as well as the developments in 3rd party libraries, a more VRAM efficient version of the code has been produced. 

In addition, improvements such as residual connections to the attention mechanism was also implemented as it is known to help with the performance (something that the authors were missing)

Furthermore dynamic GPU VRAM management was also implemented to prevent the VRAM usage from exploding and leading to either OOM errors or shared GPU VRAM usage (shifting data between GPU and RAM lead to ~2-3x slowdown).

More details can be found on the Github page, where the scientific improvements (eg. Focal loss, Residual connections) and as well as the engineering (eg. dynamic vram management, downcasting to fp16) are explained. 

## Import Libraries

In [1]:
# Import libraries
# Import all torch + DL related libraries
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb
# Import other libraries
import copy
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
from datetime import datetime
import os
import logging
from sklearn.metrics import precision_recall_curve, auc
import gc
from matplotlib import pyplot as plt
# For aggregate logging
import pandas as pd

In [2]:
# Set the device to be cuda if cuda GPU exists
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


## Helper Functions

In [3]:
# Use AMP to reduce VRAM load and to speed up training
scaler = GradScaler()

# Modify the train function to use autocast and return the loss as a Tensor for scaling
def train(data, model, optim, criterion, lbd, max_clip_norm=5):
    model.train()
    input = data[:, :-1].to(device)
    label = data[:, -1].float().to(device)
    
    optim.zero_grad()
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):  # Enable AMP here
        logits, kld = model(input)
        logits = logits.squeeze(-1)
        kld = kld.sum()
        
        bce = criterion(logits, label)
        loss = bce + lbd * kld
    
    # Scale the loss before backward pass
    scaler.scale(loss).backward()
    
    # Apply gradient clipping only when needed
    if max_clip_norm:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
    
    # Update the optimizer with scaled gradients
    scaler.step(optim)
    scaler.update()  # Adjust scaler for next iteration
    
    return loss.item(), kld.item(), bce.item()


def evaluate(model, data_iter, length):
    model.eval()
    y_pred = np.zeros(length)
    y_true = np.zeros(length)
    y_prob = np.zeros(length)
    pointer = 0
    for data in data_iter:
        input = data[:, :-1].to(device)
        label = data[:, -1]
        batch_size = len(label)
        probability, _ = model(input)
        probability = torch.sigmoid(probability.squeeze(-1).detach())
        predicted = probability > 0.5
        y_true[pointer: pointer + batch_size] = label.cpu().numpy()
        y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
        y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
        pointer += batch_size
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    return auc(recall, precision), (y_pred, y_prob, y_true), (recall, precision)

class EHRData(Dataset):
    def __init__(self, data, cla):
        self.data = data
        self.cla = cla

    def __len__(self):
        return len(self.cla)

    def __getitem__(self, idx):
        return self.data[idx], self.cla[idx]

# Reduce the for loops as much as possible
def collate_fn(data):
    # Convert the sparse matrices to dense arrays in a batch operation
    features = np.array([datum[0].toarray().ravel() for datum in data], dtype=np.float32)
    labels = np.array([datum[1] for datum in data], dtype=np.float32).reshape(-1, 1)
    
    # Stack features and labels along the last axis
    data_combined = np.hstack((features, labels))

    # Convert directly to a PyTorch tensor
    return torch.from_numpy(data_combined).long().to(device)

class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Compute the binary cross entropy with logits
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        # Compute the Focal Loss
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        # Allow for different types of calculation for focal loss
        if self.reduction == 'sum':
            return focal_loss.sum()
        elif self.reduction == 'mean':
            return focal_loss.mean()
        else:
            return focal_loss

def check_and_clear_vram(threshold=0.9):
    # Get total and reserved memory
    total_vram = torch.cuda.get_device_properties(0).total_memory
    used_vram = torch.cuda.memory_reserved(0)
    usage_ratio = used_vram / total_vram
    
    # Clear cache if VRAM usage exceeds threshold
    if usage_ratio >= threshold:
        gc.collect()
        torch.cuda.empty_cache()
    return usage_ratio

  scaler = GradScaler()


## Graph Layer

In [4]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def clone_params(param, N):
    return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class GraphLayer(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
                 num_of_heads, dropout, alpha, concat=True):
        super(GraphLayer, self).__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.num_of_nodes = num_of_nodes
        self.num_of_heads = num_of_heads

        # Single W Linear layer for all heads
        self.W = nn.Linear(in_features, hidden_features * num_of_heads, bias=False)
        self.a = nn.Parameter(torch.rand((num_of_heads, 2 * hidden_features), requires_grad=True))

        # Define V based on whether heads are concatenated
        self.V = nn.Linear(num_of_heads * hidden_features if concat else hidden_features, out_features)

        # Dropout and LeakyReLU
        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

        # Layer normalization
        self.norm = LayerNorm(num_of_heads * hidden_features if concat else hidden_features)

    def initialize(self):
        nn.init.xavier_normal_(self.W.weight.data)
        nn.init.xavier_normal_(self.a.data)
        nn.init.xavier_normal_(self.V.weight.data)

    def attention(self, N, data, edge):
        # Project data to (N, num_heads, hidden_features)
        data_proj = self.W(data).view(N, self.num_of_heads, self.hidden_features)

        # Gather source and destination features for each edge
        edge_src, edge_dst = edge
        h_src = data_proj[edge_src, :, :]  # (E, num_heads, hidden_features)
        h_dst = data_proj[edge_dst, :, :]  # (E, num_heads, hidden_features)

        # Concatenate features of edge endpoints and compute attention scores
        # Standard attention structure using concat
        edge_h = torch.cat([h_src, h_dst], dim=-1)  # (E, num_heads, 2 * hidden_features)
        edge_e = self.leakyrelu((self.a.unsqueeze(0) * edge_h).sum(dim=-1))  # (E, num_heads)

        e_rowsum = torch.zeros((N, self.num_of_heads), device=data.device)  # Shape: (N, num_heads)
        h_prime = torch.zeros((N, self.num_of_heads, self.hidden_features), device=data.device)  # Shape: (N, num_heads, hidden_features)

        # Aggregate across all edges in one pass to improve efficiency
        e_rowsum.index_add_(0, edge_dst, edge_e)  # Shape: (N, num_heads)
        h_prime.index_add_(0, edge_dst, edge_e.unsqueeze(-1) * h_src)  # Shape: (N, num_heads, hidden_features)

        # Normalize in-place to avoid creating new tensors
        e_rowsum.clamp_(min=1.0)  # Prevent division by zero
        h_prime.div_(e_rowsum.unsqueeze(-1))

        return h_prime

    def forward(self, edge, data):
        N = self.num_of_nodes
        h_prime = self.attention(N, data, edge)

        # Concatenate or average heads based on `concat`
        if self.concat:
            h_prime = h_prime.view(N, -1)  # Concatenate heads (N, num_heads * hidden_features)
            h_prime = F.elu(self.norm(h_prime))  # Apply ELU activation in-place
        else:
            h_prime = self.V(F.relu(self.norm(h_prime.mean(dim=1))))  # Apply ReLU activation in-place

        # Apply dropout
        h_prime = self.dropout(h_prime)

        return h_prime

## Variational GNN

In [None]:
# With gradient checkpointing - conditional -90% vram to speed up - residual connections -ok and good
#This isnt really giving the speed up unfortunately
def conditional_checkpoint(layer, *inputs):
    # Calculate current VRAM usage
    vram_usage_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
    # Apply checkpointing if VRAM usage exceeds the threshold
    if vram_usage_ratio > VRAM_THRESHOLD:
        return checkpoint(layer, *inputs)
    else:
        return layer(*inputs)

class VariationalGNN(nn.Module):

    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
        super(VariationalGNN, self).__init__()
        self.variational = variational
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
        self.in_att = clones(
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
                       n_heads, dropout, alpha, concat=True), n_layers)
        self.out_features = out_features
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
                                  n_heads, dropout, alpha, concat=False)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.parameterize = nn.Linear(out_features, out_features * 2)
        self.out_layer = nn.Sequential(
            nn.Linear(out_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_features, 1))
        self.none_graph_features = none_graph_features
        if none_graph_features > 0:
            self.features_ffn = nn.Sequential(
                nn.Linear(none_graph_features, out_features // 2),
                nn.ReLU(),
                nn.Dropout(dropout))
            self.out_layer = nn.Sequential(
                nn.Linear(out_features + out_features // 2, out_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_features, 1))
        for i in range(n_layers):
            self.in_att[i].initialize()

    def data_to_edges(self, data):
        # Convert data to edges with device allocation at the end
        data = data.bool()
        length = data.size()[0]
        nonzero = data.nonzero(as_tuple=False)
        
        if nonzero.numel() == 0:
            return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        if self.training:
            mask = (torch.rand(nonzero.size(0), device=data.device) > 0.05)
            nonzero = nonzero[mask]
            if nonzero.numel() == 0:
                return torch.LongTensor([[0], [0]]).to(device), torch.LongTensor([[length + 1], [length + 1]]).to(device)
        
        nonzero = nonzero.T + 1
        lengths = nonzero.size(1)
        
        input_edges = torch.cat((nonzero.repeat(1, lengths),
                                 nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        
        # Extend nonzero and avoid redundant device transfer
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(data.device)), dim=1)
        lengths = nonzero.size(1)
        output_edges = torch.cat((nonzero.repeat(1, lengths),
                                  nonzero.repeat(lengths, 1).T.contiguous().view(1, lengths ** 2)), dim=0)
        return input_edges, output_edges

    def reparameterise(self, mu, logvar):
        if self.training:
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std, device=mu.device)
            return eps.mul(std).add_(mu)
        return mu

    def encoder_decoder(self, data):
        # Calculate edges
        input_edges, output_edges = self.data_to_edges(data)

        # Embed the nodes
        h_prime = self.embed(torch.arange(self.num_of_nodes, device=data.device).long())
        if h_prime.requires_grad is False:
            h_prime.requires_grad = True

        # Apply gradient checkpointing with VRAM monitoring
        for attn in self.in_att:
            if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
                torch.cuda.empty_cache()  # Conditionally clear cache when VRAM is > 90%

            # Residual connection for each layer in in_att
            h_prime_res = h_prime
            h_prime = conditional_checkpoint(attn, input_edges, h_prime) + h_prime_res  # Add residual connection but with conditional checkpointing

        # Variational encoding step
        if self.variational:
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
            mu, logvar = h_prime[:, 0, :], h_prime[:, 1, :]
            h_prime = self.reparameterise(mu, logvar)
            mu, logvar = mu[data], logvar[data]

        # Residual connection for the out_att layer as well
        h_prime_res = h_prime
        if torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory > 0.9:
            torch.cuda.empty_cache()  # Clear cache if VRAM > 90%
        h_prime = checkpoint(self.out_att, output_edges, h_prime) + h_prime_res  # Add residual connection

        kld = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size(0) if self.variational else 0
        return h_prime[-1], kld

    def forward(self, data):
        batch_size = data.size(0)
        if self.none_graph_features == 0:
            outputs = [self.encoder_decoder(data[i]) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*outputs)
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))
        else:
            outputs = [(data[i, :self.none_graph_features],
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
            outputs_h, kld_sum = zip(*[(self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]) for out in outputs])
            return self.out_layer(F.relu(torch.stack(outputs_h))), torch.sum(torch.tensor(kld_sum))

## Training + Evaluation

## Standard Run

In [None]:
## Parameters
VRAM_THRESHOLD = 0.9  # 90% usage (prevent leak or OOM dynamically to improve training throughput)
# Folder paths
result_path = "./output/focal/"
data_path = "./Preprocessing/mimic/output/"
# Hyperparameters
embedding_size = 32
in_feature = embedding_size
out_feature =embedding_size
n_layers = 2
lr = 1e-4
reg = True
n_heads = 1
dropout = 0.2
alpha = 0.1
batch_size = 256
number_of_epochs = 2
lbd = 1
focal_loss = False

In [8]:
def train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
                  alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr,dropout,focal_loss_val
                  ):
    # Load data
    train_x, train_y = pickle.load(open(data_path + '/train_csr.pkl', 'rb'))
    val_x, val_y = pickle.load(open(data_path + '/validation_csr.pkl', 'rb'))
    test_x, test_y = pickle.load(open(data_path + '/test_csr.pkl', 'rb'))

    # Upsample training data
    train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
    train_x = train_x[train_upsampling]
    train_y = train_y[train_upsampling]

    # Create result root
    s = datetime.now().strftime('%Y%m%d%H%M%S')
    if focal_loss_val == True:
        result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}-focal_True'
    else:
        result_root = f'{result_path}/lr_{lr}-input_{embedding_size}-output_{embedding_size}-dropout_{dropout}-focal_False'
    os.makedirs(result_root, exist_ok=True)
    logging.basicConfig(filename=f'{result_root}/train.log', format='%(asctime)s %(message)s', level=logging.INFO)
    logging.info(f"Time: {s}")

    # Initialize model
    num_of_nodes = train_x.shape[1] + 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VariationalGNN(embedding_size, embedding_size, num_of_nodes, n_heads, n_layers - 1,
                           dropout=dropout, alpha=alpha, variational=reg, none_graph_features=0).to(device)

    model = nn.DataParallel(model)
    val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=batch_size,
                            collate_fn=collate_fn, shuffle=False
                           )
    # 8 bit optimizer to speed things up and reduce memory load
    optimizer = bnb.optim.Adam8bit(
        [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    pr_list = []
    val_list = []
    for epoch in range(number_of_epochs):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        ratio = Counter(train_y)
        train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=batch_size,
                                  collate_fn=collate_fn, shuffle=True)

        if focal_loss_val == True:
            # Initialize the Focal Loss criterion with calculated pos_weight
            pos_weight = ratio[False] / ratio[True]
            criterion = FocalLoss(alpha=pos_weight, gamma=1.5, reduction="sum").to(device)  # Modify gamma and alpha as needed
        else:
            pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
            criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)

        t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
        model.train()
        total_loss = np.zeros(3)

        for idx, batch_data in enumerate(t):
            # Train the model on this batch
            loss, kld, focal_loss = train(batch_data, model, optimizer, criterion, lbd, max_clip_norm=5)
            total_loss += np.array([loss, focal_loss, kld])

            # Check VRAM usage and clear cache if needed
            vram_usage = check_and_clear_vram(VRAM_THRESHOLD)  # Check VRAM only when threshold is met

            # Periodic evaluation and saving
            if idx % (len(train_loader)-1) == 0 and idx > 0:
                torch.save(model.state_dict(), f"{result_root}/parameter{epoch}_{idx}")

                # Free resources before evaluation
                with torch.no_grad():
                    model.eval()
                    val_auprc, _,prc = evaluate(model, val_loader, len(val_y))
                    pr_list.append(prc)
                    val_list.append(val_auprc)

                logging.info(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                print(f'epoch:{epoch + 1} AUPRC:{val_auprc}; loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')

                del val_auprc  # Clear evaluation results
                model.train()

            # Update progress display
            if idx % 50 == 0 and idx > 0:
                t.set_description(f'[epoch:{epoch + 1}] loss: {total_loss[0]/idx:.4f}, focal_loss: {total_loss[1]/idx:.4f}, kld: {total_loss[2]/idx:.4f}')
                t.refresh()

        # Update scheduler and check VRAM at the end of each epoch
        scheduler.step()
        check_and_clear_vram()

    data = []
    for epoch_idx, (recall,precision) in enumerate(pr_list, start=1):
        data.append({
            "Embedding_Size": embedding_size,
            "Learning_Rate" : lr,
            "Dropout": dropout,
            "Alpha":alpha,
            "Focal" : str(focal_loss_val),
            "epoch": epoch_idx,
            "precision": precision.tolist(),
            "recall": recall.tolist(),
            "auprc": val_list[epoch_idx - 1]
        })

    df = pd.DataFrame(data)
    df.to_excel(f'{result_root}/summary_runs.xlsx', index=False)
    
    return df

In [None]:
# Uncomment here to generate a single run based on the parameters
# df_run = train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
#                   alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr,dropout,focal_loss
#                   )

## Hyperparameter Optimization (Optional)

In [None]:
VRAM_THRESHOLD = 0.9  # 90% usage (prevent leak or OOM dynamically to improve training throughput)
# Folder paths
result_path = "./output/focal/Run4/"
data_path = "./data/mimic/output/"
# Hyperparameters
embedding_size = 256
in_feature = embedding_size
out_feature =embedding_size
n_layers = 2
reg = True
n_heads = 1
alpha = 0.1
batch_size = 32
lbd = 1
number_of_epochs = 10

# Set the hyperparameters range for grid search HPO
lr_list = [1e-4,5e-5,1e-5]
dropout_list = [0.2,0.3,0.4]
focal_list = [True, False]

# Set focal loss = true then switch out for false
df_run_all = pd.DataFrame()
for i in lr_list:
    for j in dropout_list:
        for k in focal_list:
            df_run = train_evaluate(result_path,data_path,in_feature,out_feature,n_layers,reg,n_heads,
                  alpha,batch_size,number_of_epochs,lbd,VRAM_THRESHOLD,lr=i,dropout=j,focal_loss_val=k)
            df_run_all = pd.concat([df_run_all, df_run], ignore_index=True)
            df_run_all.to_excel(f'{result_path}/full_summary_runs_4.xlsx', index=False)

Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5644101276217018; loss: 40.5494, focal_loss: 17.3645, kld: 23.1849
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6478064945774141; loss: 37.3423, focal_loss: 13.5659, kld: 23.7764
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6700081306428388; loss: 36.5557, focal_loss: 12.6496, kld: 23.9061
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6876039723505643; loss: 35.9975, focal_loss: 11.7791, kld: 24.2184
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6937254915600751; loss: 35.2588, focal_loss: 10.9078, kld: 24.3510
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6987698146256565; loss: 34.6928, focal_loss: 10.1841, kld: 24.5087
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6959119629875391; loss: 34.4724, focal_loss: 9.9529, kld: 24.5195
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6965758504510567; loss: 34.1985, focal_loss: 9.4810, kld: 24.7175
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.7019886950195212; loss: 33.9783, focal_loss: 9.0998, kld: 24.8785
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.7024944549411217; loss: 33.8938, focal_loss: 8.7641, kld: 25.1298




Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.56237119317267; loss: 50.0972, focal_loss: 24.3773, kld: 25.7199
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6373850814049984; loss: 45.1951, focal_loss: 19.1630, kld: 26.0321
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6661345783489524; loss: 42.9156, focal_loss: 16.7427, kld: 26.1729
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.674259380385307; loss: 41.9879, focal_loss: 15.6945, kld: 26.2934
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6830048704955447; loss: 41.3297, focal_loss: 14.8851, kld: 26.4446
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.68313875053233; loss: 40.5755, focal_loss: 13.8118, kld: 26.7637
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6920049837353717; loss: 40.2536, focal_loss: 13.2495, kld: 27.0040
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6979329147163902; loss: 39.5593, focal_loss: 12.3276, kld: 27.2317
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6953166475524939; loss: 39.4337, focal_loss: 11.9055, kld: 27.5283
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6952002521813855; loss: 39.6903, focal_loss: 11.8324, kld: 27.8579




Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5667807479573248; loss: 38.8134, focal_loss: 17.6467, kld: 21.1667
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6303976382129595; loss: 35.0011, focal_loss: 13.7504, kld: 21.2507
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6607453670303814; loss: 34.4046, focal_loss: 12.8998, kld: 21.5048
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6824466824603983; loss: 33.3046, focal_loss: 11.5351, kld: 21.7696
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6932675455442472; loss: 33.0600, focal_loss: 11.0845, kld: 21.9755
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.696388615259498; loss: 32.3097, focal_loss: 9.9884, kld: 22.3213
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6976342035976585; loss: 32.0681, focal_loss: 9.5867, kld: 22.4813
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.7109741638704775; loss: 31.9773, focal_loss: 9.1984, kld: 22.7789
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.7068047652992151; loss: 31.7718, focal_loss: 8.8231, kld: 22.9487
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.7062867928919042; loss: 31.9042, focal_loss: 8.6603, kld: 23.2439




Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5869975966720999; loss: 49.7097, focal_loss: 23.8158, kld: 25.8939
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6513735385576712; loss: 44.8116, focal_loss: 18.8499, kld: 25.9617
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.671733649194377; loss: 43.3853, focal_loss: 17.0942, kld: 26.2911
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6710504242875336; loss: 42.5590, focal_loss: 15.6209, kld: 26.9381
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6774916612279194; loss: 42.6067, focal_loss: 14.8803, kld: 27.7264
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6891055674247178; loss: 41.6517, focal_loss: 13.4328, kld: 28.2189
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6918224898240496; loss: 41.5237, focal_loss: 13.0639, kld: 28.4598
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6955084728191996; loss: 40.8082, focal_loss: 12.1059, kld: 28.7023
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6981187421475402; loss: 41.1777, focal_loss: 12.0083, kld: 29.1694
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.7006975103573143; loss: 41.0759, focal_loss: 11.3097, kld: 29.7661




Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5332791104352094; loss: 41.7626, focal_loss: 18.3215, kld: 23.4411
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.634684797140691; loss: 37.2308, focal_loss: 13.9221, kld: 23.3087
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6616044492999374; loss: 35.8940, focal_loss: 12.7076, kld: 23.1864
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6702036732925035; loss: 35.3262, focal_loss: 12.0032, kld: 23.3230
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.7002617424291004; loss: 34.7698, focal_loss: 11.2542, kld: 23.5156
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.7077552217933889; loss: 34.1353, focal_loss: 10.4798, kld: 23.6556
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.706987371641434; loss: 33.8107, focal_loss: 10.0452, kld: 23.7655
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.7121656087262798; loss: 33.5346, focal_loss: 9.5943, kld: 23.9403
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.7125755396094522; loss: 33.3296, focal_loss: 9.2829, kld: 24.0467
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.7112055791707222; loss: 33.4330, focal_loss: 9.2068, kld: 24.2262




Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5773909511209532; loss: 47.0810, focal_loss: 24.6650, kld: 22.4160
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.611389325570053; loss: 41.7700, focal_loss: 19.1944, kld: 22.5756
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6533684925430137; loss: 40.2334, focal_loss: 17.2663, kld: 22.9671
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6852315174377853; loss: 39.1699, focal_loss: 15.9989, kld: 23.1709
Learning rate: 0.0001


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6944820014711517; loss: 38.7873, focal_loss: 15.0474, kld: 23.7400
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6955529087007191; loss: 37.5948, focal_loss: 13.5665, kld: 24.0283
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.7028341448610842; loss: 37.6946, focal_loss: 13.4776, kld: 24.2170
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.7021459503166989; loss: 37.0938, focal_loss: 12.7035, kld: 24.3902
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.7059200930133958; loss: 36.6841, focal_loss: 12.1369, kld: 24.5472
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6995968260845966; loss: 36.5386, focal_loss: 11.8246, kld: 24.7140




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5365359721260768; loss: 39.8322, focal_loss: 18.2399, kld: 21.5923
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.5913215213691905; loss: 36.5478, focal_loss: 14.8824, kld: 21.6654
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6215771367932743; loss: 35.4321, focal_loss: 13.7589, kld: 21.6732
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6485661284898122; loss: 34.4668, focal_loss: 12.6452, kld: 21.8216
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6678586288767079; loss: 33.9064, focal_loss: 11.8739, kld: 22.0325
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6742728389681126; loss: 33.4883, focal_loss: 11.3363, kld: 22.1521
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.669531332204465; loss: 33.2966, focal_loss: 11.0536, kld: 22.2431
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6796375788237307; loss: 33.1264, focal_loss: 10.8285, kld: 22.2978
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6809519613399011; loss: 32.9695, focal_loss: 10.5493, kld: 22.4202
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6762137289735867; loss: 32.8549, focal_loss: 10.3407, kld: 22.5143




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5736325829081348; loss: 48.0139, focal_loss: 24.3623, kld: 23.6517
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.592514115385837; loss: 43.3224, focal_loss: 19.2537, kld: 24.0686
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6057878676740912; loss: 42.4100, focal_loss: 18.1536, kld: 24.2564
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.625479860849707; loss: 41.5195, focal_loss: 17.0191, kld: 24.5004
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6338493763397234; loss: 40.9068, focal_loss: 16.3455, kld: 24.5612
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6482937001661516; loss: 39.7933, focal_loss: 15.1113, kld: 24.6819
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6468189667976261; loss: 39.4353, focal_loss: 14.7005, kld: 24.7348
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6581842529447597; loss: 39.2989, focal_loss: 14.5016, kld: 24.7973
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6594948073899884; loss: 39.1336, focal_loss: 14.2376, kld: 24.8960
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6605481527165387; loss: 38.7995, focal_loss: 13.7688, kld: 25.0307




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5900686748367515; loss: 41.1575, focal_loss: 18.2622, kld: 22.8953
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6143707228913137; loss: 37.1871, focal_loss: 14.3519, kld: 22.8352
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6383145660889815; loss: 36.3601, focal_loss: 13.5463, kld: 22.8138
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6509553666941371; loss: 35.5685, focal_loss: 12.6469, kld: 22.9216
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6700637683046774; loss: 35.1225, focal_loss: 12.1235, kld: 22.9990
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.673957291070449; loss: 34.5016, focal_loss: 11.4899, kld: 23.0117
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6775163587513314; loss: 34.3196, focal_loss: 11.2963, kld: 23.0233
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6770696440418468; loss: 34.2453, focal_loss: 11.1950, kld: 23.0503
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.68210153991693; loss: 33.9733, focal_loss: 10.8176, kld: 23.1558
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6832861651707263; loss: 33.7543, focal_loss: 10.5146, kld: 23.2397




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.407881197775933; loss: 52.7812, focal_loss: 29.0281, kld: 23.7531
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.5762956735782537; loss: 47.2776, focal_loss: 23.3425, kld: 23.9350
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6104817146921816; loss: 45.0593, focal_loss: 20.8658, kld: 24.1935
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6403714717154991; loss: 44.0089, focal_loss: 19.5708, kld: 24.4381
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6402499211595124; loss: 43.1904, focal_loss: 18.5737, kld: 24.6167
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6538941304979242; loss: 42.3989, focal_loss: 17.5793, kld: 24.8195
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6588317552161657; loss: 42.0037, focal_loss: 17.0949, kld: 24.9088
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6576564763072633; loss: 42.0372, focal_loss: 17.0152, kld: 25.0220
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6644677679827056; loss: 41.7151, focal_loss: 16.5383, kld: 25.1768
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6705521484775538; loss: 41.4075, focal_loss: 16.1660, kld: 25.2414




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5865366258587374; loss: 41.9956, focal_loss: 18.7166, kld: 23.2790
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6251796625710916; loss: 38.1624, focal_loss: 14.5009, kld: 23.6615
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6391103299270772; loss: 37.5561, focal_loss: 13.6777, kld: 23.8784
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6495954994440298; loss: 37.2271, focal_loss: 13.0257, kld: 24.2014
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6661203994187541; loss: 36.8285, focal_loss: 12.3161, kld: 24.5125
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6729009463115194; loss: 36.6363, focal_loss: 11.9230, kld: 24.7133
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6743781744700599; loss: 36.7824, focal_loss: 11.8077, kld: 24.9747
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6785786842356244; loss: 36.9246, focal_loss: 11.5237, kld: 25.4009
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6781116269903046; loss: 36.9239, focal_loss: 11.2362, kld: 25.6878
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6790407907981318; loss: 36.9251, focal_loss: 10.9464, kld: 25.9788




Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.5840579967753734; loss: 50.4213, focal_loss: 27.4881, kld: 22.9333
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.6224938487627113; loss: 43.2390, focal_loss: 20.1280, kld: 23.1110
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.6433842220212654; loss: 41.6461, focal_loss: 18.4621, kld: 23.1840
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.6462341742491288; loss: 40.6877, focal_loss: 17.4267, kld: 23.2610
Learning rate: 5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.6654404302356707; loss: 40.2118, focal_loss: 16.7817, kld: 23.4301
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6659235078632978; loss: 39.4503, focal_loss: 15.8980, kld: 23.5522
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6703135444654977; loss: 39.1389, focal_loss: 15.5178, kld: 23.6211
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6705634142745585; loss: 38.7201, focal_loss: 15.0248, kld: 23.6954
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.675654716404566; loss: 38.3217, focal_loss: 14.5848, kld: 23.7368
Learning rate: 2.5e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6790354017218572; loss: 38.2602, focal_loss: 14.4425, kld: 23.8177




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.32235001524499574; loss: 45.7550, focal_loss: 22.5452, kld: 23.2098
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.4568007706790579; loss: 41.6554, focal_loss: 18.5373, kld: 23.1181
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.5449483691833202; loss: 39.8010, focal_loss: 16.7168, kld: 23.0841
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.5866714838531301; loss: 38.4143, focal_loss: 15.3714, kld: 23.0430
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5997409697892764; loss: 37.6446, focal_loss: 14.6501, kld: 22.9945
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.6038999630653704; loss: 37.3539, focal_loss: 14.3750, kld: 22.9789
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.6140570929373939; loss: 37.1751, focal_loss: 14.2100, kld: 22.9651
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6185583851256045; loss: 36.9434, focal_loss: 13.9945, kld: 22.9490
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6223494719516586; loss: 36.7711, focal_loss: 13.8316, kld: 22.9395
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6256217569916843; loss: 36.6494, focal_loss: 13.7176, kld: 22.9318




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.35763862994767986; loss: 53.2012, focal_loss: 33.3295, kld: 19.8717
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.4688517083316531; loss: 44.1823, focal_loss: 24.2458, kld: 19.9365
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.5463747191865623; loss: 42.7483, focal_loss: 22.7069, kld: 20.0414
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.5731230376051679; loss: 41.7175, focal_loss: 21.5844, kld: 20.1331
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5866978565900355; loss: 41.1188, focal_loss: 20.9509, kld: 20.1679
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.5963408202835783; loss: 40.4773, focal_loss: 20.2942, kld: 20.1830
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.600286194521256; loss: 40.2848, focal_loss: 20.0966, kld: 20.1882
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.6074997605410746; loss: 40.1794, focal_loss: 19.9843, kld: 20.1951
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.6130094790379634; loss: 39.9636, focal_loss: 19.7505, kld: 20.2131
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.6160740408966541; loss: 39.8425, focal_loss: 19.6246, kld: 20.2179




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.3416604049519897; loss: 44.3168, focal_loss: 23.0348, kld: 21.2821
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.43728054407070677; loss: 40.3319, focal_loss: 19.0011, kld: 21.3308
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.48191415734693943; loss: 38.9006, focal_loss: 17.4628, kld: 21.4378
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.518135830018276; loss: 38.1220, focal_loss: 16.6062, kld: 21.5158
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5471902529291484; loss: 37.5066, focal_loss: 15.8888, kld: 21.6178
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.5578629945822504; loss: 37.1159, focal_loss: 15.4569, kld: 21.6590
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.5683872055276995; loss: 36.8290, focal_loss: 15.1339, kld: 21.6951
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.5758844450157736; loss: 36.6537, focal_loss: 14.9340, kld: 21.7197
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.5851797627307238; loss: 36.5231, focal_loss: 14.7716, kld: 21.7515
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.5918224187263104; loss: 36.3593, focal_loss: 14.5800, kld: 21.7793




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.2456172097637306; loss: 56.9166, focal_loss: 35.9289, kld: 20.9877
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.34509874357207243; loss: 52.0437, focal_loss: 31.0895, kld: 20.9542
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.4457897475027427; loss: 46.4927, focal_loss: 25.4913, kld: 21.0014
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.519465312896805; loss: 44.2604, focal_loss: 23.1947, kld: 21.0657
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5406117933455654; loss: 43.3627, focal_loss: 22.2827, kld: 21.0800
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.5487758322787746; loss: 42.5096, focal_loss: 21.4100, kld: 21.0996
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.5583351577282324; loss: 42.1953, focal_loss: 21.0953, kld: 21.1000
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.5635528940263894; loss: 41.8330, focal_loss: 20.7320, kld: 21.1010
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.5683590999647538; loss: 41.5764, focal_loss: 20.4750, kld: 21.1014
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.5749388180908309; loss: 41.4309, focal_loss: 20.3154, kld: 21.1155




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.3215499596989051; loss: 45.0667, focal_loss: 24.1558, kld: 20.9109
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.4348850306984193; loss: 40.8799, focal_loss: 19.8739, kld: 21.0061
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.5132617686221047; loss: 38.8941, focal_loss: 17.8217, kld: 21.0724
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.5403334001307153; loss: 37.8033, focal_loss: 16.6636, kld: 21.1398
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5532603067862975; loss: 37.2201, focal_loss: 15.9990, kld: 21.2210
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.5617891864575284; loss: 36.8887, focal_loss: 15.6122, kld: 21.2764
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.5683414969572471; loss: 36.6223, focal_loss: 15.3179, kld: 21.3044
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.576392524340746; loss: 36.4578, focal_loss: 15.1258, kld: 21.3320
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.5813586903254374; loss: 36.2359, focal_loss: 14.8769, kld: 21.3589
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.5882508665572632; loss: 36.1056, focal_loss: 14.7084, kld: 21.3972




Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:1 AUPRC:0.3370004826796585; loss: 56.4997, focal_loss: 34.7994, kld: 21.7003
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:2 AUPRC:0.4202690536929887; loss: 49.2501, focal_loss: 27.5211, kld: 21.7291
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:3 AUPRC:0.5263242019595066; loss: 45.8602, focal_loss: 24.0635, kld: 21.7967
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:4 AUPRC:0.5560387780349547; loss: 44.4764, focal_loss: 22.6345, kld: 21.8420
Learning rate: 1e-05


  return fn(*args, **kwargs)
                                                                                                                       

epoch:5 AUPRC:0.5750260121458174; loss: 43.4581, focal_loss: 21.6005, kld: 21.8577
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:6 AUPRC:0.5815295163279977; loss: 42.7219, focal_loss: 20.8295, kld: 21.8924
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:7 AUPRC:0.5866831444367091; loss: 42.3840, focal_loss: 20.4995, kld: 21.8845
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:8 AUPRC:0.5912604757930133; loss: 42.1861, focal_loss: 20.3071, kld: 21.8790
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:9 AUPRC:0.595028220801723; loss: 41.9139, focal_loss: 20.0296, kld: 21.8843
Learning rate: 5e-06


  return fn(*args, **kwargs)
                                                                                                                       

epoch:10 AUPRC:0.5975803149785204; loss: 41.7603, focal_loss: 19.8793, kld: 21.8810


