In [1]:
import sys
import os

import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pprint

from monai.networks.nets.densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from monai.networks.nets.efficientnet import EfficientNetBN
from monai.networks.nets.resnet import ResNet, resnet34, resnet50, resnet101, resnet152, resnet200

from warnings import filterwarnings
filterwarnings("ignore")

sys.path.append(os.path.join(str(os.path.abspath('')), "..", "..", ".."))

from src.train_one_epoch import train_one_epoch
from src.get_data_loaders import prepare_train_valid_dataloader
from src.validate_func import valid_func

In [2]:
class CFG:
    debug = False # change this to run on full data
    
    image_size = 256
    folds = [0, 1, 2, 3, 4]
    
    kernel_type = "resnet34"
    
    train_batch_size = 6
    valid_batch_size = 24
    
    num_images = 64
    mri_type = 'FLAIR'
    
    init_lr = 1e-4
    weight_decay=0
    
    n_epochs = 20
    num_workers = 4

    use_amp=True
    early_stop = 5

    data_dir = PATH_TO_DATA # !!! DEFINE "PATH_TO_DATA" on your local machine
    model_dir = f'weights/'
    seed=12345
    

In [3]:
results_dir = CFG.mri_type + "_weights/"

In [4]:
! mkdir $results_dir

mkdir: cannot create directory ‘FLAIR_weights/’: File exists


In [5]:
df_train = pd.read_csv('../../crossval/train_df_folds.csv')
if CFG.debug:
    df_train = df_train.sample(frac=0.1)
df_train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,2
1,2,1,1
2,3,0,1
3,5,1,4
4,6,1,1


In [6]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
for fold in CFG.folds:
    train_loader, valid_loader = prepare_train_valid_dataloader(
        df=df_train, fold=fold, num_images=CFG.num_images,
        img_size=CFG.image_size, data_directory=CFG.data_dir, mri_type=CFG.mri_type,
        train_batch_size=CFG.train_batch_size, valid_batch_size=CFG.valid_batch_size,
        num_workers=CFG.num_workers
    )
    
