In [1]:
import sys
import os

import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pprint

from monai.networks.nets.densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from monai.networks.nets.efficientnet import EfficientNetBN
from monai.networks.nets.resnet import ResNet, resnet34, resnet50, resnet101, resnet152, resnet200

from warnings import filterwarnings
filterwarnings("ignore")

sys.path.append(os.path.join(str(os.path.abspath('')), "..", "..", ".."))

from src.train_one_epoch import train_one_epoch
from src.get_data_loaders import prepare_train_valid_dataloader
from src.validate_func import valid_func

In [2]:
class CFG:
    debug = False # change this to run on full data
    
    image_size = 256
    folds = [0, 1, 2, 3, 4]
    
    kernel_type = "resnet34"
    
    train_batch_size = 6
    valid_batch_size = 24
    
    num_images = 64
    mri_type = 'T1w'
    
    init_lr = 1e-4
    weight_decay=0
    
    n_epochs = 20
    num_workers = 4

    use_amp=True
    early_stop = 5

    data_dir = PATH_TO_DATA # !!! DEFINE "PATH_TO_DATA" on your local machine
    model_dir = f'weights/'
    seed=12345
    

In [3]:
results_dir = CFG.mri_type + "_weights/"

In [4]:
! mkdir $results_dir

mkdir: cannot create directory ‘T1w_weights/’: File exists


In [5]:
df_train = pd.read_csv('../../crossval/train_df_folds.csv')
if CFG.debug:
    df_train = df_train.sample(frac=0.1)
df_train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,2
1,2,1,1
2,3,0,1
3,5,1,4
4,6,1,1


In [6]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
for fold in CFG.folds:
    train_loader, valid_loader = prepare_train_valid_dataloader(
        df=df_train, fold=fold, num_images=CFG.num_images,
        img_size=CFG.image_size, data_directory=CFG.data_dir, mri_type=CFG.mri_type,
        train_batch_size=CFG.train_batch_size, valid_batch_size=CFG.valid_batch_size,
        num_workers=CFG.num_workers
    )
    
#     model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)
#     model = EfficientNetBN(spatial_dims=3, in_channels=1, num_classes=1, model_name="efficientnet-b0").to(device)
    model = resnet34(spatial_dims=3, n_input_channels=1, num_classes=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=CFG.init_lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

#     scheduler = ReduceLROnPlateau(
#         optimizer, mode='min', patience=1, min_lr=1e-6, factor=0.1, verbose=True, eps=1e-8
#     )

    num_epochs = CFG.n_epochs

    print("-----------------------------------------------------------------------------------------------------")
    print("                                        FOLD: ", fold)
    print("-----------------------------------------------------------------------------------------------------")
    
    roc_auc_max = 0.0
    loss_min = 99999
    ap_max = 0.0
    not_improving = 0
    metrics_list = list()
    
    for epoch in range(CFG.n_epochs):
        
        loss_train, roc_auc_train = train_one_epoch(
            model, device, criterion, optimizer, train_loader, CFG.use_amp)
        
        loss_valid, roc_auc_valid = valid_func(
            model, device, criterion, valid_loader)
        
        scheduler.step()
        
#         scheduler.step(loss_valid)
        
        metrics_dictionary = {}
        metrics_dictionary['epoch'] = epoch
        metrics_dictionary['loss_train'] = loss_train
        metrics_dictionary['loss_valid'] = loss_valid
        metrics_dictionary['roc_auc_train'] = roc_auc_train
        metrics_dictionary['roc_auc_valid'] = roc_auc_valid
        metrics_dictionary['fold'] = fold
        pprint.pprint(metrics_dictionary)
        metrics_list.append(metrics_dictionary)
        
        not_improving += 1
        if roc_auc_valid > roc_auc_max:
            print(f'roc_auc_max ({roc_auc_max:.6f} --> {roc_auc_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_AUC_{CFG.mri_type}_mri_type.pth')
            roc_auc_max = roc_auc_valid
            not_improving = 0

        if loss_valid < loss_min:
            loss_min = loss_valid
            print(f'loss_min ({loss_min:.6f} --> {loss_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_loss_{CFG.mri_type}_mri_type.pth')

            
        if not_improving == CFG.early_stop:
            print('Early Stopping...')
            break

    
    metrics = pd.DataFrame(metrics_list)
    metrics.to_csv(f'{results_dir}{CFG.kernel_type}_fold{fold}_final.csv', index=False)
    torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_final_{CFG.mri_type}_mri_type.pth')


-----------------------------------------------------------------------------------------------------
                                        FOLD:  0
-----------------------------------------------------------------------------------------------------


loss: 1.08692, total_loss: 0.75703: 100%|████████████| 78/78 [04:07<00:00,  3.17s/it]
loss: 0.72036, total_loss: 0.76632: 100%|██████████████| 5/5 [01:02<00:00, 12.53s/it]


{'epoch': 0,
 'fold': 0,
 'loss_train': 0.7570336044598849,
 'loss_valid': 0.7663166284561157,
 'roc_auc_train': 0.506713151843335,
 'roc_auc_valid': 0.5181818181818182}
roc_auc_max (0.000000 --> 0.518182). Saving model ...
loss_min (0.766317 --> 0.766317). Saving model ...


loss: 0.78457, total_loss: 0.69848: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.67737, total_loss: 0.70092: 100%|██████████████| 5/5 [00:19<00:00,  3.95s/it]


{'epoch': 1,
 'fold': 0,
 'loss_train': 0.6984758266271689,
 'loss_valid': 0.700924551486969,
 'roc_auc_train': 0.5394629478525332,
 'roc_auc_valid': 0.5759530791788856}
roc_auc_max (0.518182 --> 0.575953). Saving model ...
loss_min (0.700925 --> 0.700925). Saving model ...


loss: 0.83644, total_loss: 0.69347: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.71290, total_loss: 0.72292: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 2,
 'fold': 0,
 'loss_train': 0.6934679750448618,
 'loss_valid': 0.7229192614555359,
 'roc_auc_train': 0.5584619093539055,
 'roc_auc_valid': 0.5662756598240469}


