In [1]:
import sys
import os

import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pprint

from monai.networks.nets.densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from monai.networks.nets.efficientnet import EfficientNetBN
from monai.networks.nets.resnet import ResNet, resnet34, resnet50, resnet101, resnet152, resnet200

from warnings import filterwarnings
filterwarnings("ignore")

sys.path.append(os.path.join(str(os.path.abspath('')), "..", "..", ".."))

from src.train_one_epoch import train_one_epoch
from src.get_data_loaders import prepare_train_valid_dataloader
from src.validate_func import valid_func

In [2]:
class CFG:
    debug = False # change this to run on full data
    
    image_size = 256
    folds = [0, 1, 2, 3, 4]
    
    kernel_type = "resnet34"
    
    train_batch_size = 6
    valid_batch_size = 24
    
    num_images = 64
    mri_type = 'T1wCE'
    
    init_lr = 1e-4
    weight_decay=0
    
    n_epochs = 20
    num_workers = 4

    use_amp=True
    early_stop = 5

    data_dir = PATH_TO_DATA # !!! DEFINE "PATH_TO_DATA" on your local machine
    model_dir = f'weights/'
    seed=12345
    

In [3]:
results_dir = CFG.mri_type + "_weights/"

In [4]:
! mkdir $results_dir

mkdir: cannot create directory ‘T1wCE_weights/’: File exists


In [5]:
df_train = pd.read_csv('../../crossval/train_df_folds.csv')
if CFG.debug:
    df_train = df_train.sample(frac=0.1)
df_train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,2
1,2,1,1
2,3,0,1
3,5,1,4
4,6,1,1


In [6]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
for fold in CFG.folds:
    train_loader, valid_loader = prepare_train_valid_dataloader(
        df=df_train, fold=fold, num_images=CFG.num_images,
        img_size=CFG.image_size, data_directory=CFG.data_dir, mri_type=CFG.mri_type,
        train_batch_size=CFG.train_batch_size, valid_batch_size=CFG.valid_batch_size,
        num_workers=CFG.num_workers
    )
    