#     model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)
#     model = EfficientNetBN(spatial_dims=3, in_channels=1, num_classes=1, model_name="efficientnet-b0").to(device)
    model = resnet34(spatial_dims=3, n_input_channels=1, num_classes=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=CFG.init_lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

#     scheduler = ReduceLROnPlateau(
#         optimizer, mode='min', patience=1, min_lr=1e-6, factor=0.1, verbose=True, eps=1e-8
#     )

    num_epochs = CFG.n_epochs

    print("-----------------------------------------------------------------------------------------------------")
    print("                                        FOLD: ", fold)
    print("-----------------------------------------------------------------------------------------------------")
    
    roc_auc_max = 0.0
    loss_min = 99999
    ap_max = 0.0
    not_improving = 0
    metrics_list = list()
    
    for epoch in range(CFG.n_epochs):
        
        loss_train, roc_auc_train = train_one_epoch(
            model, device, criterion, optimizer, train_loader, CFG.use_amp)
        
        loss_valid, roc_auc_valid = valid_func(
            model, device, criterion, valid_loader)
        
        scheduler.step()
        
#         scheduler.step(loss_valid)
        
        metrics_dictionary = {}
        metrics_dictionary['epoch'] = epoch
        metrics_dictionary['loss_train'] = loss_train
        metrics_dictionary['loss_valid'] = loss_valid
        metrics_dictionary['roc_auc_train'] = roc_auc_train
        metrics_dictionary['roc_auc_valid'] = roc_auc_valid
        metrics_dictionary['fold'] = fold
        pprint.pprint(metrics_dictionary)
        metrics_list.append(metrics_dictionary)
        
        not_improving += 1
        if roc_auc_valid > roc_auc_max:
            print(f'roc_auc_max ({roc_auc_max:.6f} --> {roc_auc_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_AUC_{CFG.mri_type}_mri_type.pth')
            roc_auc_max = roc_auc_valid
            not_improving = 0

        if loss_valid < loss_min:
            loss_min = loss_valid
            print(f'loss_min ({loss_min:.6f} --> {loss_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_loss_{CFG.mri_type}_mri_type.pth')

            
        if not_improving == CFG.early_stop:
            print('Early Stopping...')
            break

    
    metrics = pd.DataFrame(metrics_list)
    metrics.to_csv(f'{results_dir}{CFG.kernel_type}_fold{fold}_final.csv', index=False)
    torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_final_{CFG.mri_type}_mri_type.pth')


-----------------------------------------------------------------------------------------------------
                                        FOLD:  0
-----------------------------------------------------------------------------------------------------


loss: 0.57748, total_loss: 0.76216: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 1.01524, total_loss: 0.94008: 100%|██████████████| 5/5 [00:19<00:00,  3.99s/it]


{'epoch': 0,
 'fold': 0,
 'loss_train': 0.7621604868998895,
 'loss_valid': 0.9400766134262085,
 'roc_auc_train': 0.5190731399747793,
 'roc_auc_valid': 0.4322580645161291}
roc_auc_max (0.000000 --> 0.432258). Saving model ...
loss_min (0.940077 --> 0.940077). Saving model ...


loss: 0.55375, total_loss: 0.70003: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.71357, total_loss: 0.72886: 100%|██████████████| 5/5 [00:20<00:00,  4.04s/it]


{'epoch': 1,
 'fold': 0,
 'loss_train': 0.7000334301056006,
 'loss_valid': 0.7288556814193725,
 'roc_auc_train': 0.548688895482531,
 'roc_auc_valid': 0.5304985337243402}
roc_auc_max (0.432258 --> 0.530499). Saving model ...
loss_min (0.728856 --> 0.728856). Saving model ...


loss: 0.81040, total_loss: 0.68607: 100%|████████████| 78/78 [02:01<00:00,  1.56s/it]
loss: 0.74902, total_loss: 0.73659: 100%|██████████████| 5/5 [00:20<00:00,  4.08s/it]


{'epoch': 2,
 'fold': 0,
 'loss_train': 0.6860737067002517,
 'loss_valid': 0.7365878462791443,
 'roc_auc_train': 0.5873915139826422,
 'roc_auc_valid': 0.5873900293255132}
roc_auc_max (0.530499 --> 0.587390). Saving model ...


loss: 0.50475, total_loss: 0.69692: 100%|████████████| 78/78 [02:02<00:00,  1.57s/it]
loss: 0.70214, total_loss: 0.69192: 100%|██████████████| 5/5 [00:20<00:00,  4.11s/it]


{'epoch': 3,
 'fold': 0,
 'loss_train': 0.6969179373521072,
 'loss_valid': 0.6919240355491638,
 'roc_auc_train': 0.571934574586455,
 'roc_auc_valid': 0.5791788856304986}
loss_min (0.691924 --> 0.691924). Saving model ...


loss: 0.57273, total_loss: 0.68460: 100%|████████████| 78/78 [02:04<00:00,  1.60s/it]
loss: 0.69818, total_loss: 0.68892: 100%|██████████████| 5/5 [00:20<00:00,  4.07s/it]


{'epoch': 4,
 'fold': 0,
 'loss_train': 0.684599077854401,
 'loss_valid': 0.6889237880706787,
 'roc_auc_train': 0.5847303612491654,
 'roc_auc_valid': 0.6058651026392962}
roc_auc_max (0.587390 --> 0.605865). Saving model ...
loss_min (0.688924 --> 0.688924). Saving model ...


loss: 0.69981, total_loss: 0.68458: 100%|████████████| 78/78 [02:06<00:00,  1.62s/it]
loss: 0.70140, total_loss: 0.67348: 100%|██████████████| 5/5 [00:20<00:00,  4.15s/it]


{'epoch': 5,
 'fold': 0,
 'loss_train': 0.6845782299836477,
 'loss_valid': 0.6734752297401428,
 'roc_auc_train': 0.5796862250574883,
 'roc_auc_valid': 0.6093841642228739}
roc_auc_max (0.605865 --> 0.609384). Saving model ...
loss_min (0.673475 --> 0.673475). Saving model ...


loss: 0.61778, total_loss: 0.68591: 100%|████████████| 78/78 [02:01<00:00,  1.56s/it]
loss: 0.70937, total_loss: 0.69265: 100%|██████████████| 5/5 [00:20<00:00,  4.09s/it]


{'epoch': 6,
 'fold': 0,
 'loss_train': 0.685905838624025,
 'loss_valid': 0.6926473617553711,
 'roc_auc_train': 0.5817724946220606,
 'roc_auc_valid': 0.552492668621701}


loss: 0.62391, total_loss: 0.67628: 100%|████████████| 78/78 [01:59<00:00,  1.53s/it]
loss: 0.69485, total_loss: 0.68002: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 7,
 'fold': 0,
 'loss_train': 0.676284032754409,
 'loss_valid': 0.6800158381462097,
 'roc_auc_train': 0.5977208664045693,
 'roc_auc_valid': 0.5932551319648094}