loss: 0.75555, total_loss: 0.68775: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.81702, total_loss: 0.81288: 100%|██████████████| 5/5 [00:19<00:00,  3.91s/it]


{'epoch': 3,
 'fold': 0,
 'loss_train': 0.687749540194487,
 'loss_valid': 0.8128776907920837,
 'roc_auc_train': 0.5774515985461018,
 'roc_auc_valid': 0.5656891495601174}


loss: 0.71537, total_loss: 0.68672: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.71704, total_loss: 0.70651: 100%|██████████████| 5/5 [00:19<00:00,  3.96s/it]


{'epoch': 4,
 'fold': 0,
 'loss_train': 0.6867207579123669,
 'loss_valid': 0.7065068125724793,
 'roc_auc_train': 0.583107707143387,
 'roc_auc_valid': 0.5753665689149561}


loss: 0.66415, total_loss: 0.68373: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70601, total_loss: 0.71782: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 5,
 'fold': 0,
 'loss_train': 0.6837337995186831,
 'loss_valid': 0.7178163051605224,
 'roc_auc_train': 0.5861212076255471,
 'roc_auc_valid': 0.5865102639296187}
roc_auc_max (0.575953 --> 0.586510). Saving model ...


loss: 0.63159, total_loss: 0.69343: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.67213, total_loss: 0.68083: 100%|██████████████| 5/5 [00:19<00:00,  3.91s/it]


{'epoch': 6,
 'fold': 0,
 'loss_train': 0.6934266663514651,
 'loss_valid': 0.6808262467384338,
 'roc_auc_train': 0.5559676581855946,
 'roc_auc_valid': 0.5950146627565982}
roc_auc_max (0.586510 --> 0.595015). Saving model ...
loss_min (0.680826 --> 0.680826). Saving model ...


loss: 0.68955, total_loss: 0.67189: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67025, total_loss: 0.68053: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 7,
 'fold': 0,
 'loss_train': 0.6718869667786819,
 'loss_valid': 0.6805254697799683,
 'roc_auc_train': 0.6193809806394184,
 'roc_auc_valid': 0.6170087976539589}
roc_auc_max (0.595015 --> 0.617009). Saving model ...
loss_min (0.680525 --> 0.680525). Saving model ...


loss: 0.82449, total_loss: 0.67131: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67017, total_loss: 0.68085: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 8,
 'fold': 0,
 'loss_train': 0.6713144481182098,
 'loss_valid': 0.6808513045310974,
 'roc_auc_train': 0.6113697055114606,
 'roc_auc_valid': 0.5956011730205278}


