## Train SparseChem on Chembl_mini 
Output to `experiments/SparseChem`

In [28]:
# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:90% !important; }</style>"))
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

import argparse
import sys
import os.path
import time
import json
import functools
import types
import wandb
from datetime import datetime
import pprint
import csv
import copy 
from contextlib import redirect_stdout
import sparsechem as sc
from sparsechem import Nothing
from sparsechem.notebook_modules import (check_for_improvement,init_wandb, initialize,
                                        assertions)
import scipy.io
import scipy.sparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from pytorch_memlab import MemReporter
from pynvml import *

pp = pprint.PrettyPrinter(indent=4)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions( linewidth=132)
os.environ['WANDB_NOTEBOOK_NAME'] = 'SparseChem_Train_mini.ipynb'
if torch.cuda.is_available():
    nvmlInit()

#import warnings
# from torch.serialization import SourceChangeWarning 
#warnings.filterwarnings("ignore", category=UserWarning)    

# import multiprocessing
# multiprocessing.set_start_method('fork', force=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Setup command line parameters

In [29]:
# datadir="../MLDatasets/chembl23_mini"
datadir="../MLDatasets/chembl29"
outdir ="../experiments/cb29-SparseChem"

cmd = (
  f" --data_dir                    {datadir} " +
  f" --output_dir                   {outdir} " +
  f" --x                 chembl_29_x.npy " +
  f" --y_class        chembl_29_thresh_y.npy " +
  f" --folding                   folding.npy " +
  f" --dev                            cuda:0 " +
  f" --fold_va                             0 " +
  f" --fold_inputs                     32000 " +
  f" --batch_ratio                      0.01 " +
  f" --batch_size                       2048 " +
  f" --hidden_sizes                2000 2000 " +
  f" --dropouts_trunk              0.70 0.70 " +
  f" --weight_decay                     1e-5 " +    
  f" --dropouts_class                      0 " +
  f" --epochs                            100 " +
  f" --lr                               1e-3 " +
  f" --lr_steps                           10 " +
  f" --lr_alpha                          0.3 " + 
  f" --prefix                             sc " +
  f" --min_samples_class                   5 "
)

# f" --dev              cuda:0 "
# f" --dev              cuda:0 "

### Initializations 

In [30]:
args = initialize(cmd)
def vprint(s=""):
    if args.verbose:
        print(s)


  command line parms : 
------------------------
 data_dir.................  ../MLDatasets/chembl29
 output_dir...............  ../experiments/cb29-SparseChem
 x........................  chembl_29_x.npy
 y_class..................  chembl_29_thresh_y.npy
 project_name.............  SparseChem-Mini
 exp_id...................  None
 exp_name.................  None
 exp_desc.................  
 folder_sfx...............  None
 hidden_sizes.............  [2000, 2000]
 dropouts_trunk...........  [0.7, 0.7]
 class_feature_size.......  -1
 last_hidden_sizes........  None
 epochs...................  100
 batch_size...............  2048
 weight_decay.............  1e-05
 last_non_linearity.......  relu
 middle_non_linearity.....  relu
 input_transform..........  none
 lr.......................  0.001
 lr_alpha.................  0.3
 lr_steps.................  [10]
 weights_class............  None
 weights_regr.............  None
 fold_va..................  0
 fold_te..................  None
 ba

In [31]:
pp.pprint(vars(args))

{   'batch_ratio': 0.01,
    'batch_size': 2048,
    'censored_loss': 1,
    'class_feature_size': -1,
    'data_dir': '../MLDatasets/chembl29',
    'dev': 'cuda:0',
    'dropouts_class': [0.0],
    'dropouts_reg': [],
    'dropouts_trunk': [0.7, 0.7],
    'enable_cat_fusion': 0,
    'epochs': 100,
    'eval_frequency': 1,
    'eval_train': 0,
    'exp_desc': '',
    'exp_id': 'f0eet9ke',
    'exp_name': '0507_1613',
    'fold_inputs': 32000,
    'fold_te': None,
    'fold_va': 0,
    'folder_sfx': None,
    'folding': '../MLDatasets/chembl29/folding.npy',
    'hdn_layer_size': 2000,
    'hidden_sizes': [2000, 2000],
    'input_size_freq': None,
    'input_transform': 'none',
    'internal_batch_max': None,
    'inverse_normalization': 0,
    'last_hidden_sizes': None,
    'last_hidden_sizes_class': None,
    'last_hidden_sizes_reg': None,
    'last_non_linearity': 'relu',
    'lr': 0.001,
    'lr_alpha': 0.3,
    'lr_steps': [10],
    'middle_non_linearity': 'relu',
    'min_samples_a

### Assertions

In [32]:
assertions(args)

All assertions passed successfully


### Summary writer

In [33]:
if args.profile == 1:
    assert (args.save_board==1), "Tensorboard should be enabled to be able to profile memory usage."
if args.save_board:
    # tb_name = os.path.join(args.output_dir, "", args.name)
    writer  = SummaryWriter(args.output_dir)
else:
    writer = Nothing()
    

### Load datasets

In [34]:
ecfp     = sc.load_sparse(args.x)
y_class  = sc.load_sparse(args.y_class)
y_regr   = sc.load_sparse(args.y_regr)
y_censor = sc.load_sparse(args.y_censor)

if (y_regr is None) and (y_censor is not None):
    raise ValueError("y_censor provided please also provide --y_regr.")
if y_class is None:
    y_class = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
if y_regr is None:
    y_regr  = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
if y_censor is None:
    y_censor = scipy.sparse.csr_matrix(y_regr.shape)

# Load folding
folding = np.load(args.folding)
assert ecfp.shape[0] == folding.shape[0], "x and folding must have same number of rows"

## Loading task weights
tasks_class = sc.load_task_weights(args.weights_class, y=y_class, label="y_class")
tasks_regr  = sc.load_task_weights(args.weights_regr, y=y_regr, label="y_regr")

## Input and folding transformation
ecfp = sc.fold_transform_inputs(ecfp, folding_size=args.fold_inputs, transform=args.input_transform)
print(f"count non zero:{ecfp[0].count_nonzero()}")


num_pos    = np.array((y_class == +1).sum(0)).flatten()
num_neg    = np.array((y_class == -1).sum(0)).flatten()
num_class  = np.array((y_class != 0).sum(0)).flatten()
if (num_class != num_pos + num_neg).any():
    raise ValueError("For classification all y values (--y_class/--y) must be 1 or -1.")

num_regr   = np.bincount(y_regr.indices, minlength=y_regr.shape[1])

assert args.min_samples_auc is None, "Parameter 'min_samples_auc' is obsolete. Use '--min_samples_class' that specifies how many samples a task needs per FOLD and per CLASS to be aggregated."

## Aggregation Weights 
if tasks_class.aggregation_weight is None:
    ## using min_samples rule
    fold_pos, fold_neg = sc.class_fold_counts(y_class, folding)
    n = args.min_samples_class
    tasks_class.aggregation_weight = ((fold_pos >= n).all(0) & (fold_neg >= n)).all(0).astype(np.float64)

if tasks_regr.aggregation_weight is None:
    if y_censor.nnz == 0:
        y_regr2 = y_regr.copy()
        y_regr2.data[:] = 1
    else:
        ## only counting uncensored data
        y_regr2      = y_censor.copy()
        y_regr2.data = (y_regr2.data == 0).astype(np.int32)
    fold_regr, _ = sc.class_fold_counts(y_regr2, folding)
    del y_regr2
    tasks_regr.aggregation_weight = (fold_regr >= args.min_samples_regr).all(0).astype(np.float64)

vprint(f"Input dimension: {ecfp.shape[1]}")
vprint(f"#samples:        {ecfp.shape[0]}")
vprint(f"#classification tasks:  {y_class.shape[1]}")
vprint(f"#regression tasks:      {y_regr.shape[1]}")
vprint(f"Using {(tasks_class.aggregation_weight > 0).sum()} classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).")
vprint(f"Using {(tasks_regr.aggregation_weight > 0).sum()} regression tasks for calculating metrics (RMSE, Rsquared, correlation).")


count non zero:79
Input dimension: 32000
#samples:        423736
#classification tasks:  3552
#regression tasks:      0
Using 1352 classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).
Using 0 regression tasks for calculating metrics (RMSE, Rsquared, correlation).


In [35]:
(tasks_class.aggregation_weight > 0).sum()
fold_pos.shape
print(fold_pos.sum(), fold_pos.sum(axis=-1))
print(fold_pos)
print()
print(fold_neg.sum(), fold_neg.sum(axis=-1))
print(fold_neg)

933144 [184010 183446 187477 194491 183720]
[[28 17  1 ... 31  6  0]
 [23 15 12 ... 44  4  1]
 [25 23  8 ... 61  5  0]
 [ 8  2  0 ... 37  1  0]
 [22 15  8 ... 83  4  0]]

1679106 [333475 335883 337963 337527 334258]
[[  9  20  36 ...  72  89  95]
 [  0   8  11 ... 147 167 170]
 [  1   3  18 ...  56  82  87]
 [  3   9  11 ...  90 121 122]
 [  1   8  15 ...  57 116 117]]