loss: 0.61289, total_loss: 0.67655: 100%|████████████| 78/78 [02:03<00:00,  1.58s/it]
loss: 0.68469, total_loss: 0.67192: 100%|██████████████| 5/5 [00:20<00:00,  4.11s/it]


{'epoch': 8,
 'fold': 0,
 'loss_train': 0.6765464995151911,
 'loss_valid': 0.6719173073768616,
 'roc_auc_train': 0.5985275573028708,
 'roc_auc_valid': 0.6536656891495602}
roc_auc_max (0.609384 --> 0.653666). Saving model ...
loss_min (0.671917 --> 0.671917). Saving model ...


loss: 0.72126, total_loss: 0.66901: 100%|████████████| 78/78 [02:00<00:00,  1.54s/it]
loss: 0.72862, total_loss: 0.70013: 100%|██████████████| 5/5 [00:20<00:00,  4.13s/it]


{'epoch': 9,
 'fold': 0,
 'loss_train': 0.6690108967133057,
 'loss_valid': 0.7001328706741333,
 'roc_auc_train': 0.6257232401157183,
 'roc_auc_valid': 0.5909090909090909}


loss: 0.63813, total_loss: 0.67465: 100%|████████████| 78/78 [02:02<00:00,  1.57s/it]
loss: 0.70588, total_loss: 0.69039: 100%|██████████████| 5/5 [00:20<00:00,  4.05s/it]


{'epoch': 10,
 'fold': 0,
 'loss_train': 0.6746469483925746,
 'loss_valid': 0.6903860330581665,
 'roc_auc_train': 0.6121393071730584,
 'roc_auc_valid': 0.5941348973607039}


loss: 0.49017, total_loss: 0.65750: 100%|████████████| 78/78 [01:58<00:00,  1.51s/it]
loss: 0.71994, total_loss: 0.69984: 100%|██████████████| 5/5 [00:20<00:00,  4.06s/it]


{'epoch': 11,
 'fold': 0,
 'loss_train': 0.6574991169648293,
 'loss_valid': 0.6998438835144043,
 'roc_auc_train': 0.6511757288035012,
 'roc_auc_valid': 0.549266862170088}


loss: 0.59638, total_loss: 0.65657: 100%|████████████| 78/78 [01:58<00:00,  1.52s/it]
loss: 0.69879, total_loss: 0.68816: 100%|██████████████| 5/5 [00:20<00:00,  4.03s/it]


{'epoch': 12,
 'fold': 0,
 'loss_train': 0.6565728390063995,
 'loss_valid': 0.6881574630737305,
 'roc_auc_train': 0.6552091832950079,
 'roc_auc_valid': 0.5818181818181818}


loss: 0.45386, total_loss: 0.63303: 100%|████████████| 78/78 [01:58<00:00,  1.51s/it]
loss: 0.75090, total_loss: 0.71800: 100%|██████████████| 5/5 [00:20<00:00,  4.07s/it]


{'epoch': 13,
 'fold': 0,
 'loss_train': 0.6330341459848942,
 'loss_valid': 0.7180023193359375,
 'roc_auc_train': 0.6973703731177212,
 'roc_auc_valid': 0.6008797653958944}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  1
-----------------------------------------------------------------------------------------------------


loss: 0.75587, total_loss: 0.76692: 100%|████████████| 78/78 [02:01<00:00,  1.56s/it]
loss: 0.64539, total_loss: 0.75446: 100%|██████████████| 5/5 [00:20<00:00,  4.05s/it]


{'epoch': 0,
 'fold': 1,
 'loss_train': 0.7669208607612512,
 'loss_valid': 0.7544562339782714,
 'roc_auc_train': 0.5000742115027829,
 'roc_auc_valid': 0.47950819672131145}
roc_auc_max (0.000000 --> 0.479508). Saving model ...
loss_min (0.754456 --> 0.754456). Saving model ...


loss: 0.70776, total_loss: 0.69271: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 0.66301, total_loss: 0.70237: 100%|██████████████| 5/5 [00:20<00:00,  4.03s/it]


{'epoch': 1,
 'fold': 1,
 'loss_train': 0.6927062976054656,
 'loss_valid': 0.7023653745651245,
 'roc_auc_train': 0.5650834879406308,
 'roc_auc_valid': 0.5690866510538642}