loss: 0.58467, total_loss: 0.66185: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.66163, total_loss: 0.67312: 100%|██████████████| 5/5 [00:19<00:00,  3.97s/it]


{'epoch': 9,
 'fold': 0,
 'loss_train': 0.6618459133001474,
 'loss_valid': 0.6731209754943848,
 'roc_auc_train': 0.6481436837029895,
 'roc_auc_valid': 0.6117302052785925}
loss_min (0.673121 --> 0.673121). Saving model ...


loss: 0.57320, total_loss: 0.67310: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67829, total_loss: 0.70922: 100%|██████████████| 5/5 [00:19<00:00,  3.97s/it]


{'epoch': 10,
 'fold': 0,
 'loss_train': 0.6731020112832388,
 'loss_valid': 0.7092248201370239,
 'roc_auc_train': 0.6372487204213337,
 'roc_auc_valid': 0.5240469208211144}


loss: 0.71949, total_loss: 0.63910: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67575, total_loss: 0.75851: 100%|██████████████| 5/5 [00:19<00:00,  3.95s/it]


{'epoch': 11,
 'fold': 0,
 'loss_train': 0.6391044541811332,
 'loss_valid': 0.7585105776786805,
 'roc_auc_train': 0.6756546250278168,
 'roc_auc_valid': 0.5612903225806453}


loss: 0.52219, total_loss: 0.60404: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.72123, total_loss: 0.73990: 100%|██████████████| 5/5 [00:19<00:00,  3.96s/it]


{'epoch': 12,
 'fold': 0,
 'loss_train': 0.6040397343727258,
 'loss_valid': 0.739900553226471,
 'roc_auc_train': 0.7289240412432312,
 'roc_auc_valid': 0.5193548387096774}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  1
-----------------------------------------------------------------------------------------------------


loss: 0.68812, total_loss: 0.76711: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.65692, total_loss: 0.77315: 100%|██████████████| 5/5 [00:19<00:00,  3.97s/it]


{'epoch': 0,
 'fold': 1,
 'loss_train': 0.7671057111941851,
 'loss_valid': 0.7731543183326721,
 'roc_auc_train': 0.5371892393320965,
 'roc_auc_valid': 0.526639344262295}
roc_auc_max (0.000000 --> 0.526639). Saving model ...
loss_min (0.773154 --> 0.773154). Saving model ...


loss: 0.52718, total_loss: 0.69410: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67069, total_loss: 0.69134: 100%|██████████████| 5/5 [00:19<00:00,  3.96s/it]


{'epoch': 1,
 'fold': 1,
 'loss_train': 0.6940974490000651,
 'loss_valid': 0.6913392782211304,
 'roc_auc_train': 0.548051948051948,
 'roc_auc_valid': 0.5512295081967212}
roc_auc_max (0.526639 --> 0.551230). Saving model ...
loss_min (0.691339 --> 0.691339). Saving model ...


loss: 0.75639, total_loss: 0.69715: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.74465, total_loss: 0.70287: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 2,
 'fold': 1,
 'loss_train': 0.6971460596109048,
 'loss_valid': 0.7028677701950073,
 'roc_auc_train': 0.5291280148423005,
 'roc_auc_valid': 0.5231264637002342}


loss: 1.00748, total_loss: 0.69392: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.65100, total_loss: 0.68941: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 3,
 'fold': 1,
 'loss_train': 0.6939231963493885,
 'loss_valid': 0.6894148230552674,
 'roc_auc_train': 0.5545825602968459,
 'roc_auc_valid': 0.5749414519906324}
roc_auc_max (0.551230 --> 0.574941). Saving model ...
loss_min (0.689415 --> 0.689415). Saving model ...


loss: 0.62179, total_loss: 0.68475: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70626, total_loss: 0.68765: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 4,
 'fold': 1,
 'loss_train': 0.6847496475928869,
 'loss_valid': 0.6876543402671814,
 'roc_auc_train': 0.5751484230055659,
 'roc_auc_valid': 0.5802107728337236}
roc_auc_max (0.574941 --> 0.580211). Saving model ...
loss_min (0.687654 --> 0.687654). Saving model ...