#     model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)
#     model = EfficientNetBN(spatial_dims=3, in_channels=1, num_classes=1, model_name="efficientnet-b0").to(device)
    model = resnet34(spatial_dims=3, n_input_channels=1, num_classes=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=CFG.init_lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

#     scheduler = ReduceLROnPlateau(
#         optimizer, mode='min', patience=1, min_lr=1e-6, factor=0.1, verbose=True, eps=1e-8
#     )

    num_epochs = CFG.n_epochs

    print("-----------------------------------------------------------------------------------------------------")
    print("                                        FOLD: ", fold)
    print("-----------------------------------------------------------------------------------------------------")
    
    roc_auc_max = 0.0
    loss_min = 99999
    ap_max = 0.0
    not_improving = 0
    metrics_list = list()
    
    for epoch in range(CFG.n_epochs):
        
        loss_train, roc_auc_train = train_one_epoch(
            model, device, criterion, optimizer, train_loader, CFG.use_amp)
        
        loss_valid, roc_auc_valid = valid_func(
            model, device, criterion, valid_loader)
        
        scheduler.step()
        
#         scheduler.step(loss_valid)
        
        metrics_dictionary = {}
        metrics_dictionary['epoch'] = epoch
        metrics_dictionary['loss_train'] = loss_train
        metrics_dictionary['loss_valid'] = loss_valid
        metrics_dictionary['roc_auc_train'] = roc_auc_train
        metrics_dictionary['roc_auc_valid'] = roc_auc_valid
        metrics_dictionary['fold'] = fold
        pprint.pprint(metrics_dictionary)
        metrics_list.append(metrics_dictionary)
        
        not_improving += 1
        if roc_auc_valid > roc_auc_max:
            print(f'roc_auc_max ({roc_auc_max:.6f} --> {roc_auc_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_AUC_{CFG.mri_type}_mri_type.pth')
            roc_auc_max = roc_auc_valid
            not_improving = 0

        if loss_valid < loss_min:
            loss_min = loss_valid
            print(f'loss_min ({loss_min:.6f} --> {loss_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_loss_{CFG.mri_type}_mri_type.pth')

            
        if not_improving == CFG.early_stop:
            print('Early Stopping...')
            break

    
    metrics = pd.DataFrame(metrics_list)
    metrics.to_csv(f'{results_dir}{CFG.kernel_type}_fold{fold}_final.csv', index=False)
    torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_final_{CFG.mri_type}_mri_type.pth')


-----------------------------------------------------------------------------------------------------
                                        FOLD:  0
-----------------------------------------------------------------------------------------------------


loss: 0.74724, total_loss: 0.78179: 100%|████████████| 78/78 [05:04<00:00,  3.91s/it]
loss: 0.73284, total_loss: 0.76877: 100%|██████████████| 5/5 [01:27<00:00, 17.56s/it]


{'epoch': 0,
 'fold': 0,
 'loss_train': 0.7817926479455752,
 'loss_valid': 0.7687693238258362,
 'roc_auc_train': 0.5445812625176174,
 'roc_auc_valid': 0.5363636363636365}
roc_auc_max (0.000000 --> 0.536364). Saving model ...
loss_min (0.768769 --> 0.768769). Saving model ...


loss: 0.69417, total_loss: 0.69618: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.75588, total_loss: 0.72503: 100%|██████████████| 5/5 [00:19<00:00,  3.92s/it]


{'epoch': 1,
 'fold': 0,
 'loss_train': 0.6961806351557757,
 'loss_valid': 0.7250278711318969,
 'roc_auc_train': 0.5417624805281507,
 'roc_auc_valid': 0.5718475073313783}
roc_auc_max (0.536364 --> 0.571848). Saving model ...
loss_min (0.725028 --> 0.725028). Saving model ...


loss: 0.61893, total_loss: 0.70019: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.72505, total_loss: 0.72580: 100%|██████████████| 5/5 [00:19<00:00,  3.90s/it]


{'epoch': 2,
 'fold': 0,
 'loss_train': 0.7001904249191284,
 'loss_valid': 0.7258048176765441,
 'roc_auc_train': 0.5194162154142867,
 'roc_auc_valid': 0.5739002932551319}
roc_auc_max (0.571848 --> 0.573900). Saving model ...


loss: 0.59287, total_loss: 0.69690: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69148, total_loss: 0.69143: 100%|██████████████| 5/5 [00:19<00:00,  3.88s/it]


{'epoch': 3,
 'fold': 0,
 'loss_train': 0.6968953357293055,
 'loss_valid': 0.691432785987854,
 'roc_auc_train': 0.5594725910540761,
 'roc_auc_valid': 0.5730205278592375}
loss_min (0.691433 --> 0.691433). Saving model ...


loss: 0.78877, total_loss: 0.69323: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68789, total_loss: 0.69010: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 4,
 'fold': 0,
 'loss_train': 0.6932311638807639,
 'loss_valid': 0.6900953888893128,
 'roc_auc_train': 0.5513222312884801,
 'roc_auc_valid': 0.5607038123167154}
loss_min (0.690095 --> 0.690095). Saving model ...


loss: 0.90612, total_loss: 0.69107: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.73047, total_loss: 0.70037: 100%|██████████████| 5/5 [00:19<00:00,  3.91s/it]


{'epoch': 5,
 'fold': 0,
 'loss_train': 0.6910734405884376,
 'loss_valid': 0.7003687739372253,
 'roc_auc_train': 0.5656572212743862,
 'roc_auc_valid': 0.5829912023460411}
roc_auc_max (0.573900 --> 0.582991). Saving model ...


loss: 0.79466, total_loss: 0.69388: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67996, total_loss: 0.69046: 100%|██████████████| 5/5 [00:19<00:00,  3.91s/it]


{'epoch': 6,
 'fold': 0,
 'loss_train': 0.6938804747202457,
 'loss_valid': 0.6904624462127685,
 'roc_auc_train': 0.5712020621615608,
 'roc_auc_valid': 0.5489736070381231}


loss: 0.70150, total_loss: 0.68691: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69267, total_loss: 0.68808: 100%|██████████████| 5/5 [00:19<00:00,  3.90s/it]