roc_auc_max (0.479508 --> 0.569087). Saving model ...
loss_min (0.702365 --> 0.702365). Saving model ...


loss: 0.50071, total_loss: 0.68367: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 0.71106, total_loss: 0.70848: 100%|██████████████| 5/5 [00:20<00:00,  4.03s/it]


{'epoch': 2,
 'fold': 1,
 'loss_train': 0.6836728197641861,
 'loss_valid': 0.7084830284118653,
 'roc_auc_train': 0.5857328385899814,
 'roc_auc_valid': 0.5477166276346606}


loss: 0.64564, total_loss: 0.68667: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 0.72344, total_loss: 0.69251: 100%|██████████████| 5/5 [00:20<00:00,  4.00s/it]


{'epoch': 3,
 'fold': 1,
 'loss_train': 0.6866664259861677,
 'loss_valid': 0.6925091505050659,
 'roc_auc_train': 0.5716883116883117,
 'roc_auc_valid': 0.5374707259953162}
loss_min (0.692509 --> 0.692509). Saving model ...


loss: 0.59396, total_loss: 0.68873: 100%|████████████| 78/78 [01:56<00:00,  1.49s/it]
loss: 0.76759, total_loss: 0.69684: 100%|██████████████| 5/5 [00:20<00:00,  4.05s/it]


{'epoch': 4,
 'fold': 1,
 'loss_train': 0.6887348531148373,
 'loss_valid': 0.6968430399894714,
 'roc_auc_train': 0.5657142857142857,
 'roc_auc_valid': 0.549473067915691}


loss: 0.61859, total_loss: 0.69636: 100%|████████████| 78/78 [01:58<00:00,  1.51s/it]
loss: 0.68559, total_loss: 0.74895: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 5,
 'fold': 1,
 'loss_train': 0.6963576326767603,
 'loss_valid': 0.7489516854286193,
 'roc_auc_train': 0.5563358070500928,
 'roc_auc_valid': 0.5357142857142857}


loss: 0.75198, total_loss: 0.68287: 100%|████████████| 78/78 [01:59<00:00,  1.54s/it]
loss: 0.74619, total_loss: 0.69590: 100%|██████████████| 5/5 [00:20<00:00,  4.03s/it]


{'epoch': 6,
 'fold': 1,
 'loss_train': 0.6828707135640658,
 'loss_valid': 0.6958992719650269,
 'roc_auc_train': 0.5872820037105752,
 'roc_auc_valid': 0.5327868852459017}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  2
-----------------------------------------------------------------------------------------------------


loss: 0.73232, total_loss: 0.77598: 100%|████████████| 78/78 [01:57<00:00,  1.50s/it]
loss: 0.62965, total_loss: 0.68547: 100%|██████████████| 5/5 [00:18<00:00,  3.64s/it]


{'epoch': 0,
 'fold': 2,
 'loss_train': 0.7759760186458246,
 'loss_valid': 0.6854722499847412,
 'roc_auc_train': 0.5278973127712623,
 'roc_auc_valid': 0.608047690014903}
roc_auc_max (0.000000 --> 0.608048). Saving model ...
loss_min (0.685472 --> 0.685472). Saving model ...


loss: 0.77827, total_loss: 0.71337: 100%|████████████| 78/78 [01:55<00:00,  1.48s/it]
loss: 0.68387, total_loss: 0.69130: 100%|██████████████| 5/5 [00:18<00:00,  3.68s/it]


{'epoch': 1,
 'fold': 2,
 'loss_train': 0.7133727356409415,
 'loss_valid': 0.6913033485412597,
 'roc_auc_train': 0.5256902761104442,
 'roc_auc_valid': 0.5344262295081967}


loss: 0.71566, total_loss: 0.68906: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.66465, total_loss: 0.69014: 100%|██████████████| 5/5 [00:18<00:00,  3.64s/it]


{'epoch': 2,
 'fold': 2,
 'loss_train': 0.6890636812417935,
 'loss_valid': 0.6901381134986877,
 'roc_auc_train': 0.5749376673746421,
 'roc_auc_valid': 0.5377049180327869}


loss: 0.50620, total_loss: 0.68730: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.65815, total_loss: 0.70915: 100%|██████████████| 5/5 [00:18<00:00,  3.71s/it]


{'epoch': 3,
 'fold': 2,
 'loss_train': 0.6873039591770905,
 'loss_valid': 0.7091546773910522,
 'roc_auc_train': 0.5745313510019392,
 'roc_auc_valid': 0.5496274217585693}