loss: 0.57004, total_loss: 0.67717: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67447, total_loss: 0.72825: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 5,
 'fold': 1,
 'loss_train': 0.6771749200729223,
 'loss_valid': 0.7282495498657227,
 'roc_auc_train': 0.5968923933209648,
 'roc_auc_valid': 0.5313231850117096}


loss: 0.87782, total_loss: 0.69025: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69571, total_loss: 0.71743: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 6,
 'fold': 1,
 'loss_train': 0.6902469381307944,
 'loss_valid': 0.7174272298812866,
 'roc_auc_train': 0.5670871985157699,
 'roc_auc_valid': 0.5383489461358314}


loss: 0.77376, total_loss: 0.67765: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.66884, total_loss: 0.71359: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 7,
 'fold': 1,
 'loss_train': 0.6776452836317893,
 'loss_valid': 0.7135894179344178,
 'roc_auc_train': 0.6093877551020409,
 'roc_auc_valid': 0.5459601873536299}


loss: 0.74619, total_loss: 0.66727: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.78001, total_loss: 0.68628: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 8,
 'fold': 1,
 'loss_train': 0.6672673435547413,
 'loss_valid': 0.6862785696983338,
 'roc_auc_train': 0.6288961038961038,
 'roc_auc_valid': 0.6378805620608899}
roc_auc_max (0.580211 --> 0.637881). Saving model ...
loss_min (0.686279 --> 0.686279). Saving model ...


loss: 0.68731, total_loss: 0.65633: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.66345, total_loss: 0.69734: 100%|██████████████| 5/5 [00:19<00:00,  3.94s/it]


{'epoch': 9,
 'fold': 1,
 'loss_train': 0.6563333158309643,
 'loss_valid': 0.6973383069038391,
 'roc_auc_train': 0.650556586270872,
 'roc_auc_valid': 0.6115339578454332}


loss: 0.87470, total_loss: 0.62667: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70156, total_loss: 0.75862: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 10,
 'fold': 1,
 'loss_train': 0.6266686736773222,
 'loss_valid': 0.7586153507232666,
 'roc_auc_train': 0.7064935064935065,
 'roc_auc_valid': 0.5755269320843092}


loss: 0.58502, total_loss: 0.58663: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68851, total_loss: 0.74558: 100%|██████████████| 5/5 [00:19<00:00,  3.95s/it]


{'epoch': 11,
 'fold': 1,
 'loss_train': 0.5866348176048353,
 'loss_valid': 0.7455797076225281,
 'roc_auc_train': 0.758265306122449,
 'roc_auc_valid': 0.5424473067915692}


loss: 0.25446, total_loss: 0.53481: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.89789, total_loss: 0.88022: 100%|██████████████| 5/5 [00:19<00:00,  3.93s/it]


{'epoch': 12,
 'fold': 1,
 'loss_train': 0.5348126464165174,
 'loss_valid': 0.8802168488502502,
 'roc_auc_train': 0.8068274582560298,
 'roc_auc_valid': 0.5120023419203747}


loss: 0.35440, total_loss: 0.46184: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.92581, total_loss: 0.88539: 100%|██████████████| 5/5 [00:19<00:00,  3.96s/it]


{'epoch': 13,
 'fold': 1,
 'loss_train': 0.46184180982601947,
 'loss_valid': 0.8853938698768615,
 'roc_auc_train': 0.8620500927643786,
 'roc_auc_valid': 0.5289812646370023}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  2
-----------------------------------------------------------------------------------------------------


loss: 0.56130, total_loss: 0.75445: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.60316, total_loss: 0.71138: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 0,
 'fold': 2,
 'loss_train': 0.7544496605793635,
 'loss_valid': 0.711379611492157,
 'roc_auc_train': 0.5255425247021885,
 'roc_auc_valid': 0.5493293591654248}
roc_auc_max (0.000000 --> 0.549329). Saving model ...
loss_min (0.711380 --> 0.711380). Saving model ...


loss: 0.77893, total_loss: 0.71668: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.72211, total_loss: 0.70109: 100%|██████████████| 5/5 [00:18<00:00,  3.71s/it]


{'epoch': 1,
 'fold': 2,
 'loss_train': 0.7166797342972878,
 'loss_valid': 0.7010917425155639,
 'roc_auc_train': 0.48184504571059195,
 'roc_auc_valid': 0.6181818181818182}
roc_auc_max (0.549329 --> 0.618182). Saving model ...
loss_min (0.701092 --> 0.701092). Saving model ...