{'epoch': 7,
 'fold': 0,
 'loss_train': 0.6869126244997367,
 'loss_valid': 0.6880834817886352,
 'roc_auc_train': 0.5592593279430309,
 'roc_auc_valid': 0.5674486803519061}
loss_min (0.688083 --> 0.688083). Saving model ...


loss: 0.67495, total_loss: 0.67614: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.88254, total_loss: 0.89074: 100%|██████████████| 5/5 [00:19<00:00,  3.95s/it]


{'epoch': 8,
 'fold': 0,
 'loss_train': 0.6761445670555799,
 'loss_valid': 0.8907400369644165,
 'roc_auc_train': 0.6129552703805354,
 'roc_auc_valid': 0.42346041055718475}


loss: 0.80442, total_loss: 0.67815: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67891, total_loss: 0.68693: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 9,
 'fold': 0,
 'loss_train': 0.6781494105473543,
 'loss_valid': 0.686934518814087,
 'roc_auc_train': 0.5985924634671018,
 'roc_auc_valid': 0.5970674486803519}
roc_auc_max (0.582991 --> 0.597067). Saving model ...
loss_min (0.686935 --> 0.686935). Saving model ...


loss: 0.64638, total_loss: 0.66482: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69326, total_loss: 0.70144: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 10,
 'fold': 0,
 'loss_train': 0.6648166863582073,
 'loss_valid': 0.701440966129303,
 'roc_auc_train': 0.6440731399747792,
 'roc_auc_valid': 0.5568914956011731}


loss: 0.53990, total_loss: 0.61863: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.78711, total_loss: 0.77419: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 11,
 'fold': 0,
 'loss_train': 0.6186256825159757,
 'loss_valid': 0.774185061454773,
 'roc_auc_train': 0.7220254432163784,
 'roc_auc_valid': 0.45923753665689154}


loss: 0.65160, total_loss: 0.59605: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.78858, total_loss: 0.77177: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 12,
 'fold': 0,
 'loss_train': 0.59604590290632,
 'loss_valid': 0.7717692255973816,
 'roc_auc_train': 0.7458830947259105,
 'roc_auc_valid': 0.4935483870967742}


loss: 0.74848, total_loss: 0.52695: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.87904, total_loss: 0.82363: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 13,
 'fold': 0,
 'loss_train': 0.5269452663950431,
 'loss_valid': 0.8236349940299987,
 'roc_auc_train': 0.8195145018915511,
 'roc_auc_valid': 0.5395894428152492}


loss: 0.70560, total_loss: 0.38385: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 1.28480, total_loss: 0.99691: 100%|██████████████| 5/5 [00:19<00:00,  3.97s/it]


{'epoch': 14,
 'fold': 0,
 'loss_train': 0.38385424772516274,
 'loss_valid': 0.9969134330749512,
 'roc_auc_train': 0.9153530895334174,
 'roc_auc_valid': 0.5551319648093842}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  1
-----------------------------------------------------------------------------------------------------


loss: 0.70069, total_loss: 0.76471: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 0.65999, total_loss: 0.76980: 100%|██████████████| 5/5 [00:19<00:00,  3.91s/it]


{'epoch': 0,
 'fold': 1,
 'loss_train': 0.764708487651287,
 'loss_valid': 0.7697958946228027,
 'roc_auc_train': 0.47770871985157703,
 'roc_auc_valid': 0.5550351288056207}
roc_auc_max (0.000000 --> 0.555035). Saving model ...
loss_min (0.769796 --> 0.769796). Saving model ...


loss: 0.69522, total_loss: 0.71094: 100%|████████████| 78/78 [01:56<00:00,  1.49s/it]
loss: 0.75428, total_loss: 0.70809: 100%|██████████████| 5/5 [00:19<00:00,  3.90s/it]


{'epoch': 1,
 'fold': 1,
 'loss_train': 0.7109371702640485,
 'loss_valid': 0.7080919146537781,
 'roc_auc_train': 0.5304174397031539,
 'roc_auc_valid': 0.5035128805620608}
loss_min (0.708092 --> 0.708092). Saving model ...