loss: 0.70205, total_loss: 0.69364: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.66839, total_loss: 0.70642: 100%|██████████████| 5/5 [00:18<00:00,  3.67s/it]


{'epoch': 4,
 'fold': 2,
 'loss_train': 0.6936385127214285,
 'loss_valid': 0.706421172618866,
 'roc_auc_train': 0.5601348231600333,
 'roc_auc_valid': 0.5549925484351714}


loss: 0.81279, total_loss: 0.68311: 100%|████████████| 78/78 [01:56<00:00,  1.49s/it]
loss: 0.69900, total_loss: 0.69252: 100%|██████████████| 5/5 [00:18<00:00,  3.66s/it]


{'epoch': 5,
 'fold': 2,
 'loss_train': 0.6831068006845621,
 'loss_valid': 0.692519462108612,
 'roc_auc_train': 0.5869517037584264,
 'roc_auc_valid': 0.5624441132637854}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  3
-----------------------------------------------------------------------------------------------------


loss: 0.20802, total_loss: 0.74394: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 0.69870, total_loss: 0.73853: 100%|██████████████| 5/5 [00:17<00:00,  3.45s/it]


{'epoch': 0,
 'fold': 3,
 'loss_train': 0.743944356456781,
 'loss_valid': 0.7385300993919373,
 'roc_auc_train': 0.5421276202788808,
 'roc_auc_valid': 0.6220566318926974}
roc_auc_max (0.000000 --> 0.622057). Saving model ...
loss_min (0.738530 --> 0.738530). Saving model ...


loss: 0.73417, total_loss: 0.72856: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.68381, total_loss: 0.68962: 100%|██████████████| 5/5 [00:17<00:00,  3.42s/it]


{'epoch': 1,
 'fold': 3,
 'loss_train': 0.7285616222100381,
 'loss_valid': 0.6896203517913818,
 'roc_auc_train': 0.5378058915874042,
 'roc_auc_valid': 0.5690014903129659}
loss_min (0.689620 --> 0.689620). Saving model ...


loss: 0.56297, total_loss: 0.68961: 100%|████████████| 78/78 [01:53<00:00,  1.45s/it]
loss: 0.69413, total_loss: 0.69521: 100%|██████████████| 5/5 [00:17<00:00,  3.58s/it]


{'epoch': 2,
 'fold': 3,
 'loss_train': 0.6896138359338809,
 'loss_valid': 0.6952112436294555,
 'roc_auc_train': 0.5476313602364022,
 'roc_auc_valid': 0.5776453055141579}


loss: 0.71241, total_loss: 0.67948: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.72473, total_loss: 0.70707: 100%|██████████████| 5/5 [00:18<00:00,  3.69s/it]


{'epoch': 3,
 'fold': 3,
 'loss_train': 0.6794821337247506,
 'loss_valid': 0.7070745587348938,
 'roc_auc_train': 0.5970172684458399,
 'roc_auc_valid': 0.5150521609538002}


loss: 0.66445, total_loss: 0.68780: 100%|████████████| 78/78 [01:55<00:00,  1.48s/it]
loss: 0.72817, total_loss: 0.69561: 100%|██████████████| 5/5 [00:17<00:00,  3.53s/it]


{'epoch': 4,
 'fold': 3,
 'loss_train': 0.6878014489626273,
 'loss_valid': 0.6956098437309265,
 'roc_auc_train': 0.5774956136300674,
 'roc_auc_valid': 0.5630402384500744}


loss: 0.80527, total_loss: 0.68351: 100%|████████████| 78/78 [02:01<00:00,  1.56s/it]
loss: 0.80271, total_loss: 0.75560: 100%|██████████████| 5/5 [00:17<00:00,  3.50s/it]


{'epoch': 5,
 'fold': 3,
 'loss_train': 0.6835081592584268,
 'loss_valid': 0.755600655078888,
 'roc_auc_train': 0.5949302798042294,
 'roc_auc_valid': 0.513859910581222}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  4
-----------------------------------------------------------------------------------------------------


loss: 0.68925, total_loss: 0.86275: 100%|████████████| 78/78 [02:05<00:00,  1.61s/it]
loss: 0.70097, total_loss: 0.72141: 100%|██████████████| 5/5 [00:19<00:00,  3.82s/it]


{'epoch': 0,
 'fold': 4,
 'loss_train': 0.8627506295839945,
 'loss_valid': 0.7214102029800415,
 'roc_auc_train': 0.4823252377874227,
 'roc_auc_valid': 0.5481371087928465}