loss: 0.89032, total_loss: 0.70202: 100%|████████████| 78/78 [01:55<00:00,  1.48s/it]
loss: 0.69075, total_loss: 0.72700: 100%|██████████████| 5/5 [00:18<00:00,  3.63s/it]


{'epoch': 2,
 'fold': 2,
 'loss_train': 0.7020211804371613,
 'loss_valid': 0.7269954800605773,
 'roc_auc_train': 0.540908671160772,
 'roc_auc_valid': 0.5555886736214606}


loss: 0.57609, total_loss: 0.69048: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.65334, total_loss: 0.67581: 100%|██████████████| 5/5 [00:18<00:00,  3.71s/it]


{'epoch': 3,
 'fold': 2,
 'loss_train': 0.6904828387957352,
 'loss_valid': 0.6758059978485107,
 'roc_auc_train': 0.5569674023455536,
 'roc_auc_valid': 0.6068554396423249}
loss_min (0.675806 --> 0.675806). Saving model ...


loss: 0.65616, total_loss: 0.69935: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67683, total_loss: 0.69546: 100%|██████████████| 5/5 [00:18<00:00,  3.66s/it]


{'epoch': 4,
 'fold': 2,
 'loss_train': 0.6993524722563915,
 'loss_valid': 0.6954554915428162,
 'roc_auc_train': 0.5302428663773202,
 'roc_auc_valid': 0.5886736214605067}


loss: 0.61398, total_loss: 0.68673: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.67066, total_loss: 0.67239: 100%|██████████████| 5/5 [00:19<00:00,  3.82s/it]


{'epoch': 5,
 'fold': 2,
 'loss_train': 0.6867295977396842,
 'loss_valid': 0.6723854064941406,
 'roc_auc_train': 0.5774402068519716,
 'roc_auc_valid': 0.6152011922503726}
loss_min (0.672385 --> 0.672385). Saving model ...


loss: 0.65340, total_loss: 0.69169: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.69416, total_loss: 0.68299: 100%|██████████████| 5/5 [00:18<00:00,  3.76s/it]


{'epoch': 6,
 'fold': 2,
 'loss_train': 0.6916875594701523,
 'loss_valid': 0.6829899072647094,
 'roc_auc_train': 0.5631267891772094,
 'roc_auc_valid': 0.6092399403874813}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  3
-----------------------------------------------------------------------------------------------------


loss: 0.72215, total_loss: 0.77979: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.73641, total_loss: 0.81990: 100%|██████████████| 5/5 [00:18<00:00,  3.67s/it]


{'epoch': 0,
 'fold': 3,
 'loss_train': 0.779788307272471,
 'loss_valid': 0.8198972105979919,
 'roc_auc_train': 0.5232708468002586,
 'roc_auc_valid': 0.4929955290611029}
roc_auc_max (0.000000 --> 0.492996). Saving model ...
loss_min (0.819897 --> 0.819897). Saving model ...


loss: 0.88924, total_loss: 0.71273: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70122, total_loss: 0.68867: 100%|██████████████| 5/5 [00:17<00:00,  3.60s/it]


{'epoch': 1,
 'fold': 3,
 'loss_train': 0.7127263630047823,
 'loss_valid': 0.6886723399162292,
 'roc_auc_train': 0.5053652230122818,
 'roc_auc_valid': 0.5597615499254843}
roc_auc_max (0.492996 --> 0.559762). Saving model ...
loss_min (0.688672 --> 0.688672). Saving model ...


loss: 0.73264, total_loss: 0.70177: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.71028, total_loss: 0.69986: 100%|██████████████| 5/5 [00:17<00:00,  3.53s/it]


{'epoch': 2,
 'fold': 3,
 'loss_train': 0.7017715496894641,
 'loss_valid': 0.6998627066612244,
 'roc_auc_train': 0.535645027241666,
 'roc_auc_valid': 0.5776453055141579}
roc_auc_max (0.559762 --> 0.577645). Saving model ...


loss: 0.75222, total_loss: 0.70332: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.72087, total_loss: 0.72604: 100%|██████████████| 5/5 [00:17<00:00,  3.56s/it]


{'epoch': 3,
 'fold': 3,
 'loss_train': 0.7033221973822668,
 'loss_valid': 0.7260392665863037,
 'roc_auc_train': 0.5047188106011635,
 'roc_auc_valid': 0.5618479880774964}