In [36]:
print(tasks_class.aggregation_weight.sum())
print(tasks_class.aggregation_weight)
print(tasks_class.training_weight)
 

1352.0
[0. 0. 0. ... 1. 0. 0.]
tensor([1., 1., 1.,  ..., 1., 1., 1.])


In [37]:
## Separation of test data
if args.fold_te is not None and args.fold_te >= 0:
    ## removing test data
    assert args.fold_te != args.fold_va, "fold_va and fold_te must not be equal."
    keep    = folding != args.fold_te
    ecfp    = ecfp[keep]
    y_class = y_class[keep]
    y_regr  = y_regr[keep]
    y_censor = y_censor[keep]
    folding = folding[keep]

## Regression Normalization    
normalize_inv = None
if args.normalize_regression == 1 and args.normalize_regr_va == 1:
   y_regr, mean_save, var_save = sc.normalize_regr(y_regr)

## Separation of train and Validation data
fold_va = args.fold_va
idx_tr  = np.where(folding != fold_va)[0]
idx_va  = np.where(folding == fold_va)[0]

y_class_tr = y_class[idx_tr]
y_class_va = y_class[idx_va]
y_regr_tr  = y_regr[idx_tr]
y_regr_va  = y_regr[idx_va]
y_censor_tr = y_censor[idx_tr]
y_censor_va = y_censor[idx_va]

## REgression normalization
if args.normalize_regression == 1 and args.normalize_regr_va == 0:
   y_regr_tr, mean_save, var_save = sc.normalize_regr(y_regr_tr) 
   if args.inverse_normalization == 1:
      normalize_inv = {}
      normalize_inv["mean"] = mean_save
      normalize_inv["var"]  = var_save
    
num_pos_va  = np.array((y_class_va == +1).sum(0)).flatten()
num_neg_va  = np.array((y_class_va == -1).sum(0)).flatten()
num_regr_va = np.bincount(y_regr_va.indices, minlength=y_regr.shape[1])
pos_rate = num_pos_va/(num_pos_va+num_neg_va)
pos_rate_ref = args.pi_zero
pos_rate = np.clip(pos_rate, 0, 0.99)
cal_fact_aucpr = pos_rate*(1-pos_rate_ref)/(pos_rate_ref*(1-pos_rate))

vprint(f"Input dimension   : {ecfp.shape[1]}")
vprint(f"Input dimension   : {ecfp.shape[1]}")
vprint(f"Training dataset  : {ecfp[idx_tr].shape}")
vprint(f"Validation dataset: {ecfp[idx_va].shape}")
vprint()
vprint(f"#classification tasks:  {y_class.shape[1]}")
vprint(f"#regression tasks    :      {y_regr.shape[1]}")
vprint(f"Using {(tasks_class.aggregation_weight > 0).sum():3d} classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).")
vprint(f"Using {(tasks_regr.aggregation_weight > 0).sum():3d} regression tasks for calculating metrics (RMSE, Rsquared, correlation).")

Input dimension   : 32000
Input dimension   : 32000
Training dataset  : (340296, 32000)
Validation dataset: (83440, 32000)

#classification tasks:  3552
#regression tasks    :      0
Using 1352 classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).
Using   0 regression tasks for calculating metrics (RMSE, Rsquared, correlation).


  pos_rate = num_pos_va/(num_pos_va+num_neg_va)


### Batch Size Calculation

In [38]:
num_int_batches = 1
if args.batch_size is not None:
    batch_size = args.batch_size
else:
    batch_size = int(np.ceil(args.batch_ratio * idx_tr.shape[0]))

print(f"orig batch size:   {batch_size}")
print(f"orig num int batches:   {num_int_batches}")

if args.internal_batch_max is not None:
    if args.internal_batch_max < batch_size:
        num_int_batches = int(np.ceil(batch_size / args.internal_batch_max))
        batch_size      = int(np.ceil(batch_size / num_int_batches))
print(f"batch size:   {batch_size}")
print(f"num_int_batches:   {num_int_batches}")

orig batch size:   2048
orig num int batches:   1
batch size:   2048
num_int_batches:   1


In [39]:
# #import ipdb; ipdb.set_trace()
# batch_size  = int(np.ceil(args.batch_ratio * idx_tr.shape[0]))
# num_int_batches = 1

# if args.internal_batch_max is not None:
#     if args.internal_batch_max < batch_size:
#         num_int_batches = int(np.ceil(batch_size / args.internal_batch_max))
#         batch_size      = int(np.ceil(batch_size / num_int_batches))
# vprint(f"#internal batch size:   {batch_size}")

In [40]:
tasks_cat_id_list = None
select_cat_ids = None
if tasks_class.cat_id is not None:
    tasks_cat_id_list = [[x,i] for i,x in enumerate(tasks_class.cat_id) if str(x) != 'nan']
    tasks_cat_ids = [i for i,x in enumerate(tasks_class.cat_id) if str(x) != 'nan']
    select_cat_ids = np.array(tasks_cat_ids)
    cat_id_size = len(tasks_cat_id_list)
else:
    cat_id_size = 0

### Dataloaders

In [41]:
dataset_tr = sc.ClassRegrSparseDataset(x=ecfp[idx_tr], y_class=y_class_tr, y_regr=y_regr_tr, y_censor=y_censor_tr, y_cat_columns=select_cat_ids)
dataset_va = sc.ClassRegrSparseDataset(x=ecfp[idx_va], y_class=y_class_va, y_regr=y_regr_va, y_censor=y_censor_va, y_cat_columns=select_cat_ids)

loader_tr = DataLoader(dataset_tr, batch_size=batch_size, num_workers = 8, pin_memory=True, collate_fn=dataset_tr.collate, shuffle=True)
loader_va = DataLoader(dataset_va, batch_size=batch_size, num_workers = 4, pin_memory=True, collate_fn=dataset_va.collate, shuffle=False)

args.input_size  = dataset_tr.input_size
args.output_size = dataset_tr.output_size

args.class_output_size = dataset_tr.class_output_size
args.regr_output_size  = dataset_tr.regr_output_size
args.cat_id_size = cat_id_size



In [42]:

print(f"\n dataset_tr.y_class                                 :  {dataset_tr.y_class.shape}",
      f"\n dataset_va.y_class                                 :  {dataset_va.y_class.shape}",
#       f"\n dataset_test.y_class                                 :  {dataset_va.y_class.shape}",
      f"\n                                ",
      f'\n size of training set                               :  {len(dataset_tr)}',
      f'\n size of validation set                             :  {len(dataset_va)}',
#     #   f'\n size of test set                                   :  {len(dldrs.testset)}',
#     #   f'\n                               Total                :  {len(dldrs.trainset0)+len(dldrs.trainset1)+len(dldrs.trainset2)+len(dldrs.valset)+ len(dldrs.testset)}',
      f"\n                                ",
      f"\n Number of batches in training                      :  {len(loader_tr)}",
      f"\n Number of batches in validation dataset            :  {len(loader_va)}",
    #   f"\n lenght (# batches) in test dataset                 :  {len(dldrs.test_loader)}",
      f"\n                                ")
                


 dataset_tr.y_class                                 :  (340296, 3552) 
 dataset_va.y_class                                 :  (83440, 3552) 
                                 
 size of training set                               :  340296 
 size of validation set                             :  83440 
                                 
 Number of batches in training                      :  167 
 Number of batches in validation dataset            :  41 
                                


###  WandB setup

In [43]:
#------------------------------------------------------------------
# ### WandB setup
#------------------------------------------------------------------
ns = types.SimpleNamespace()
ns.current_epoch  = 0
ns.current_iter   = 0
ns.best_results   = {}
ns.best_metrics   = None
ns.best_value     = 0 
ns.best_iter      = 0
ns.best_epoch     = 0
ns.p_epoch        = 0
ns.num_prints     = 0

init_wandb(ns, args)
wandb.define_metric("best_accuracy", summary="last")
wandb.define_metric("best_epoch", summary="last")

f0eet9ke 0507_1613 SparseChem-Mini


 PROJECT NAME: SparseChem-Mini
 RUN ID      : f0eet9ke 
 RUN NAME    : 0507_1613


<wandb.sdk.wandb_metric.Metric at 0x2af7ce2dc700>

### Network

In [44]:
#------------------------------------------------------------------
# ### Network
#------------------------------------------------------------------
dev  = torch.device(args.dev)

net  = sc.SparseFFN(args).to(dev)
loss_class = torch.nn.BCEWithLogitsLoss(reduction="none")
loss_regr  = sc.censored_mse_loss

if not args.censored_loss:
    loss_regr = functools.partial(loss_regr, censored_enabled=False)

tasks_class.training_weight = tasks_class.training_weight.to(dev)
tasks_regr.training_weight  = tasks_regr.training_weight.to(dev)
tasks_regr.censored_weight  = tasks_regr.censored_weight.to(dev)

###  Optimizer, Scheduler, GradScaler

In [45]:
#------------------------------------------------------------------
# ###  Optimizer, Scheduler, GradScaler
#------------------------------------------------------------------
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_alpha)
scaler = torch.cuda.amp.GradScaler()