roc_auc_max (0.000000 --> 0.548137). Saving model ...
loss_min (0.721410 --> 0.721410). Saving model ...


loss: 1.01136, total_loss: 0.71053: 100%|████████████| 78/78 [02:02<00:00,  1.56s/it]
loss: 0.78828, total_loss: 0.70062: 100%|██████████████| 5/5 [00:18<00:00,  3.71s/it]


{'epoch': 1,
 'fold': 4,
 'loss_train': 0.7105267598078802,
 'loss_valid': 0.700621509552002,
 'roc_auc_train': 0.5381383322559793,
 'roc_auc_valid': 0.5928464977645306}
roc_auc_max (0.548137 --> 0.592846). Saving model ...
loss_min (0.700622 --> 0.700622). Saving model ...


loss: 1.12310, total_loss: 0.68656: 100%|████████████| 78/78 [02:01<00:00,  1.56s/it]
loss: 0.68156, total_loss: 0.73012: 100%|██████████████| 5/5 [00:19<00:00,  3.83s/it]


{'epoch': 2,
 'fold': 4,
 'loss_train': 0.6865570965485696,
 'loss_valid': 0.7301161766052247,
 'roc_auc_train': 0.5856404100101579,
 'roc_auc_valid': 0.46706408345752615}


loss: 0.59868, total_loss: 0.69573: 100%|████████████| 78/78 [01:58<00:00,  1.52s/it]
loss: 0.70261, total_loss: 0.67655: 100%|██████████████| 5/5 [00:18<00:00,  3.78s/it]


{'epoch': 3,
 'fold': 4,
 'loss_train': 0.6957346609769723,
 'loss_valid': 0.6765473008155822,
 'roc_auc_train': 0.5565795548988827,
 'roc_auc_valid': 0.6110283159463487}
roc_auc_max (0.592846 --> 0.611028). Saving model ...
loss_min (0.676547 --> 0.676547). Saving model ...


loss: 0.72547, total_loss: 0.68974: 100%|████████████| 78/78 [02:13<00:00,  1.71s/it]
loss: 0.72912, total_loss: 0.68952: 100%|██████████████| 5/5 [00:19<00:00,  3.86s/it]


{'epoch': 4,
 'fold': 4,
 'loss_train': 0.689739194435951,
 'loss_valid': 0.6895205736160278,
 'roc_auc_train': 0.5721673284698494,
 'roc_auc_valid': 0.5764530551415796}


loss: 0.77635, total_loss: 0.67898: 100%|████████████| 78/78 [02:09<00:00,  1.66s/it]
loss: 0.71823, total_loss: 0.68671: 100%|██████████████| 5/5 [00:18<00:00,  3.76s/it]


{'epoch': 5,
 'fold': 4,
 'loss_train': 0.6789777737397414,
 'loss_valid': 0.6867068767547607,
 'roc_auc_train': 0.5932126696832579,
 'roc_auc_valid': 0.5764530551415796}


loss: 0.51377, total_loss: 0.67792: 100%|████████████| 78/78 [02:10<00:00,  1.67s/it]
loss: 0.72188, total_loss: 0.68641: 100%|██████████████| 5/5 [00:19<00:00,  3.87s/it]


{'epoch': 6,
 'fold': 4,
 'loss_train': 0.6779151704066839,
 'loss_valid': 0.6864084362983703,
 'roc_auc_train': 0.6029273247760644,
 'roc_auc_valid': 0.588375558867362}


loss: 0.77120, total_loss: 0.67953: 100%|████████████| 78/78 [02:09<00:00,  1.66s/it]
loss: 0.63842, total_loss: 0.68688: 100%|██████████████| 5/5 [00:19<00:00,  3.89s/it]


{'epoch': 7,
 'fold': 4,
 'loss_train': 0.6795313285711484,
 'loss_valid': 0.6868777394294738,
 'roc_auc_train': 0.6121340844029919,
 'roc_auc_valid': 0.5800298062593144}


loss: 0.78616, total_loss: 0.66672: 100%|████████████| 78/78 [01:58<00:00,  1.52s/it]
loss: 0.75957, total_loss: 0.72044: 100%|██████████████| 5/5 [00:18<00:00,  3.75s/it]


{'epoch': 8,
 'fold': 4,
 'loss_train': 0.6667175006407958,
 'loss_valid': 0.7204410672187805,
 'roc_auc_train': 0.6310832025117739,
 'roc_auc_valid': 0.5001490312965723}
Early Stopping...