loss: 0.64110, total_loss: 0.68318: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.72732, total_loss: 0.69606: 100%|██████████████| 5/5 [00:16<00:00,  3.40s/it]


{'epoch': 4,
 'fold': 3,
 'loss_train': 0.6831795489176725,
 'loss_valid': 0.696061646938324,
 'roc_auc_train': 0.585243328100471,
 'roc_auc_valid': 0.5394932935916543}


loss: 0.55912, total_loss: 0.68618: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.74094, total_loss: 0.73722: 100%|██████████████| 5/5 [00:17<00:00,  3.55s/it]


{'epoch': 5,
 'fold': 3,
 'loss_train': 0.6861773637624887,
 'loss_valid': 0.7372214674949646,
 'roc_auc_train': 0.5837658140179149,
 'roc_auc_valid': 0.5391952309985097}


loss: 0.65000, total_loss: 0.68463: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.74974, total_loss: 0.70589: 100%|██████████████| 5/5 [00:17<00:00,  3.52s/it]


{'epoch': 6,
 'fold': 3,
 'loss_train': 0.684633310024555,
 'loss_valid': 0.7058885335922241,
 'roc_auc_train': 0.5866469664788992,
 'roc_auc_valid': 0.5436661698956782}


loss: 0.54557, total_loss: 0.68357: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.72225, total_loss: 0.68889: 100%|██████████████| 5/5 [00:17<00:00,  3.54s/it]


{'epoch': 7,
 'fold': 3,
 'loss_train': 0.6835660957373105,
 'loss_valid': 0.6888906478881835,
 'roc_auc_train': 0.590183765814018,
 'roc_auc_valid': 0.5532041728763041}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  4
-----------------------------------------------------------------------------------------------------


loss: 1.03188, total_loss: 0.73392: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 1.57311, total_loss: 1.40092: 100%|██████████████| 5/5 [00:18<00:00,  3.75s/it]


{'epoch': 0,
 'fold': 4,
 'loss_train': 0.7339223329073343,
 'loss_valid': 1.4009219646453857,
 'roc_auc_train': 0.539782066672823,
 'roc_auc_valid': 0.5910581222056632}
roc_auc_max (0.000000 --> 0.591058). Saving model ...
loss_min (1.400922 --> 1.400922). Saving model ...


loss: 0.74702, total_loss: 0.70926: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.83828, total_loss: 0.76626: 100%|██████████████| 5/5 [00:18<00:00,  3.78s/it]


{'epoch': 1,
 'fold': 4,
 'loss_train': 0.7092610750442896,
 'loss_valid': 0.7662597894668579,
 'roc_auc_train': 0.554935820482039,
 'roc_auc_valid': 0.568107302533532}
loss_min (0.766260 --> 0.766260). Saving model ...


loss: 0.57565, total_loss: 0.69048: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.70008, total_loss: 0.68394: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 2,
 'fold': 4,
 'loss_train': 0.6904818377433679,
 'loss_valid': 0.683942997455597,
 'roc_auc_train': 0.5597746791024103,
 'roc_auc_valid': 0.5725782414307005}
loss_min (0.683943 --> 0.683943). Saving model ...


loss: 0.60682, total_loss: 0.68465: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.74261, total_loss: 0.69901: 100%|██████████████| 5/5 [00:18<00:00,  3.67s/it]


{'epoch': 3,
 'fold': 4,
 'loss_train': 0.6846491954265497,
 'loss_valid': 0.6990083932876587,
 'roc_auc_train': 0.5821774863791671,
 'roc_auc_valid': 0.5809239940387481}


loss: 0.62070, total_loss: 0.69496: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.70530, total_loss: 0.69023: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 4,
 'fold': 4,
 'loss_train': 0.6949615088792948,
 'loss_valid': 0.6902297616004944,
 'roc_auc_train': 0.5529596453966202,
 'roc_auc_valid': 0.5845007451564829}


loss: 0.79095, total_loss: 0.67984: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.70250, total_loss: 0.73386: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 5,
 'fold': 4,
 'loss_train': 0.6798403698664445,
 'loss_valid': 0.7338621616363525,
 'roc_auc_train': 0.6007295225782621,
 'roc_auc_valid': 0.44679582712369603}
Early Stopping...