loss: 0.62135, total_loss: 0.70014: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68951, total_loss: 0.70265: 100%|██████████████| 5/5 [00:19<00:00,  3.89s/it]


{'epoch': 2,
 'fold': 1,
 'loss_train': 0.7001435886590909,
 'loss_valid': 0.7026509284973145,
 'roc_auc_train': 0.5211502782931354,
 'roc_auc_valid': 0.5357142857142858}
loss_min (0.702651 --> 0.702651). Saving model ...


loss: 0.69771, total_loss: 0.69271: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.68696, total_loss: 0.74724: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 3,
 'fold': 1,
 'loss_train': 0.6927131869089909,
 'loss_valid': 0.747244942188263,
 'roc_auc_train': 0.5612894248608534,
 'roc_auc_valid': 0.5193208430913349}


loss: 0.65892, total_loss: 0.69263: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70535, total_loss: 0.69620: 100%|██████████████| 5/5 [00:19<00:00,  3.90s/it]


{'epoch': 4,
 'fold': 1,
 'loss_train': 0.6926332391225375,
 'loss_valid': 0.696199357509613,
 'roc_auc_train': 0.5646753246753247,
 'roc_auc_valid': 0.5175644028103045}
loss_min (0.696199 --> 0.696199). Saving model ...


loss: 0.64392, total_loss: 0.68741: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.75341, total_loss: 0.69831: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 5,
 'fold': 1,
 'loss_train': 0.6874067714581122,
 'loss_valid': 0.6983139753341675,
 'roc_auc_train': 0.5645732838589982,
 'roc_auc_valid': 0.5342505854800936}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  2
-----------------------------------------------------------------------------------------------------


loss: 0.80810, total_loss: 0.77443: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.68567, total_loss: 0.68601: 100%|██████████████| 5/5 [00:18<00:00,  3.60s/it]


{'epoch': 0,
 'fold': 2,
 'loss_train': 0.7744262079015757,
 'loss_valid': 0.6860054850578308,
 'roc_auc_train': 0.4907563025210084,
 'roc_auc_valid': 0.5767511177347243}
roc_auc_max (0.000000 --> 0.576751). Saving model ...
loss_min (0.686005 --> 0.686005). Saving model ...


loss: 0.68255, total_loss: 0.70806: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.71366, total_loss: 0.69805: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 1,
 'fold': 2,
 'loss_train': 0.7080558026448275,
 'loss_valid': 0.6980455160140991,
 'roc_auc_train': 0.5220426632191338,
 'roc_auc_valid': 0.5442622950819672}


loss: 0.62302, total_loss: 0.69874: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.70611, total_loss: 0.69303: 100%|██████████████| 5/5 [00:18<00:00,  3.62s/it]


{'epoch': 2,
 'fold': 2,
 'loss_train': 0.6987365675278199,
 'loss_valid': 0.6930344939231873,
 'roc_auc_train': 0.5241111829347124,
 'roc_auc_valid': 0.5526080476900149}


loss: 0.69057, total_loss: 0.69240: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.74201, total_loss: 0.71339: 100%|██████████████| 5/5 [00:18<00:00,  3.63s/it]


{'epoch': 3,
 'fold': 2,
 'loss_train': 0.6923971137939355,
 'loss_valid': 0.7133877992630004,
 'roc_auc_train': 0.5451749930741527,
 'roc_auc_valid': 0.4965722801788376}


loss: 0.79075, total_loss: 0.67902: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.68973, total_loss: 0.69244: 100%|██████████████| 5/5 [00:18<00:00,  3.61s/it]


{'epoch': 4,
 'fold': 2,
 'loss_train': 0.6790242668909904,
 'loss_valid': 0.692444920539856,
 'roc_auc_train': 0.6061593868316557,
 'roc_auc_valid': 0.5514157973174367}


loss: 0.57440, total_loss: 0.68453: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68970, total_loss: 0.69551: 100%|██████████████| 5/5 [00:18<00:00,  3.66s/it]


{'epoch': 5,
 'fold': 2,
 'loss_train': 0.6845331619947385,
 'loss_valid': 0.6955051422119141,
 'roc_auc_train': 0.586000554067781,
 'roc_auc_valid': 0.5225037257824143}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  3