wandb.watch(net, log='all', log_freq= 100)     ###  Weights and Biases Initialization 
reporter = None
h = None

### setup memory profiling reporter

In [46]:
if args.profile == 1:
   torch_gpu_id = torch.cuda.current_device()
   if "CUDA_VISIBLE_DEVICES" in os.environ:
      ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
      nvml_gpu_id = ids[torch_gpu_id] # remap
   else:
      nvml_gpu_id = torch_gpu_id
   h = nvmlDeviceGetHandleByIndex(nvml_gpu_id)

if args.profile == 1:
   #####   output saving   #####
   if not os.path.exists(args.output_dir):
       os.makedirs(args.output_dir)

   reporter = MemReporter(net)

   with open(f"{args.output_dir}/memprofile.txt", "w+") as profile_file:
        with redirect_stdout(profile_file):
             profile_file.write(f"\nInitial model detailed report:\n\n")
             reporter.report()

In [47]:
#------------------------------------------------------------------
# ### Display network and other values
#------------------------------------------------------------------
print("Network:")
print(net)
print(optimizer)
print(f"dev                  :    {dev}")
print(f"args.lr              :    {args.lr}")
print(f"args.weight_decay    :    {args.weight_decay}")
print(f"args.lr_steps        :    {args.lr_steps}")
print(f"args.lr_steps        :    {args.lr_steps}")
print(f"num_int_batches      :    {num_int_batches}")
print(f"batch_size           :    {batch_size}")
print(f"current epoch        :    {ns.current_epoch}")
print(f"epochs               :    {args.epochs}")
print(f"scaler               :    {scaler}")
print(f"args.normalize_loss  :    {args.normalize_loss}")
print(f"loss_class           :    {loss_class}")
print(f"mixed precision      :    {args.mixed_precision}")
print(f"args.eval_train      :    {args.eval_train}")

Network:
SparseFFN(
  (net): Sequential(
    (0): SparseInputNet(
      (net_freq): SparseLinear(in_features=32000, out_features=2000, bias=True)
    )
    (1): MiddleNet(
      (net): Sequential(
        (layer_0): Sequential(
          (0): ReLU()
          (1): Dropout(p=0.7, inplace=False)
          (2): Linear(in_features=2000, out_features=2000, bias=True)
        )
      )
    )
  )
  (classLast): LastNet(
    (net): Sequential(
      (initial_layer): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.7, inplace=False)
        (2): Linear(in_features=2000, out_features=3552, bias=True)
      )
    )
  )
  (regrLast): Sequential(
    (0): LastNet(
      (net): Sequential(
        (initial_layer): Sequential(
          (0): Tanh()
          (1): Dropout(p=0.7, inplace=False)
          (2): Linear(in_features=2000, out_features=0, bias=True)
        )
      )
    )
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
 

##  Training Loop

In [49]:
import warnings
# from torch.serialization import SourceChangeWarning 
warnings.filterwarnings("ignore", category=UserWarning) 
progress = True
args.epochs = 250
ns.end_epoch = ns.current_epoch + args.epochs
print(f" Last Epoch: {ns.current_epoch}   # of epochs to do:  {args.epochs} - Run epochs {ns.current_epoch+1} to {ns.end_epoch}")

 Last Epoch: 0   # of epochs to do:  250 - Run epochs 1 to 250


In [50]:
ns.end_epoch = ns.current_epoch + args.epochs

for ns.current_epoch in range(ns.current_epoch+1, ns.end_epoch+1, 1):
    t0 = time.time()
    sc.train_class_regr(
        net, optimizer,
        loader          = loader_tr,
        loss_class      = loss_class,
        loss_regr       = loss_regr,
        dev             = dev,
        weights_class   = tasks_class.training_weight * (1-args.regression_weight) * 2,
        weights_regr    = tasks_regr.training_weight * args.regression_weight * 2,
        censored_weight = tasks_regr.censored_weight,
        normalize_loss  = args.normalize_loss,
        num_int_batches = num_int_batches,
        progress        = progress,
        writer          = writer,
        epoch           = ns.current_epoch,
        args            = args,
        scaler          = scaler,
        nvml_handle     = h)

    if args.profile == 1:
       with open(f"{args.output_dir}/memprofile.txt", "a+") as profile_file:
            profile_file.write(f"\nAfter epoch {epoch} model detailed report:\n\n")
            with redirect_stdout(profile_file):
                 reporter.report()

    t1 = time.time()
    eval_round = (args.eval_frequency > 0) and ((ns.current_epoch + 1) % args.eval_frequency == 0)
    last_round = ns.current_epoch == args.epochs - 1

    if eval_round or last_round:

        results_va = sc.evaluate_class_regr(net, loader_va, loss_class, loss_regr, 
                                            tasks_class= tasks_class, 
                                            tasks_regr = tasks_regr, 
                                            dev        = dev, 
                                            progress   = progress, 
                                            normalize_inv=normalize_inv, 
                                            cal_fact_aucpr=cal_fact_aucpr)
        
        for key, val in results_va["classification_agg"].items():
            writer.add_scalar("val_metrics:aggregated/"+key, val, ns.current_epoch * batch_size)


        if args.eval_train:
            results_tr = sc.evaluate_class_regr(net, loader_tr, loss_class, loss_regr, 
                                                tasks_class = tasks_class, 
                                                tasks_regr  = tasks_regr, 
                                                dev         = dev, 
                                                progress    = progress)
            for key, val in results_tr["classification_agg"].items():
                writer.add_scalar("trn_metrics:aggregated/"+key, val, ns.current_epoch * batch_size)

        else:
            results_tr = None

        if args.verbose:
            ## printing a new header every 20 lines
            header = ns.num_prints % 20 == 0
            ns.num_prints += 1
            sc.print_metrics_cr(ns.current_epoch, t1 - t0, results_tr, results_va, header)
            
        wandb.log(results_va["classification_agg"].to_dict())

        check_for_improvement(ns, results_va)
    
    scheduler.step()

print(f"Best Epoch :       {ns.best_epoch}\n"
      f"Best Iteration :   {ns.best_iter} \n"
      f"Best Precision :   {ns.best_value:.5f}\n")

                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
1      |   0.38637   0.50471   0.67062   0.74233   0.66394   0.38039   0.70275 |       nan       nan       nan |   10.2 
Previous best_epoch:     0   best iter:     0,   best_value: 0.00000
New      best_epoch:     1   best iter:     0,   best_value: 0.67062


                                                                                                                                                                    

2      |   0.36877   0.48001   0.69124   0.76607   0.68448   0.40954   0.71445 |       nan       nan       nan |   10.0 
Previous best_epoch:     1   best iter:     0,   best_value: 0.67062
New      best_epoch:     2   best iter:     0,   best_value: 0.69124


                                                                                                                                                                    

3      |   0.36653   0.47516   0.69868   0.77418   0.69191   0.42079   0.71843 |       nan       nan       nan |   10.0 
Previous best_epoch:     2   best iter:     0,   best_value: 0.69124
New      best_epoch:     3   best iter:     0,   best_value: 0.69868


                                                                                                                                                                    

4      |   0.36383   0.47085   0.70158   0.77716   0.69479   0.42404   0.72003 |       nan       nan       nan |   10.0 
Previous best_epoch:     3   best iter:     0,   best_value: 0.69868
New      best_epoch:     4   best iter:     0,   best_value: 0.70158


                                                                                                                                                                    

5      |   0.36390   0.46997   0.70717   0.78173   0.70051   0.43222   0.72297 |       nan       nan       nan |   10.0 
Previous best_epoch:     4   best iter:     0,   best_value: 0.70158
New      best_epoch:     5   best iter:     0,   best_value: 0.70717


                                                                                                                                                                    

6      |   0.36418   0.47147   0.70752   0.78088   0.70110   0.43404   0.72258 |       nan       nan       nan |   10.3 
Previous best_epoch:     5   best iter:     0,   best_value: 0.70717
New      best_epoch:     6   best iter:     0,   best_value: 0.70752


                                                                                                                                                                    

7      |   0.36467   0.47229   0.70725   0.78150   0.70065   0.43349   0.72319 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

8      |   0.36624   0.47375   0.70941   0.78261   0.70276   0.43448   0.72340 |       nan       nan       nan |   10.3 
Previous best_epoch:     6   best iter:     0,   best_value: 0.70752
New      best_epoch:     8   best iter:     0,   best_value: 0.70941


                                                                                                                                                                    

9      |   0.36718   0.47442   0.70855   0.78134   0.70212   0.43113   0.72397 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

10     |   0.36708   0.47532   0.71100   0.78249   0.70477   0.43517   0.72419 |       nan       nan       nan |   10.3 
Previous best_epoch:     8   best iter:     0,   best_value: 0.70941
New      best_epoch:    10   best iter:     0,   best_value: 0.71100


                                                                                                                                                                    

11     |   0.37109   0.48367   0.71139   0.78317   0.70519   0.43688   0.72472 |       nan       nan       nan |   11.9 
Previous best_epoch:    10   best iter:     0,   best_value: 0.71100
New      best_epoch:    11   best iter:     0,   best_value: 0.71139


                                                                                                                                                                    

12     |   0.37328   0.48739   0.71217   0.78442   0.70581   0.43744   0.72570 |       nan       nan       nan |   10.3 
Previous best_epoch:    11   best iter:     0,   best_value: 0.71139
New      best_epoch:    12   best iter:     0,   best_value: 0.71217


                                                                                                                                                                    

13     |   0.37741   0.49387   0.71262   0.78493   0.70613   0.43774   0.72604 |       nan       nan       nan |   10.9 
Previous best_epoch:    12   best iter:     0,   best_value: 0.71217
New      best_epoch:    13   best iter:     0,   best_value: 0.71262


                                                                                                                                                                    

14     |   0.38180   0.49897   0.71273   0.78454   0.70629   0.43850   0.72624 |       nan       nan       nan |   10.0 
Previous best_epoch:    13   best iter:     0,   best_value: 0.71262
New      best_epoch:    14   best iter:     0,   best_value: 0.71273


                                                                                                                                                                    

15     |   0.38644   0.50522   0.71255   0.78422   0.70611   0.43760   0.72595 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

16     |   0.39101   0.51149   0.71313   0.78467   0.70672   0.43838   0.72574 |       nan       nan       nan |   10.0 
Previous best_epoch:    14   best iter:     0,   best_value: 0.71273
New      best_epoch:    16   best iter:     0,   best_value: 0.71313


                                                                                                                                                                    

17     |   0.39477   0.51655   0.71315   0.78465   0.70680   0.43783   0.72543 |       nan       nan       nan |   10.2 
Previous best_epoch:    16   best iter:     0,   best_value: 0.71313
New      best_epoch:    17   best iter:     0,   best_value: 0.71315


                                                                                                                                                                    

18     |   0.40020   0.52260   0.71363   0.78466   0.70720   0.43948   0.72563 |       nan       nan       nan |   10.4 
Previous best_epoch:    17   best iter:     0,   best_value: 0.71315
New      best_epoch:    18   best iter:     0,   best_value: 0.71363


                                                                                                                                                                    

19     |   0.40315   0.52741   0.71286   0.78363   0.70655   0.43760   0.72508 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

20     |   0.40977   0.53549   0.71367   0.78400   0.70751   0.43979   0.72537 |       nan       nan       nan |   10.4 
Previous best_epoch:    18   best iter:     0,   best_value: 0.71363
New      best_epoch:    20   best iter:     0,   best_value: 0.71367


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
21     |   0.41302   0.53904   0.71300   0.78389   0.70665   0.43849   0.72522 |       nan       nan       nan |   11.1 


                                                                                                                                                                    

22     |   0.41693   0.54415   0.71360   0.78383   0.70735   0.43805   0.72509 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

23     |   0.42087   0.54916   0.71245   0.78339   0.70605   0.43602   0.72507 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

24     |   0.42728   0.55897   0.71232   0.78294   0.70597   0.43508   0.72506 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

25     |   0.43301   0.56517   0.71218   0.78286   0.70572   0.43468   0.72511 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

26     |   0.43561   0.56811   0.71234   0.78308   0.70602   0.43484   0.72521 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

27     |   0.43949   0.57256   0.71178   0.78262   0.70530   0.43469   0.72513 |       nan       nan       nan |   11.8 


                                                                                                                                                                    

28     |   0.44499   0.57906   0.71246   0.78314   0.70610   0.43487   0.72503 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

29     |   0.45118   0.58617   0.71272   0.78317   0.70638   0.43537   0.72541 |       nan       nan       nan |   11.0 


                                                                                                                                                                    

30     |   0.45181   0.58767   0.71179   0.78249   0.70525   0.43337   0.72475 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

31     |   0.45559   0.59255   0.71216   0.78305   0.70564   0.43379   0.72535 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

32     |   0.45829   0.59677   0.71258   0.78267   0.70614   0.43442   0.72462 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

33     |   0.46360   0.60186   0.71164   0.78230   0.70514   0.43381   0.72412 |       nan       nan       nan |   11.2 


                                                                                                                                                                    

34     |   0.47004   0.60899   0.71148   0.78175   0.70507   0.43299   0.72365 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

35     |   0.47139   0.61104   0.71226   0.78190   0.70588   0.43356   0.72411 |       nan       nan       nan |   11.0 


                                                                                                                                                                    

36     |   0.47927   0.62028   0.71115   0.78148   0.70469   0.43232   0.72354 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

37     |   0.47804   0.62001   0.71146   0.78138   0.70493   0.43260   0.72401 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

38     |   0.48469   0.62710   0.71152   0.78123   0.70508   0.43294   0.72388 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

39     |   0.48559   0.62906   0.71148   0.78140   0.70524   0.43325   0.72367 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

40     |   0.48996   0.63508   0.71123   0.78078   0.70495   0.43290   0.72335 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
41     |   0.48872   0.63376   0.71182   0.78089   0.70545   0.43349   0.72331 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

42     |   0.49568   0.64127   0.71084   0.78045   0.70436   0.43160   0.72332 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

43     |   0.49703   0.64301   0.71067   0.78032   0.70414   0.43100   0.72294 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

44     |   0.50252   0.64989   0.71079   0.78052   0.70436   0.43116   0.72303 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

45     |   0.50316   0.65045   0.71093   0.78076   0.70445   0.43149   0.72313 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

46     |   0.50672   0.65644   0.71051   0.78000   0.70392   0.43072   0.72366 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

47     |   0.51053   0.66107   0.71034   0.77943   0.70389   0.43002   0.72251 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

48     |   0.51633   0.66739   0.71006   0.77940   0.70359   0.42945   0.72214 |       nan       nan       nan |   12.9 


                                                                                                                                                                    

49     |   0.51713   0.66776   0.70993   0.77968   0.70347   0.42947   0.72224 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

50     |   0.51637   0.66648   0.70995   0.77933   0.70362   0.42927   0.72216 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

51     |   0.52210   0.67279   0.71044   0.77999   0.70398   0.43063   0.72252 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

52     |   0.52208   0.67387   0.71008   0.77945   0.70361   0.42981   0.72215 |       nan       nan       nan |   11.5 


                                                                                                                                                                    

53     |   0.52206   0.67592   0.71026   0.77947   0.70380   0.42968   0.72236 |       nan       nan       nan |   12.8 


                                                                                                                                                                    

54     |   0.52798   0.68087   0.70941   0.77913   0.70299   0.42782   0.72189 |       nan       nan       nan |   12.8 


                                                                                                                                                                    

55     |   0.53129   0.68668   0.70876   0.77838   0.70216   0.42702   0.72166 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

56     |   0.52854   0.68287   0.70908   0.77846   0.70252   0.42690   0.72173 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

57     |   0.53715   0.69363   0.70803   0.77825   0.70138   0.42567   0.72110 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

58     |   0.53713   0.69319   0.70843   0.77833   0.70195   0.42645   0.72141 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

59     |   0.53692   0.69266   0.70804   0.77788   0.70154   0.42556   0.72071 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

60     |   0.53622   0.69251   0.70909   0.77804   0.70268   0.42710   0.72124 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
61     |   0.54237   0.70124   0.70825   0.77767   0.70176   0.42570   0.72090 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

62     |   0.54518   0.70442   0.70777   0.77749   0.70131   0.42553   0.72075 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

63     |   0.54137   0.70015   0.70814   0.77739   0.70167   0.42501   0.72062 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

64     |   0.54557   0.70610   0.70739   0.77733   0.70089   0.42433   0.72083 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

65     |   0.54968   0.70974   0.70755   0.77717   0.70102   0.42412   0.72097 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

66     |   0.54748   0.70592   0.70758   0.77750   0.70106   0.42400   0.72090 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

67     |   0.55453   0.71482   0.70773   0.77761   0.70130   0.42413   0.72069 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

68     |   0.55204   0.71066   0.70752   0.77721   0.70106   0.42281   0.72063 |       nan       nan       nan |   11.4 


                                                                                                                                                                    

69     |   0.54952   0.70919   0.70777   0.77767   0.70127   0.42373   0.72117 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

70     |   0.55984   0.72035   0.70742   0.77721   0.70096   0.42378   0.72074 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

71     |   0.56021   0.72097   0.70752   0.77684   0.70106   0.42373   0.72103 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

72     |   0.56122   0.72200   0.70750   0.77710   0.70100   0.42335   0.72101 |       nan       nan       nan |   13.4 


                                                                                                                                                                    

73     |   0.56298   0.72455   0.70716   0.77677   0.70060   0.42229   0.72146 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

74     |   0.56417   0.72597   0.70656   0.77632   0.69990   0.42158   0.72086 |       nan       nan       nan |   10.7 


                                                                                                                                                                    

75     |   0.56105   0.72119   0.70702   0.77634   0.70044   0.42320   0.72110 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

76     |   0.56600   0.72740   0.70695   0.77653   0.70031   0.42224   0.72134 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

77     |   0.57031   0.73365   0.70741   0.77665   0.70082   0.42204   0.72060 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

78     |   0.56719   0.72867   0.70693   0.77644   0.70027   0.42232   0.72047 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

79     |   0.56998   0.73158   0.70701   0.77620   0.70049   0.42133   0.72060 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

80     |   0.57153   0.73549   0.70641   0.77523   0.69995   0.42070   0.71999 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
81     |   0.57261   0.73547   0.70585   0.77522   0.69937   0.42017   0.71959 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

82     |   0.57544   0.73848   0.70520   0.77479   0.69862   0.41991   0.71959 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

83     |   0.57435   0.73823   0.70622   0.77490   0.69966   0.42035   0.71984 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

84     |   0.57917   0.74704   0.70586   0.77511   0.69933   0.41961   0.71898 |       nan       nan       nan |   11.8 


                                                                                                                                                                    

85     |   0.57740   0.74084   0.70579   0.77533   0.69912   0.41906   0.71921 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

86     |   0.57783   0.74111   0.70612   0.77547   0.69957   0.42032   0.71920 |       nan       nan       nan |   11.1 


                                                                                                                                                                    

87     |   0.57874   0.74206   0.70555   0.77456   0.69898   0.41939   0.71967 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

88     |   0.58240   0.74623   0.70556   0.77493   0.69908   0.41993   0.71980 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

89     |   0.58037   0.74247   0.70565   0.77463   0.69899   0.41922   0.71956 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

90     |   0.58378   0.74921   0.70506   0.77351   0.69857   0.41827   0.71885 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

91     |   0.58192   0.74581   0.70509   0.77377   0.69858   0.41851   0.71930 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

92     |   0.58744   0.75245   0.70473   0.77383   0.69820   0.41791   0.71895 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

93     |   0.58690   0.75412   0.70439   0.77347   0.69777   0.41765   0.71874 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

94     |   0.59044   0.75554   0.70385   0.77315   0.69727   0.41649   0.71850 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

95     |   0.58469   0.74879   0.70470   0.77386   0.69807   0.41753   0.71916 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

96     |   0.58504   0.74972   0.70494   0.77393   0.69832   0.41756   0.71932 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

97     |   0.58752   0.75326   0.70425   0.77356   0.69775   0.41687   0.71920 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

98     |   0.59021   0.75450   0.70509   0.77339   0.69865   0.41700   0.71919 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

99     |   0.58961   0.75171   0.70448   0.77333   0.69797   0.41643   0.71878 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

100    |   0.58851   0.75431   0.70402   0.77281   0.69763   0.41650   0.71840 |       nan       nan       nan |   10.5 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
101    |   0.59044   0.75388   0.70371   0.77333   0.69708   0.41584   0.71829 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

102    |   0.59603   0.76059   0.70463   0.77368   0.69817   0.41657   0.71911 |       nan       nan       nan |   10.1 


                                                                                                                                                                    

103    |   0.59562   0.76339   0.70331   0.77262   0.69674   0.41421   0.71858 |       nan       nan       nan |   11.6 


                                                                                                                                                                    

104    |   0.59474   0.76198   0.70306   0.77229   0.69658   0.41454   0.71826 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

105    |   0.59527   0.76146   0.70391   0.77279   0.69747   0.41551   0.71809 |       nan       nan       nan |   10.5 


                                                                                                                                                                    

106    |   0.59425   0.76173   0.70513   0.77407   0.69868   0.41681   0.71874 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

107    |   0.60017   0.76595   0.70421   0.77322   0.69770   0.41475   0.71850 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

108    |   0.59363   0.75936   0.70471   0.77330   0.69820   0.41510   0.71877 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

109    |   0.59518   0.75998   0.70475   0.77359   0.69816   0.41532   0.71867 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

110    |   0.59289   0.75952   0.70503   0.77349   0.69856   0.41637   0.71946 |       nan       nan       nan |   10.5 


                                                                                                                                                                    

111    |   0.59972   0.76856   0.70378   0.77218   0.69722   0.41331   0.71891 |       nan       nan       nan |   12.1 


                                                                                                                                                                    

112    |   0.59477   0.76129   0.70385   0.77269   0.69718   0.41301   0.71829 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

113    |   0.60180   0.76892   0.70342   0.77200   0.69683   0.41290   0.71812 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

114    |   0.60173   0.76909   0.70268   0.77149   0.69605   0.41134   0.71771 |       nan       nan       nan |   10.1 


                                                                                                                                                                    

115    |   0.60262   0.77039   0.70317   0.77175   0.69669   0.41315   0.71781 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

116    |   0.60368   0.77036   0.70300   0.77159   0.69637   0.41150   0.71759 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

117    |   0.60038   0.76738   0.70335   0.77209   0.69678   0.41287   0.71789 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

118    |   0.60454   0.77425   0.70271   0.77144   0.69605   0.41197   0.71779 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

119    |   0.60212   0.77195   0.70301   0.77139   0.69639   0.41272   0.71782 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

120    |   0.60021   0.76896   0.70312   0.77155   0.69663   0.41284   0.71768 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
121    |   0.60708   0.77498   0.70338   0.77181   0.69688   0.41311   0.71850 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

122    |   0.60309   0.77297   0.70240   0.77135   0.69588   0.41207   0.71763 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

123    |   0.60409   0.77421   0.70246   0.77094   0.69589   0.41168   0.71686 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

124    |   0.60831   0.77637   0.70275   0.77130   0.69614   0.41217   0.71726 |       nan       nan       nan |   11.4 


                                                                                                                                                                    

125    |   0.61204   0.78044   0.70308   0.77171   0.69651   0.41226   0.71770 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

126    |   0.60643   0.77374   0.70447   0.77257   0.69804   0.41449   0.71861 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

127    |   0.60589   0.77197   0.70446   0.77211   0.69805   0.41435   0.71832 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

128    |   0.60957   0.77883   0.70356   0.77195   0.69702   0.41366   0.71815 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

129    |   0.61189   0.77996   0.70315   0.77118   0.69668   0.41292   0.71774 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

130    |   0.61044   0.77938   0.70317   0.77072   0.69677   0.41165   0.71766 |       nan       nan       nan |   10.5 


                                                                                                                                                                    

131    |   0.61269   0.78003   0.70256   0.77118   0.69596   0.41104   0.71703 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

132    |   0.60867   0.77538   0.70251   0.77116   0.69590   0.41122   0.71743 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

133    |   0.61449   0.78303   0.70255   0.77110   0.69595   0.41179   0.71712 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

134    |   0.60981   0.77942   0.70274   0.77142   0.69622   0.41221   0.71751 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

135    |   0.60776   0.77727   0.70232   0.77133   0.69594   0.41048   0.71760 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

136    |   0.60939   0.77685   0.70232   0.77075   0.69584   0.41120   0.71743 |       nan       nan       nan |   10.7 


                                                                                                                                                                    

137    |   0.60938   0.77814   0.70200   0.77049   0.69550   0.41074   0.71770 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

138    |   0.61632   0.78658   0.70090   0.77059   0.69430   0.40914   0.71711 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

139    |   0.60807   0.77625   0.70142   0.77085   0.69467   0.41064   0.71702 |       nan       nan       nan |   10.1 


                                                                                                                                                                    

140    |   0.61662   0.78695   0.70098   0.76989   0.69430   0.40949   0.71719 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
141    |   0.61446   0.78330   0.70156   0.77047   0.69502   0.40976   0.71664 |       nan       nan       nan |   12.2 


                                                                                                                                                                    

142    |   0.61905   0.78767   0.70065   0.77004   0.69382   0.40960   0.71620 |       nan       nan       nan |   11.9 


                                                                                                                                                                    

143    |   0.61458   0.78298   0.70128   0.76998   0.69473   0.40971   0.71596 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

144    |   0.60987   0.77811   0.70191   0.77011   0.69530   0.41059   0.71727 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

145    |   0.61232   0.77968   0.70133   0.77000   0.69470   0.40985   0.71723 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

146    |   0.61785   0.78672   0.70035   0.76934   0.69366   0.40888   0.71663 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

147    |   0.61323   0.78232   0.70048   0.76959   0.69385   0.40824   0.71620 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

148    |   0.61282   0.78102   0.70044   0.76922   0.69371   0.40850   0.71665 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

149    |   0.61680   0.78572   0.70026   0.76954   0.69352   0.40743   0.71629 |       nan       nan       nan |   10.5 


                                                                                                                                                                    

150    |   0.62613   0.79834   0.69944   0.76908   0.69278   0.40722   0.71574 |       nan       nan       nan |   12.3 


                                                                                                                                                                    

151    |   0.61815   0.78819   0.69957   0.76916   0.69281   0.40597   0.71616 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

152    |   0.62667   0.79953   0.69910   0.76825   0.69233   0.40515   0.71566 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

153    |   0.62027   0.79201   0.69959   0.76900   0.69295   0.40627   0.71622 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

154    |   0.62214   0.79228   0.69983   0.76910   0.69312   0.40727   0.71605 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

155    |   0.61410   0.78502   0.69964   0.76877   0.69315   0.40723   0.71596 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

156    |   0.61797   0.78715   0.69916   0.76922   0.69256   0.40669   0.71548 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

157    |   0.61785   0.78744   0.70037   0.76936   0.69384   0.40746   0.71617 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

158    |   0.61918   0.78726   0.70053   0.76927   0.69402   0.40821   0.71557 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

159    |   0.62193   0.79098   0.70042   0.76928   0.69393   0.40844   0.71560 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

160    |   0.62269   0.79115   0.70051   0.76922   0.69393   0.40728   0.71609 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
161    |   0.62463   0.79417   0.70049   0.76871   0.69400   0.40743   0.71539 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

162    |   0.62098   0.79073   0.70039   0.76904   0.69400   0.40686   0.71641 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

163    |   0.61918   0.78988   0.70022   0.76874   0.69371   0.40603   0.71566 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

164    |   0.62632   0.79720   0.70003   0.76887   0.69342   0.40681   0.71589 |       nan       nan       nan |   10.7 


                                                                                                                                                                    

165    |   0.61931   0.78604   0.70027   0.76932   0.69368   0.40693   0.71570 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

166    |   0.62712   0.79510   0.70033   0.76892   0.69375   0.40647   0.71586 |       nan       nan       nan |   12.7 


                                                                                                                                                                    

167    |   0.62703   0.79595   0.69964   0.76827   0.69305   0.40560   0.71566 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

168    |   0.62471   0.79483   0.70062   0.76890   0.69411   0.40676   0.71623 |       nan       nan       nan |   11.4 


                                                                                                                                                                    

169    |   0.62411   0.79199   0.69978   0.76895   0.69309   0.40707   0.71601 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

170    |   0.62651   0.79409   0.70026   0.76863   0.69354   0.40627   0.71670 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

171    |   0.62783   0.79591   0.70083   0.76948   0.69420   0.40800   0.71624 |       nan       nan       nan |   10.7 


                                                                                                                                                                    

172    |   0.62765   0.79612   0.70020   0.76877   0.69354   0.40717   0.71537 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

173    |   0.62126   0.78983   0.70024   0.76883   0.69360   0.40725   0.71538 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

174    |   0.62506   0.79360   0.70031   0.76914   0.69366   0.40712   0.71515 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

175    |   0.62367   0.79046   0.69995   0.76858   0.69337   0.40724   0.71560 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

176    |   0.63184   0.80351   0.69911   0.76789   0.69233   0.40535   0.71487 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

177    |   0.62343   0.79300   0.69967   0.76837   0.69323   0.40657   0.71469 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

178    |   0.62606   0.79683   0.69951   0.76839   0.69300   0.40577   0.71556 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

179    |   0.62365   0.79347   0.69972   0.76889   0.69319   0.40652   0.71519 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

180    |   0.62697   0.79722   0.69960   0.76859   0.69304   0.40690   0.71555 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
181    |   0.62806   0.79771   0.69981   0.76882   0.69330   0.40766   0.71526 |       nan       nan       nan |   12.2 


                                                                                                                                                                    

182    |   0.62516   0.79307   0.69932   0.76863   0.69281   0.40639   0.71535 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

183    |   0.62863   0.79968   0.69940   0.76868   0.69282   0.40561   0.71523 |       nan       nan       nan |   11.4 


                                                                                                                                                                    

184    |   0.62328   0.79191   0.69893   0.76822   0.69219   0.40491   0.71527 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

185    |   0.62683   0.79610   0.69957   0.76841   0.69302   0.40562   0.71497 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

186    |   0.62927   0.79972   0.69924   0.76782   0.69272   0.40541   0.71509 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

187    |   0.62420   0.79387   0.69950   0.76812   0.69295   0.40602   0.71531 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

188    |   0.62562   0.79651   0.69927   0.76799   0.69277   0.40550   0.71529 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

189    |   0.63145   0.80145   0.69898   0.76809   0.69233   0.40532   0.71510 |       nan       nan       nan |   11.6 


                                                                                                                                                                    

190    |   0.62988   0.79890   0.69810   0.76730   0.69141   0.40279   0.71431 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

191    |   0.62690   0.79769   0.69834   0.76744   0.69163   0.40373   0.71439 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

192    |   0.62819   0.79847   0.69725   0.76605   0.69058   0.40220   0.71359 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

193    |   0.62639   0.79660   0.69751   0.76589   0.69090   0.40229   0.71350 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

194    |   0.63743   0.80727   0.69740   0.76591   0.69069   0.40174   0.71393 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

195    |   0.62900   0.79936   0.69723   0.76641   0.69066   0.40231   0.71355 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

196    |   0.62945   0.80139   0.69834   0.76642   0.69188   0.40297   0.71387 |       nan       nan       nan |   11.2 


                                                                                                                                                                    

197    |   0.63002   0.80054   0.69818   0.76672   0.69174   0.40329   0.71432 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

198    |   0.63062   0.80089   0.69824   0.76709   0.69171   0.40359   0.71436 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

199    |   0.62909   0.80122   0.69822   0.76695   0.69171   0.40336   0.71424 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

200    |   0.62642   0.79734   0.69864   0.76730   0.69217   0.40412   0.71452 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
201    |   0.62420   0.79397   0.69913   0.76809   0.69262   0.40567   0.71459 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

202    |   0.62986   0.79916   0.69906   0.76769   0.69264   0.40521   0.71462 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

203    |   0.62985   0.79970   0.69772   0.76624   0.69126   0.40279   0.71373 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

204    |   0.63184   0.80509   0.69749   0.76640   0.69074   0.40131   0.71386 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

205    |   0.62986   0.80119   0.69753   0.76666   0.69092   0.40247   0.71402 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

206    |   0.63187   0.80158   0.69801   0.76622   0.69136   0.40254   0.71392 |       nan       nan       nan |   11.1 


                                                                                                                                                                    

207    |   0.62817   0.79824   0.69777   0.76652   0.69100   0.40188   0.71385 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

208    |   0.63244   0.80124   0.69812   0.76684   0.69152   0.40299   0.71366 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

209    |   0.63510   0.80812   0.69756   0.76641   0.69091   0.40171   0.71399 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

210    |   0.63017   0.79968   0.69775   0.76675   0.69113   0.40156   0.71432 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

211    |   0.63289   0.80249   0.69772   0.76647   0.69100   0.40156   0.71357 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

212    |   0.63184   0.80269   0.69641   0.76638   0.68957   0.39979   0.71337 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

213    |   0.63402   0.80806   0.69687   0.76616   0.69006   0.40051   0.71342 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

214    |   0.63267   0.80607   0.69733   0.76625   0.69057   0.39998   0.71422 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

215    |   0.63356   0.80410   0.69660   0.76575   0.68979   0.39897   0.71367 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

216    |   0.63210   0.80301   0.69653   0.76515   0.68977   0.39883   0.71374 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

217    |   0.63072   0.80078   0.69671   0.76584   0.68987   0.39909   0.71363 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

218    |   0.63689   0.81035   0.69651   0.76520   0.68966   0.39822   0.71338 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

219    |   0.63555   0.80674   0.69746   0.76630   0.69069   0.40064   0.71415 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

220    |   0.63265   0.80420   0.69700   0.76608   0.69028   0.39965   0.71389 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
221    |   0.63155   0.80244   0.69765   0.76681   0.69091   0.40154   0.71398 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

222    |   0.63306   0.80409   0.69688   0.76608   0.69007   0.39949   0.71380 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

223    |   0.63628   0.80738   0.69726   0.76628   0.69050   0.40069   0.71344 |       nan       nan       nan |   13.0 


                                                                                                                                                                    

224    |   0.63371   0.80464   0.69664   0.76597   0.68997   0.39974   0.71318 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

225    |   0.63731   0.80700   0.69718   0.76533   0.69045   0.39982   0.71354 |       nan       nan       nan |   11.0 


                                                                                                                                                                    

226    |   0.63817   0.80930   0.69735   0.76611   0.69056   0.40079   0.71401 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

227    |   0.63219   0.80371   0.69703   0.76609   0.69031   0.39990   0.71392 |       nan       nan       nan |   11.0 


                                                                                                                                                                    

228    |   0.63301   0.80414   0.69612   0.76555   0.68929   0.39827   0.71395 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

229    |   0.63642   0.80645   0.69690   0.76654   0.69005   0.39931   0.71421 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

230    |   0.63740   0.80884   0.69655   0.76581   0.68981   0.39916   0.71369 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

231    |   0.63530   0.80454   0.69665   0.76550   0.68992   0.40010   0.71335 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

232    |   0.63435   0.80203   0.69627   0.76529   0.68969   0.39927   0.71394 |       nan       nan       nan |   10.4 


                                                                                                                                                                    

233    |   0.64137   0.81100   0.69615   0.76534   0.68948   0.39845   0.71434 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

234    |   0.63410   0.80188   0.69649   0.76535   0.68978   0.39861   0.71384 |       nan       nan       nan |   10.2 


                                                                                                                                                                    

235    |   0.63635   0.80601   0.69632   0.76502   0.68959   0.39948   0.71387 |       nan       nan       nan |   10.6 


                                                                                                                                                                    

236    |   0.64109   0.81142   0.69574   0.76430   0.68910   0.39865   0.71317 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

237    |   0.63687   0.80846   0.69586   0.76408   0.68922   0.39901   0.71287 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

238    |   0.63798   0.80631   0.69626   0.76516   0.68955   0.39966   0.71380 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

239    |   0.63813   0.80639   0.69673   0.76527   0.69007   0.39965   0.71400 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

240    |   0.64325   0.81113   0.69677   0.76483   0.69009   0.39997   0.71366 |       nan       nan       nan |    9.9 


                                                                                                                                                                    

Epoch  |      logl   bceloss  avg prec   auc roc    auc pr aucpr_cal    f1_max |      rmse  rsquared  corrcoef | tr_time 
241    |   0.64456   0.81346   0.69618   0.76469   0.68949   0.39844   0.71312 |       nan       nan       nan |   10.0 


                                                                                                                                                                    

242    |   0.63942   0.80895   0.69706   0.76499   0.69028   0.39936   0.71359 |       nan       nan       nan |   11.7 


                                                                                                                                                                    

243    |   0.64301   0.81473   0.69614   0.76448   0.68943   0.39816   0.71374 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

244    |   0.64186   0.81096   0.69584   0.76474   0.68896   0.39834   0.71342 |       nan       nan       nan |   10.8 


                                                                                                                                                                    

245    |   0.64045   0.81071   0.69626   0.76493   0.68950   0.39899   0.71328 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

246    |   0.64038   0.80913   0.69646   0.76474   0.68978   0.39948   0.71346 |       nan       nan       nan |   11.3 


                                                                                                                                                                    

247    |   0.64203   0.81189   0.69535   0.76415   0.68847   0.39776   0.71245 |       nan       nan       nan |   10.3 


                                                                                                                                                                    

248    |   0.64024   0.80895   0.69533   0.76446   0.68865   0.39847   0.71279 |       nan       nan       nan |   10.9 


                                                                                                                                                                    

249    |   0.64387   0.81397   0.69513   0.76459   0.68825   0.39800   0.71290 |       nan       nan       nan |   10.7 


                                                                                                                                                                    

250    |   0.63683   0.80534   0.69553   0.76505   0.68875   0.39846   0.71313 |       nan       nan       nan |   10.6 
Best Epoch :       20
Best Iteration :   0 
Best Precision :   0.71367



In [51]:
print(f"Best Epoch :       {ns.best_epoch}\n"
      f"Best Iteration :   {ns.best_iter} \n"
      f"Best Precision :   {ns.best_value:.5f}\n")
print()
for k in results_va['classification_agg'].keys():
    print(f" {k:20s}  {results_va['classification_agg'][k]:.4f}")

Best Epoch :       20
Best Iteration :   0 
Best Precision :   0.71367


 roc_auc_score         0.7650
 auc_pr                0.6888
 avg_prec_score        0.6955
 f1_max                0.7131
 p_f1_max              0.2840
 kappa                 0.3299
 kappa_max             0.4507
 p_kappa_max           0.5042
 bceloss               0.8053
 auc_pr_cal            0.3985
 logloss               0.6368
 num_tasks_total       3552.0000
 num_tasks_agg         1352.0000


In [52]:
pd.options.display.width = 150
df = results_va['classification']
print(df[pd.notna(df.roc_auc_score)])

      roc_auc_score    auc_pr  avg_prec_score    f1_max  p_f1_max     kappa  kappa_max  p_kappa_max   bceloss  auc_pr_cal
task                                                                                                                     
0          0.712302  0.881268        0.879764  0.870968  0.506769  0.241026   0.364807     0.586293  0.639831    0.332115
1          0.745588  0.693746        0.703821  0.708333  0.000592  0.134165   0.409289     0.012586  1.294072    0.297643
2          0.541667  0.027778        0.055556  0.105263  0.000681  0.000000   0.056972     0.000681  0.204779    0.095238
4          0.721374  0.963379        0.963562  0.952727  0.231102  0.150805   0.224599     0.986254  0.535628    0.500362
5          0.819940  0.901806        0.902461  0.845771  0.214211  0.481162   0.526847     0.767619  0.742919    0.492405
...             ...       ...             ...       ...       ...       ...        ...          ...       ...         ...
3546       0.762585  0.8

## Post Training 

In [53]:
#print("DEBUG data for hidden spliting")
#print (f"Classification mask: Sum = {net.classmask.sum()}\t Uniques: {np.unique(net.classmask)}")
#print (f"Regression mask:     Sum = {net.regmask.sum()}\t Uniques: {np.unique(net.regmask)}")
#print (f"overlap: {(net.regmask * net.classmask).sum()}")

writer.close()
vprint()
if args.profile == 1:
   multiplexer = sc.create_multiplexer(tb_name)
#   sc.export_scalars(multiplexer, '.', "GPUmem", "testcsv.csv")
   data = sc.extract_scalars(multiplexer, '.', "GPUmem")
   vprint(f"Peak GPU memory used: {sc.return_max_val(data)}MB")
vprint("Saving performance metrics (AUCs) and model.")

#####   model saving   #####
if not os.path.exists(args.output_dir):
   os.makedirs(args.output_dir)

model_file = f"{args.output_dir}/{args.name}.pt"
out_file   = f"{args.output_dir}/{args.name}.json"

if args.save_model:
   torch.save(net.state_dict(), model_file)
   vprint(f"Saved model weights into '{model_file}'.")

results_va["classification"]["num_pos"] = num_pos_va
results_va["classification"]["num_neg"] = num_neg_va
results_va["regression"]["num_samples"] = num_regr_va

if results_tr is not None:
    results_tr["classification"]["num_pos"] = num_pos - num_pos_va
    results_tr["classification"]["num_neg"] = num_neg - num_neg_va
    results_tr["regression"]["num_samples"] = num_regr - num_regr_va

stats=None
if args.normalize_regression == 1 :
   stats={}
   stats["mean"] = mean_save
   stats["var"]  = np.array(var_save)[0]
sc.save_results(out_file, args, validation=results_va, training=results_tr, stats=stats)

vprint(f"Saved config and results into '{out_file}'.\nYou can load the results by:\n  import sparsechem as sc\n  res = sc.load_results('{out_file}')")


Saving performance metrics (AUCs) and model.
Saved model weights into '../experiments/cb29-SparseChem/2000x1_0507_1613_lr0.001_do0.7/sc_2000.2000_lr0.001_do0.7.pt'.
Saved config and results into '../experiments/cb29-SparseChem/2000x1_0507_1613_lr0.001_do0.7/sc_2000.2000_lr0.001_do0.7.json'.
You can load the results by:
  import sparsechem as sc
  res = sc.load_results('../experiments/cb29-SparseChem/2000x1_0507_1613_lr0.001_do0.7/sc_2000.2000_lr0.001_do0.7.json')


In [54]:
print()
print(results_va['classification'][0:20])
print()
print(results_va.keys())
pp.pprint(results_va['classification_agg'])


      roc_auc_score    auc_pr  avg_prec_score    f1_max  p_f1_max     kappa  kappa_max  p_kappa_max   bceloss  auc_pr_cal  num_pos  num_neg
task                                                                                                                                       
0          0.712302  0.881268        0.879764  0.870968  0.506769  0.241026   0.364807     0.586293  0.639831    0.332115       28        9
1          0.745588  0.693746        0.703821  0.708333  0.000592  0.134165   0.409289     0.012586  1.294072    0.297643       17       20
2          0.541667  0.027778        0.055556  0.105263  0.000681  0.000000   0.056972     0.000681  0.204779    0.095238        1       36
3               NaN       NaN             NaN       NaN       NaN       NaN        NaN          NaN       NaN         NaN        0       37
4          0.721374  0.963379        0.963562  0.952727  0.231102  0.150805   0.224599     0.986254  0.535628    0.500362      131       14
5          0.819940

In [55]:
ns.wandb_run.finish()




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
auc_pr,▁▇████▇▇▇▆▆▆▆▆▅▅▅▅▅▄▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▃▂
auc_pr_cal,▃▇██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▄▃▃▃▃▃▂▃▂▂▂▂▂▂▁▁▁▁▁▁
avg_prec_score,▁▇████▇▇▇▆▆▆▆▆▅▅▅▅▅▄▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▃▂
bceloss,▁▁▂▂▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████
best_accuracy,▁▄▆▆▇▇▇█████████
best_epoch,▁▁▂▂▂▃▄▄▅▅▅▆▇▇▇█
f1_max,▂▇██▇▇▆▇▆▅▅▅▅▄▄▄▄▄▃▄▄▃▃▃▃▃▃▂▂▂▁▂▂▂▁▁▁▁▁▁
kappa,▁▇████▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅
kappa_max,▂▇███▇▇▆▆▅▅▅▅▄▄▄▄▄▃▃▄▄▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁
logloss,▁▁▁▂▃▃▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇████████████████

0,1
auc_pr,0.68875
auc_pr_cal,0.39846
avg_prec_score,0.69553
bceloss,0.80534
f1_max,0.71313
kappa,0.32993
kappa_max,0.45065
logloss,0.63683
num_tasks_agg,1352.0
num_tasks_total,3552.0


## Results 

In [23]:
print()
print(results_va['classification'][0:50])


      roc_auc_score    auc_pr  avg_prec_score    f1_max  p_f1_max     kappa  \
task                                                                          
0          0.991560  0.998782        0.998752  0.982301  0.980404  0.786753   
1               NaN       NaN             NaN       NaN       NaN       NaN   
2               NaN       NaN             NaN       NaN       NaN       NaN   
3               NaN       NaN             NaN       NaN       NaN       NaN   
4               NaN       NaN             NaN       NaN       NaN       NaN   
5          0.944444  0.996995        0.997076  0.972973  0.562853  0.000000   
6          1.000000  1.000000        1.000000  1.000000  0.973055  0.797642   
7          0.666667  0.461111        0.588889  0.750000  0.763078  0.000000   
8          1.000000  1.000000        1.000000  1.000000  0.987239  0.000000   
9          1.000000  1.000000        1.000000  1.000000  0.998716  0.660870   
10              NaN       NaN             NaN      

In [45]:
print( num_neg.sum())
print( num_pos.sum())
print( num_neg_va.sum())
print( num_pos_va.sum())

2445
18704
505
3804


In [None]:
for i in zip(dldrs.valset.num_pos, dldrs.valset.num_neg):
    print(f" {i[0]:4d}  {i[1]:4d}")
for i in zip(dldrs.valset.num_pos, dldrs.valset.num_neg):
    print(f" {i[0]:4d}  {i[1]:4d}")

In [24]:
# pp.pprint(results_va)
print()

pp.pprint(results_va['classification_agg'])
print()


roc_auc_score        0.900757
auc_pr               0.943594
avg_prec_score       0.945930
f1_max               0.938011
p_f1_max             0.723872
kappa                0.611823
kappa_max            0.803714
p_kappa_max          0.826011
bceloss              0.296447
auc_pr_cal           0.827893
logloss              0.087415
num_tasks_total    100.000000
num_tasks_agg       20.000000
dtype: float64



In [40]:
print(num_pos_va)
print(num_neg_va)
print(num_regr_va)
for i in zip(num_pos_va, num_neg_va, num_pos, num_neg):
    print(f" {i[0]:4d}  {i[1]:4d}    trianing: {i[2]:4d}   {i[3]:4d}")

[115  31  17  54   7  18 203   3  55  76  63   1  42   0  75  36   0   1  43   2   7   7   0   3   0  21   1  37 330  11  86   0   0 142  27  25   4
   7 111  12   7  79   0  97 129  15  10   9   9  38 272   1   2   0 285  63 185   1   0   4  19  21  13   0  18  12   0   0  93  55   1  73  29  41
  18   0  46  37  14   0  40  18   6   8  10 330   6   4   1   0   0   0   8   0   0   0   3   0   1   0]
[17  0  0  0  0  1  3  4  2  2  0  0  0  1  2  0  0  0  0  2  0  0 13  0  0  0  0  0  6  0  0 30  0  0  0  3  3  3  2  0  0  1  1  1 18  0  2  0 85
  0  0  2 13 47  1 11  0  0  0  2  0  1  1  1  3  0  0  0  0  0 18  0  9  3  0  4  1  0 23 11  6  0  5  0  0 46  4  4 73  2  0  7  1  0  0  0  0  0
  4  0]
[]
  115    17    trianing:  884     83
   31     0    trianing:   59      0
   17     0    trianing:   59      0
   54     0    trianing:  284      0
    7     0    trianing:   55      0
   18     1    trianing:   43      3
  203     3    trianing:  749     29
    3     4    trianing:   14 

In [26]:
df[pd.notna(df.roc_auc_score)].mean()

roc_auc_score      0.874262
auc_pr             0.922464
avg_prec_score     0.932781
f1_max             0.938123
p_f1_max           0.746278
kappa              0.400306
kappa_max          0.754262
p_kappa_max        0.830485
bceloss            0.385833
auc_pr_cal         0.803581
num_pos           58.512821
num_neg            9.948718
dtype: float64

In [27]:
del net

## Misc

In [33]:
import wandb
def restart_wandb(exp_id, exp_name, project_name, resume = "allow" ):
    print(exp_id, exp_name, project_name) 
    wandb_run = wandb.init(project = project_name, 
                                     entity  = "kbardool", 
                                     id      = exp_id, 
                                     name    = exp_name,
                                     resume=resume )
    
    print(f" PROJECT NAME: {wandb_run.project}\n"
          f" RUN ID      : {wandb_run.id} \n"
          f" RUN NAME    : {wandb_run.name}")     
 
    return wandb_run 


In [34]:
run = restart_wandb("d2rw3bdq","0413_0509","SparseChem-Mini")

d2rw3bdq 0413_0509 SparseChem-Mini





VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…


!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

You should always run with libnvidia-ml.so that is installed with your
NVIDIA Display Driver. By default it's installed in /usr/lib and /usr/lib64.
libnvidia-ml.so in GDK package is a stub library that is attached only for
build purposes (e.g. machine that you build your application doesn't have
to have Display Driver installed).
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Linked to libnvidia-ml library at wrong path : /usr/lib64/libnvidia-ml.so.1



 PROJECT NAME: SparseChem-Mini
 RUN ID      : d2rw3bdq 
 RUN NAME    : 0413_0509


In [43]:
print(run)

NameError: name 'run' is not defined

In [41]:
run.finish()

In [42]:
del run

In [7]:
# cmd = (
#   f" --x       /home/kbardool/kusanagi/MLDatasets/chembl_23mini_synthetic/chembl_23mini_x.npy " +
#   f" --y_class /home/kbardool/kusanagi/MLDatasets/chembl_23mini_synthetic/chembl_23mini_adashare_y_all_bin_sparse.npy " +
#   f" --folding /home/kbardool/kusanagi/MLDatasets/chembl_23mini_synthetic/chembl_23mini_folds.npy " +
#   f" --output_dir {output_dir}" +    
#   f" --fold_va           0 " +
#   f" --batch_ratio    0.02 " +
#   f" --hidden_sizes   25 25 25 25 25 25 " +
#   f" --dropouts_trunk  0  0  0  0  0  0 " +
#   f" --weight_decay   1e-4 " 
#   f" --epochs           40 " +
#   f" --lr             1e-3 " +
#   f" --lr_steps         10 " +
#   f" --lr_alpha        0.3" 
# )

# cmd = (
#   f" --x       {data_dir}/chembl_23mini_x.npy " +
#   f" --y_class {data_dir}/chembl_23mini_adashare_y_all_bin_sparse.npy " +
#   f" --folding {data_dir}/chembl_23mini_folds.npy " +
#   f" --output_dir {output_dir}" +    
#   f" --fold_va            0 " +
#   f" --batch_ratio     0.02 " +
#   f" --hidden_sizes   40 40 " +
#   f" --dropouts_trunk  0  0 " +
#   f" --weight_decay   1e-4 " +
#   f" --epochs           20 " +
#   f" --lr             1e-3 " +
#   f" --lr_steps         10 " +
#   f" --lr_alpha        0.3 " 
# )

#   f" --hidden_sizes   400 400 " +
#   f" --last_dropout   0.2 " +
#   f" --middle_dropout 0.2 " +
#   f" --x       ./{data_dir}/chembl_23_x.mtx " +
#   f" --y_class ./{data_dir}/chembl_23_y.mtx " +
#   f" --folding ./{data_dir}/folding_hier_0.6.npy " +

#### copied from SparseChemDev 

# cmd = (
#         f" --x       ./{data_dir}/chembl_23mini_x.npy" +
#         f" --y_class ./{data_dir}/chembl_23mini_y.npy" +
#         f" --folding ./{data_dir}/chembl_23mini_folds.npy" +
#         f" --hidden_sizes 20 30 40 " +  
#         f" --output_dir {output_dir}" +
#         f" --batch_ratio 0.1" +
#         f" --epochs 2" +
#         f" --lr 1e-3" +
#         f" --lr_steps 1" +
#         f" --dev {dev}" +
#         f" --verbose 1")
#         f" --input_size_freq  40"
#         f" --tail_hidden_size  10"

In [None]:
# data_dir="chembl23_data"
# data_dir="chembl23_run_01152022"
# rstr = "synthetic_data_model" ##random_str(12)
# rstr = "synthetic_data_model_03042022" ##random_str(12)
# output_dir = f"./models-{rstr}/"
# output_dir = f"./{data_dir}/models-{rstr}/"
# output dir kbardool/kusanagi/experiments/SparseChem/0116_0843


In [4]:
# dev = "gpu" 
# data_dir="chembl23_data"
# data_dir="chembl23_run_01152022"
# data_dir = "/home/kbardool/kusanagi/MLDatasets/chembl_23mini_synthetic"

# rm_output=False

# rstr = datetime.now().strftime("%m%d_%H%M")
# rstr = "synthetic_data_model" ##random_str(12)
# rstr = "synthetic_data_model_03042022" ##random_str(12)

# output_dir = f"./models-{rstr}/"
# output_dir = f"./{data_dir}/models-{rstr}/"

# output dir kbardool/kusanagi/experiments/SparseChem/0116_0843
# output_dir = f"/home/kbardool/kusanagi/experiments/SparseChem/{rstr}"
# print(output_dir)