-----------------------------------------------------------------------------------------------------


loss: 0.69598, total_loss: 0.73929: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.67567, total_loss: 0.74562: 100%|██████████████| 5/5 [00:17<00:00,  3.49s/it]


{'epoch': 0,
 'fold': 3,
 'loss_train': 0.7392890881269406,
 'loss_valid': 0.7456230759620667,
 'roc_auc_train': 0.5137870532828516,
 'roc_auc_valid': 0.5067064083457526}
roc_auc_max (0.000000 --> 0.506706). Saving model ...
loss_min (0.745623 --> 0.745623). Saving model ...


loss: 0.66717, total_loss: 0.73051: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.79446, total_loss: 0.71653: 100%|██████████████| 5/5 [00:17<00:00,  3.58s/it]


{'epoch': 1,
 'fold': 3,
 'loss_train': 0.7305116882691016,
 'loss_valid': 0.7165308475494385,
 'roc_auc_train': 0.5046726382860837,
 'roc_auc_valid': 0.6002980625931446}
roc_auc_max (0.506706 --> 0.600298). Saving model ...
loss_min (0.716531 --> 0.716531). Saving model ...


loss: 0.44676, total_loss: 0.69524: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.94117, total_loss: 0.94276: 100%|██████████████| 5/5 [00:18<00:00,  3.61s/it]


{'epoch': 2,
 'fold': 3,
 'loss_train': 0.6952442771349198,
 'loss_valid': 0.9427633285522461,
 'roc_auc_train': 0.565149136577708,
 'roc_auc_valid': 0.4008941877794337}


loss: 0.73971, total_loss: 0.69069: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.79948, total_loss: 0.74006: 100%|██████████████| 5/5 [00:17<00:00,  3.54s/it]


{'epoch': 3,
 'fold': 3,
 'loss_train': 0.6906865980380621,
 'loss_valid': 0.7400572299957275,
 'roc_auc_train': 0.5683257918552036,
 'roc_auc_valid': 0.6035767511177348}
roc_auc_max (0.600298 --> 0.603577). Saving model ...


loss: 0.67687, total_loss: 0.68783: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.69328, total_loss: 0.69166: 100%|██████████████| 5/5 [00:17<00:00,  3.52s/it]


{'epoch': 4,
 'fold': 3,
 'loss_train': 0.6878311848029112,
 'loss_valid': 0.6916606187820434,
 'roc_auc_train': 0.5741989103333641,
 'roc_auc_valid': 0.5195230998509688}
loss_min (0.691661 --> 0.691661). Saving model ...


loss: 0.54202, total_loss: 0.68790: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.71102, total_loss: 0.68794: 100%|██████████████| 5/5 [00:17<00:00,  3.59s/it]


{'epoch': 5,
 'fold': 3,
 'loss_train': 0.6878986794214982,
 'loss_valid': 0.6879362463951111,
 'roc_auc_train': 0.556071659433004,
 'roc_auc_valid': 0.5666169895678094}
loss_min (0.687936 --> 0.687936). Saving model ...


loss: 0.81665, total_loss: 0.68649: 100%|████████████| 78/78 [01:52<00:00,  1.45s/it]
loss: 0.74241, total_loss: 0.75150: 100%|██████████████| 5/5 [00:17<00:00,  3.47s/it]


{'epoch': 6,
 'fold': 3,
 'loss_train': 0.6864928075900445,
 'loss_valid': 0.7515038132667542,
 'roc_auc_train': 0.5798411672361252,
 'roc_auc_valid': 0.5374068554396423}


loss: 0.65795, total_loss: 0.68600: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.73116, total_loss: 0.69885: 100%|██████████████| 5/5 [00:17<00:00,  3.59s/it]


{'epoch': 7,
 'fold': 3,
 'loss_train': 0.6860006497456477,
 'loss_valid': 0.6988525152206421,
 'roc_auc_train': 0.5733031674208144,
 'roc_auc_valid': 0.5248882265275708}


loss: 0.62865, total_loss: 0.66761: 100%|████████████| 78/78 [01:52<00:00,  1.44s/it]
loss: 0.75206, total_loss: 0.71173: 100%|██████████████| 5/5 [00:17<00:00,  3.47s/it]


{'epoch': 8,
 'fold': 3,
 'loss_train': 0.6676129370163648,
 'loss_valid': 0.7117303729057312,
 'roc_auc_train': 0.625754917351556,
 'roc_auc_valid': 0.5421758569299553}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  4
-----------------------------------------------------------------------------------------------------


loss: 0.79259, total_loss: 0.76137: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.75380, total_loss: 0.71772: 100%|██████████████| 5/5 [00:18<00:00,  3.60s/it]


{'epoch': 0,
 'fold': 4,
 'loss_train': 0.7613689933831875,
 'loss_valid': 0.7177247405052185,
 'roc_auc_train': 0.48168805983932034,
 'roc_auc_valid': 0.5365126676602086}
roc_auc_max (0.000000 --> 0.536513). Saving model ...
loss_min (0.717725 --> 0.717725). Saving model ...


loss: 0.56364, total_loss: 0.70057: 100%|████████████| 78/78 [01:52<00:00,  1.45s/it]
loss: 0.70012, total_loss: 0.72973: 100%|██████████████| 5/5 [00:18<00:00,  3.62s/it]


{'epoch': 1,
 'fold': 4,
 'loss_train': 0.7005666410311674,
 'loss_valid': 0.7297281622886658,
 'roc_auc_train': 0.5299843014128729,
 'roc_auc_valid': 0.5514157973174367}
roc_auc_max (0.536513 --> 0.551416). Saving model ...


loss: 0.62334, total_loss: 0.68563: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.72735, total_loss: 0.70341: 100%|██████████████| 5/5 [00:17<00:00,  3.57s/it]


{'epoch': 2,
 'fold': 4,
 'loss_train': 0.6856308273780041,
 'loss_valid': 0.7034149527549743,
 'roc_auc_train': 0.5871640964077939,
 'roc_auc_valid': 0.5713859910581222}
roc_auc_max (0.551416 --> 0.571386). Saving model ...
loss_min (0.703415 --> 0.703415). Saving model ...


loss: 0.77142, total_loss: 0.69129: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.68951, total_loss: 0.69822: 100%|██████████████| 5/5 [00:17<00:00,  3.56s/it]


{'epoch': 3,
 'fold': 4,
 'loss_train': 0.691287657389274,
 'loss_valid': 0.6982153058052063,
 'roc_auc_train': 0.5567642441592022,
 'roc_auc_valid': 0.568107302533532}
loss_min (0.698215 --> 0.698215). Saving model ...


loss: 0.72438, total_loss: 0.67661: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.73378, total_loss: 0.72269: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 4,
 'fold': 4,
 'loss_train': 0.6766118888671582,
 'loss_valid': 0.7226863980293274,
 'roc_auc_train': 0.6003047372795272,
 'roc_auc_valid': 0.5239940387481371}


loss: 0.83966, total_loss: 0.68123: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.70950, total_loss: 0.70248: 100%|██████████████| 5/5 [00:18<00:00,  3.60s/it]


{'epoch': 5,
 'fold': 4,
 'loss_train': 0.6812341266717666,
 'loss_valid': 0.7024819135665894,
 'roc_auc_train': 0.5954012374180442,
 'roc_auc_valid': 0.5368107302533532}


loss: 0.66036, total_loss: 0.66007: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.88992, total_loss: 0.76449: 100%|██████████████| 5/5 [00:18<00:00,  3.66s/it]


{'epoch': 6,
 'fold': 4,
 'loss_train': 0.6600687851508459,
 'loss_valid': 0.7644879698753357,
 'roc_auc_train': 0.6438636993258842,
 'roc_auc_valid': 0.48971684053651265}


loss: 0.42422, total_loss: 0.67475: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.69925, total_loss: 0.74069: 100%|██████████████| 5/5 [00:17<00:00,  3.57s/it]


{'epoch': 7,
 'fold': 4,
 'loss_train': 0.674750471344361,
 'loss_valid': 0.7406906127929688,
 'roc_auc_train': 0.6201495983008588,
 'roc_auc_valid': 0.5388971684053652}
Early Stopping...
