In [1]:
import sys
import os

import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pprint

from monai.networks.nets.densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from monai.networks.nets.efficientnet import EfficientNetBN
from monai.networks.nets.resnet import ResNet, resnet34, resnet50, resnet101, resnet152, resnet200

from warnings import filterwarnings
filterwarnings("ignore")

sys.path.append(os.path.join(str(os.path.abspath('')), "..", "..", ".."))

from src.train_one_epoch import train_one_epoch
from src.get_data_loaders import prepare_train_valid_dataloader
from src.validate_func import valid_func

In [2]:
class CFG:
    debug = False # change this to run on full data
    
    image_size = 256
    folds = [0, 1, 2, 3, 4]
    
    kernel_type = "resnet34"
    
    train_batch_size = 6
    valid_batch_size = 24
    
    num_images = 64
    mri_type = 'T2w'
    
    init_lr = 1e-4
    weight_decay=0
    
    n_epochs = 20
    num_workers = 4

    use_amp=True
    early_stop = 5

    data_dir = PATH_TO_DATA # !!! DEFINE "PATH_TO_DATA" on your local machine
    model_dir = f'weights/'
    seed=12345
    

In [3]:
results_dir = CFG.mri_type + "_weights/"

In [4]:
! mkdir $results_dir

mkdir: cannot create directory ‘T2w_weights/’: File exists


In [5]:
df_train = pd.read_csv('../../crossval/train_df_folds.csv')
if CFG.debug:
    df_train = df_train.sample(frac=0.1)
df_train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,2
1,2,1,1
2,3,0,1
3,5,1,4
4,6,1,1


In [6]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
for fold in CFG.folds:
    train_loader, valid_loader = prepare_train_valid_dataloader(
        df=df_train, fold=fold, num_images=CFG.num_images,
        img_size=CFG.image_size, data_directory=CFG.data_dir, mri_type=CFG.mri_type,
        train_batch_size=CFG.train_batch_size, valid_batch_size=CFG.valid_batch_size,
        num_workers=CFG.num_workers
    )
    
#     model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)
#     model = EfficientNetBN(spatial_dims=3, in_channels=1, num_classes=1, model_name="efficientnet-b0").to(device)
    model = resnet34(spatial_dims=3, n_input_channels=1, num_classes=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=CFG.init_lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

#     scheduler = ReduceLROnPlateau(
#         optimizer, mode='min', patience=1, min_lr=1e-6, factor=0.1, verbose=True, eps=1e-8
#     )

    num_epochs = CFG.n_epochs

    print("-----------------------------------------------------------------------------------------------------")
    print("                                        FOLD: ", fold)
    print("-----------------------------------------------------------------------------------------------------")
    
    roc_auc_max = 0.0
    loss_min = 99999
    ap_max = 0.0
    not_improving = 0
    metrics_list = list()
    
    for epoch in range(CFG.n_epochs):
        
        loss_train, roc_auc_train = train_one_epoch(
            model, device, criterion, optimizer, train_loader, CFG.use_amp)
        
        loss_valid, roc_auc_valid = valid_func(
            model, device, criterion, valid_loader)
        
        scheduler.step()
        
#         scheduler.step(loss_valid)
        
        metrics_dictionary = {}
        metrics_dictionary['epoch'] = epoch
        metrics_dictionary['loss_train'] = loss_train
        metrics_dictionary['loss_valid'] = loss_valid
        metrics_dictionary['roc_auc_train'] = roc_auc_train
        metrics_dictionary['roc_auc_valid'] = roc_auc_valid
        metrics_dictionary['fold'] = fold
        pprint.pprint(metrics_dictionary)
        metrics_list.append(metrics_dictionary)
        
        not_improving += 1
        if roc_auc_valid > roc_auc_max:
            print(f'roc_auc_max ({roc_auc_max:.6f} --> {roc_auc_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_AUC_{CFG.mri_type}_mri_type.pth')
            roc_auc_max = roc_auc_valid
            not_improving = 0

        if loss_valid < loss_min:
            loss_min = loss_valid
            print(f'loss_min ({loss_min:.6f} --> {loss_valid:.6f}). Saving model ...')
            torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_best_loss_{CFG.mri_type}_mri_type.pth')

            
        if not_improving == CFG.early_stop:
            print('Early Stopping...')
            break

    
    metrics = pd.DataFrame(metrics_list)
    metrics.to_csv(f'{results_dir}{CFG.kernel_type}_fold{fold}_final.csv', index=False)
    torch.save(model.state_dict(), f'{results_dir}{CFG.kernel_type}_fold{fold}_final_{CFG.mri_type}_mri_type.pth')


-----------------------------------------------------------------------------------------------------
                                        FOLD:  0
-----------------------------------------------------------------------------------------------------


loss: 0.58869, total_loss: 0.75831: 100%|████████████| 78/78 [06:14<00:00,  4.80s/it]
loss: 0.75587, total_loss: 0.69554: 100%|██████████████| 5/5 [01:29<00:00, 17.93s/it]


{'epoch': 0,
 'fold': 0,
 'loss_train': 0.7583113924050943,
 'loss_valid': 0.6955398797988892,
 'roc_auc_train': 0.524608708552778,
 'roc_auc_valid': 0.5879765395894428}
roc_auc_max (0.000000 --> 0.587977). Saving model ...
loss_min (0.695540 --> 0.695540). Saving model ...


loss: 0.71088, total_loss: 0.69585: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.82749, total_loss: 0.84316: 100%|██████████████| 5/5 [00:19<00:00,  3.96s/it]


{'epoch': 1,
 'fold': 0,
 'loss_train': 0.695853475576792,
 'loss_valid': 0.8431619167327881,
 'roc_auc_train': 0.5464728135894963,
 'roc_auc_valid': 0.5428152492668622}


loss: 0.76383, total_loss: 0.71340: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.71458, total_loss: 0.73449: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 2,
 'fold': 0,
 'loss_train': 0.7133990480349615,
 'loss_valid': 0.7344889521598816,
 'roc_auc_train': 0.5246643424078332,
 'roc_auc_valid': 0.6208211143695015}
roc_auc_max (0.587977 --> 0.620821). Saving model ...


loss: 0.66633, total_loss: 0.69976: 100%|████████████| 78/78 [01:55<00:00,  1.48s/it]
loss: 0.76708, total_loss: 0.82728: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 3,
 'fold': 0,
 'loss_train': 0.6997567025514749,
 'loss_valid': 0.8272830963134765,
 'roc_auc_train': 0.5337048438543136,
 'roc_auc_valid': 0.48504398826979467}


loss: 0.73981, total_loss: 0.71294: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.73102, total_loss: 0.78220: 100%|██████████████| 5/5 [00:19<00:00,  3.99s/it]


{'epoch': 4,
 'fold': 0,
 'loss_train': 0.7129444457017459,
 'loss_valid': 0.7821999311447143,
 'roc_auc_train': 0.4908018692975299,
 'roc_auc_valid': 0.544574780058651}


loss: 0.61460, total_loss: 0.69944: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68721, total_loss: 0.67143: 100%|██████████████| 5/5 [00:19<00:00,  3.98s/it]


{'epoch': 5,
 'fold': 0,
 'loss_train': 0.6994409362475077,
 'loss_valid': 0.6714315533638,
 'roc_auc_train': 0.5324159928788665,
 'roc_auc_valid': 0.6293255131964809}
roc_auc_max (0.620821 --> 0.629326). Saving model ...
loss_min (0.671432 --> 0.671432). Saving model ...


loss: 0.76883, total_loss: 0.69238: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.68562, total_loss: 0.68053: 100%|██████████████| 5/5 [00:20<00:00,  4.00s/it]


{'epoch': 6,
 'fold': 0,
 'loss_train': 0.6923793730063316,
 'loss_valid': 0.6805298924446106,
 'roc_auc_train': 0.5385913507900008,
 'roc_auc_valid': 0.5909090909090909}


loss: 0.51992, total_loss: 0.68247: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.70386, total_loss: 0.67482: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 7,
 'fold': 0,
 'loss_train': 0.6824694978885162,
 'loss_valid': 0.6748202204704284,
 'roc_auc_train': 0.581299606854091,
 'roc_auc_valid': 0.6126099706744869}


loss: 0.61665, total_loss: 0.69161: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.69122, total_loss: 0.68693: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 8,
 'fold': 0,
 'loss_train': 0.6916094758571723,
 'loss_valid': 0.6869284749031067,
 'roc_auc_train': 0.5525925376455753,
 'roc_auc_valid': 0.5991202346041055}


loss: 0.47601, total_loss: 0.67283: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.74474, total_loss: 0.70598: 100%|██████████████| 5/5 [00:19<00:00,  4.00s/it]


{'epoch': 9,
 'fold': 0,
 'loss_train': 0.6728261334773822,
 'loss_valid': 0.7059778928756714,
 'roc_auc_train': 0.6032842519100957,
 'roc_auc_valid': 0.5900293255131965}


loss: 0.64912, total_loss: 0.68888: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.70500, total_loss: 0.68206: 100%|██████████████| 5/5 [00:19<00:00,  4.00s/it]


{'epoch': 10,
 'fold': 0,
 'loss_train': 0.6888803060238178,
 'loss_valid': 0.6820550680160522,
 'roc_auc_train': 0.571851123803872,
 'roc_auc_valid': 0.6079178885630498}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  1
-----------------------------------------------------------------------------------------------------


loss: 0.54901, total_loss: 0.76683: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.77791, total_loss: 0.69203: 100%|██████████████| 5/5 [00:20<00:00,  4.06s/it]


{'epoch': 0,
 'fold': 1,
 'loss_train': 0.7668263216813406,
 'loss_valid': 0.6920268535614014,
 'roc_auc_train': 0.504239332096475,
 'roc_auc_valid': 0.5945550351288057}
roc_auc_max (0.000000 --> 0.594555). Saving model ...
loss_min (0.692027 --> 0.692027). Saving model ...


loss: 0.74868, total_loss: 0.70494: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69470, total_loss: 0.70753: 100%|██████████████| 5/5 [00:20<00:00,  4.02s/it]


{'epoch': 1,
 'fold': 1,
 'loss_train': 0.7049402434092301,
 'loss_valid': 0.7075337529182434,
 'roc_auc_train': 0.5497588126159555,
 'roc_auc_valid': 0.4812646370023419}


loss: 0.59311, total_loss: 0.69287: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.73094, total_loss: 0.70498: 100%|██████████████| 5/5 [00:20<00:00,  4.04s/it]


{'epoch': 2,
 'fold': 1,
 'loss_train': 0.6928730217310098,
 'loss_valid': 0.7049793124198913,
 'roc_auc_train': 0.5543413729128015,
 'roc_auc_valid': 0.5383489461358314}


loss: 0.71213, total_loss: 0.68918: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.69738, total_loss: 0.70185: 100%|██████████████| 5/5 [00:20<00:00,  4.05s/it]


{'epoch': 3,
 'fold': 1,
 'loss_train': 0.6891801800483313,
 'loss_valid': 0.7018494367599487,
 'roc_auc_train': 0.5702504638218924,
 'roc_auc_valid': 0.5225409836065573}


loss: 0.61174, total_loss: 0.67379: 100%|████████████| 78/78 [01:55<00:00,  1.47s/it]
loss: 0.67213, total_loss: 0.79084: 100%|██████████████| 5/5 [00:20<00:00,  4.00s/it]


{'epoch': 4,
 'fold': 1,
 'loss_train': 0.6737852696424875,
 'loss_valid': 0.7908355236053467,
 'roc_auc_train': 0.6106493506493507,
 'roc_auc_valid': 0.49473067915690866}


loss: 0.54432, total_loss: 0.67855: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67781, total_loss: 0.68832: 100%|██████████████| 5/5 [00:20<00:00,  4.03s/it]


{'epoch': 5,
 'fold': 1,
 'loss_train': 0.6785543167438263,
 'loss_valid': 0.6883157372474671,
 'roc_auc_train': 0.6100371057513914,
 'roc_auc_valid': 0.5518149882903981}
loss_min (0.688316 --> 0.688316). Saving model ...
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  2
-----------------------------------------------------------------------------------------------------


loss: 0.65741, total_loss: 0.74448: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.71110, total_loss: 0.71987: 100%|██████████████| 5/5 [00:18<00:00,  3.65s/it]


{'epoch': 0,
 'fold': 2,
 'loss_train': 0.7444801685901788,
 'loss_valid': 0.7198694348335266,
 'roc_auc_train': 0.5367162249515192,
 'roc_auc_valid': 0.5257824143070045}
roc_auc_max (0.000000 --> 0.525782). Saving model ...
loss_min (0.719869 --> 0.719869). Saving model ...


loss: 0.64041, total_loss: 0.70117: 100%|████████████| 78/78 [01:58<00:00,  1.52s/it]
loss: 0.71855, total_loss: 0.70519: 100%|██████████████| 5/5 [00:18<00:00,  3.64s/it]


{'epoch': 1,
 'fold': 2,
 'loss_train': 0.7011667016224984,
 'loss_valid': 0.7051853537559509,
 'roc_auc_train': 0.5374642164558131,
 'roc_auc_valid': 0.5257824143070046}
roc_auc_max (0.525782 --> 0.525782). Saving model ...
loss_min (0.705185 --> 0.705185). Saving model ...


loss: 0.66954, total_loss: 0.68415: 100%|████████████| 78/78 [01:55<00:00,  1.48s/it]
loss: 0.70051, total_loss: 0.69864: 100%|██████████████| 5/5 [00:18<00:00,  3.77s/it]


{'epoch': 2,
 'fold': 2,
 'loss_train': 0.6841508448123932,
 'loss_valid': 0.6986369490623474,
 'roc_auc_train': 0.5819743281928156,
 'roc_auc_valid': 0.5222056631892698}
loss_min (0.698637 --> 0.698637). Saving model ...


loss: 0.69416, total_loss: 0.68600: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.67690, total_loss: 0.70257: 100%|██████████████| 5/5 [00:18<00:00,  3.64s/it]


{'epoch': 3,
 'fold': 2,
 'loss_train': 0.6860010677423233,
 'loss_valid': 0.702571177482605,
 'roc_auc_train': 0.5767660910518053,
 'roc_auc_valid': 0.5341281669150522}
roc_auc_max (0.525782 --> 0.534128). Saving model ...


loss: 0.61979, total_loss: 0.68085: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.73581, total_loss: 0.71972: 100%|██████████████| 5/5 [00:18<00:00,  3.62s/it]


{'epoch': 4,
 'fold': 2,
 'loss_train': 0.6808488858051789,
 'loss_valid': 0.7197165012359619,
 'roc_auc_train': 0.5944223843383507,
 'roc_auc_valid': 0.5204172876304024}


loss: 0.80914, total_loss: 0.68962: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 0.69777, total_loss: 0.70911: 100%|██████████████| 5/5 [00:18<00:00,  3.66s/it]


{'epoch': 5,
 'fold': 2,
 'loss_train': 0.6896163095266391,
 'loss_valid': 0.7091142654418945,
 'roc_auc_train': 0.5720842183027056,
 'roc_auc_valid': 0.5219076005961252}


loss: 0.64684, total_loss: 0.66092: 100%|████████████| 78/78 [01:54<00:00,  1.46s/it]
loss: 1.16309, total_loss: 1.09950: 100%|██████████████| 5/5 [00:18<00:00,  3.61s/it]


{'epoch': 6,
 'fold': 2,
 'loss_train': 0.6609209451155785,
 'loss_valid': 1.099501943588257,
 'roc_auc_train': 0.6408624988456921,
 'roc_auc_valid': 0.4944858420268256}


loss: 0.74440, total_loss: 0.66668: 100%|████████████| 78/78 [01:57<00:00,  1.50s/it]
loss: 0.72309, total_loss: 0.73305: 100%|██████████████| 5/5 [00:18<00:00,  3.77s/it]


{'epoch': 7,
 'fold': 2,
 'loss_train': 0.6666780286110364,
 'loss_valid': 0.7330497026443481,
 'roc_auc_train': 0.6316926770708283,
 'roc_auc_valid': 0.5108792846497764}


loss: 0.74954, total_loss: 0.65658: 100%|████████████| 78/78 [02:01<00:00,  1.55s/it]
loss: 0.73510, total_loss: 0.71617: 100%|██████████████| 5/5 [00:18<00:00,  3.77s/it]


{'epoch': 8,
 'fold': 2,
 'loss_train': 0.6565776860866791,
 'loss_valid': 0.7161687016487122,
 'roc_auc_train': 0.6563117554714195,
 'roc_auc_valid': 0.5394932935916542}
roc_auc_max (0.534128 --> 0.539493). Saving model ...


loss: 0.54935, total_loss: 0.65248: 100%|████████████| 78/78 [01:53<00:00,  1.46s/it]
loss: 0.68611, total_loss: 0.71708: 100%|██████████████| 5/5 [00:17<00:00,  3.60s/it]


{'epoch': 9,
 'fold': 2,
 'loss_train': 0.6524789306597832,
 'loss_valid': 0.7170772910118103,
 'roc_auc_train': 0.6505586850124666,
 'roc_auc_valid': 0.5165424739195231}


loss: 0.64883, total_loss: 0.61010: 100%|████████████| 78/78 [02:10<00:00,  1.67s/it]
loss: 0.65387, total_loss: 0.73098: 100%|██████████████| 5/5 [00:19<00:00,  3.81s/it]


{'epoch': 10,
 'fold': 2,
 'loss_train': 0.6100989924027369,
 'loss_valid': 0.7309792995452881,
 'roc_auc_train': 0.7212854372518238,
 'roc_auc_valid': 0.49239940387481373}


loss: 0.53622, total_loss: 0.55657: 100%|████████████| 78/78 [02:14<00:00,  1.73s/it]
loss: 1.05990, total_loss: 0.92500: 100%|██████████████| 5/5 [00:18<00:00,  3.72s/it]


{'epoch': 11,
 'fold': 2,
 'loss_train': 0.556571799211013,
 'loss_valid': 0.9249996185302735,
 'roc_auc_train': 0.7804321728691476,
 'roc_auc_valid': 0.5049180327868853}


loss: 0.41736, total_loss: 0.50354: 100%|████████████| 78/78 [02:05<00:00,  1.61s/it]
loss: 0.90668, total_loss: 1.02879: 100%|██████████████| 5/5 [00:18<00:00,  3.76s/it]


{'epoch': 12,
 'fold': 2,
 'loss_train': 0.5035408343642186,
 'loss_valid': 1.0287936329841614,
 'roc_auc_train': 0.8315079878105088,
 'roc_auc_valid': 0.46110283159463483}


loss: 0.83939, total_loss: 0.37072: 100%|████████████| 78/78 [01:57<00:00,  1.51s/it]
loss: 1.52904, total_loss: 1.53235: 100%|██████████████| 5/5 [00:18<00:00,  3.71s/it]


{'epoch': 13,
 'fold': 2,
 'loss_train': 0.37071823940063137,
 'loss_valid': 1.5323458790779114,
 'roc_auc_train': 0.9144057623049219,
 'roc_auc_valid': 0.46825633383010434}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  3
-----------------------------------------------------------------------------------------------------


loss: 0.78981, total_loss: 0.76858: 100%|████████████| 78/78 [02:16<00:00,  1.74s/it]
loss: 0.75659, total_loss: 0.72296: 100%|██████████████| 5/5 [00:18<00:00,  3.68s/it]


{'epoch': 0,
 'fold': 3,
 'loss_train': 0.7685831724069058,
 'loss_valid': 0.7229572057723999,
 'roc_auc_train': 0.5063163727029274,
 'roc_auc_valid': 0.5418777943368107}
roc_auc_max (0.000000 --> 0.541878). Saving model ...
loss_min (0.722957 --> 0.722957). Saving model ...


loss: 0.78951, total_loss: 0.70619: 100%|████████████| 78/78 [02:16<00:00,  1.75s/it]
loss: 0.79989, total_loss: 0.70738: 100%|██████████████| 5/5 [00:19<00:00,  3.82s/it]


{'epoch': 1,
 'fold': 3,
 'loss_train': 0.7061885190315735,
 'loss_valid': 0.7073768496513366,
 'roc_auc_train': 0.552645673654077,
 'roc_auc_valid': 0.5776453055141578}
roc_auc_max (0.541878 --> 0.577645). Saving model ...
loss_min (0.707377 --> 0.707377). Saving model ...


loss: 0.66309, total_loss: 0.69841: 100%|████████████| 78/78 [02:18<00:00,  1.77s/it]
loss: 0.71603, total_loss: 0.68886: 100%|██████████████| 5/5 [00:19<00:00,  3.80s/it]


{'epoch': 2,
 'fold': 3,
 'loss_train': 0.6984097804778662,
 'loss_valid': 0.6888598322868347,
 'roc_auc_train': 0.5576692215347677,
 'roc_auc_valid': 0.600596125186289}
roc_auc_max (0.577645 --> 0.600596). Saving model ...
loss_min (0.688860 --> 0.688860). Saving model ...


loss: 0.67330, total_loss: 0.69485: 100%|████████████| 78/78 [02:15<00:00,  1.74s/it]
loss: 0.68800, total_loss: 0.68813: 100%|██████████████| 5/5 [00:19<00:00,  3.81s/it]


{'epoch': 3,
 'fold': 3,
 'loss_train': 0.6948513098252125,
 'loss_valid': 0.6881259322166443,
 'roc_auc_train': 0.5548711792409272,
 'roc_auc_valid': 0.555290611028316}
loss_min (0.688126 --> 0.688126). Saving model ...


loss: 0.67639, total_loss: 0.68707: 100%|████████████| 78/78 [02:17<00:00,  1.76s/it]
loss: 0.73333, total_loss: 0.69234: 100%|██████████████| 5/5 [00:18<00:00,  3.77s/it]


{'epoch': 4,
 'fold': 3,
 'loss_train': 0.6870720141973251,
 'loss_valid': 0.692339813709259,
 'roc_auc_train': 0.565943300397082,
 'roc_auc_valid': 0.574962742175857}


loss: 0.59040, total_loss: 0.68235: 100%|████████████| 78/78 [02:17<00:00,  1.76s/it]
loss: 0.76431, total_loss: 0.70099: 100%|██████████████| 5/5 [00:18<00:00,  3.76s/it]


{'epoch': 5,
 'fold': 3,
 'loss_train': 0.6823474654020407,
 'loss_valid': 0.7009888410568237,
 'roc_auc_train': 0.5818635146366239,
 'roc_auc_valid': 0.5836065573770493}


loss: 0.61030, total_loss: 0.67585: 100%|████████████| 78/78 [02:16<00:00,  1.75s/it]
loss: 0.71224, total_loss: 0.69181: 100%|██████████████| 5/5 [00:18<00:00,  3.74s/it]


{'epoch': 6,
 'fold': 3,
 'loss_train': 0.6758514237709534,
 'loss_valid': 0.6918053388595581,
 'roc_auc_train': 0.6075722596731001,
 'roc_auc_valid': 0.5833084947839046}


loss: 0.64143, total_loss: 0.66981: 100%|████████████| 78/78 [02:17<00:00,  1.76s/it]
loss: 0.80869, total_loss: 0.75181: 100%|██████████████| 5/5 [00:18<00:00,  3.69s/it]


{'epoch': 7,
 'fold': 3,
 'loss_train': 0.6698111853538415,
 'loss_valid': 0.7518052697181702,
 'roc_auc_train': 0.616695909132884,
 'roc_auc_valid': 0.5594634873323399}
Early Stopping...
-----------------------------------------------------------------------------------------------------
                                        FOLD:  4
-----------------------------------------------------------------------------------------------------


loss: 0.71022, total_loss: 0.76039: 100%|████████████| 78/78 [02:17<00:00,  1.76s/it]
loss: 0.85153, total_loss: 0.79046: 100%|██████████████| 5/5 [00:18<00:00,  3.72s/it]


{'epoch': 0,
 'fold': 4,
 'loss_train': 0.7603858739901812,
 'loss_valid': 0.7904561400413513,
 'roc_auc_train': 0.4955489888262997,
 'roc_auc_valid': 0.3684053651266766}
roc_auc_max (0.000000 --> 0.368405). Saving model ...
loss_min (0.790456 --> 0.790456). Saving model ...


loss: 0.73488, total_loss: 0.70290: 100%|████████████| 78/78 [02:17<00:00,  1.77s/it]
loss: 0.69092, total_loss: 0.68081: 100%|██████████████| 5/5 [00:17<00:00,  3.59s/it]


{'epoch': 1,
 'fold': 4,
 'loss_train': 0.7028991618217566,
 'loss_valid': 0.6808054327964783,
 'roc_auc_train': 0.5355434481484902,
 'roc_auc_valid': 0.6396423248882265}
roc_auc_max (0.368405 --> 0.639642). Saving model ...
loss_min (0.680805 --> 0.680805). Saving model ...


loss: 0.81199, total_loss: 0.69573: 100%|████████████| 78/78 [02:01<00:00,  1.55s/it]
loss: 0.68243, total_loss: 0.68475: 100%|██████████████| 5/5 [00:18<00:00,  3.67s/it]


{'epoch': 2,
 'fold': 4,
 'loss_train': 0.6957292403930273,
 'loss_valid': 0.6847487330436707,
 'roc_auc_train': 0.5512143318866009,
 'roc_auc_valid': 0.5752608047690014}


loss: 0.86387, total_loss: 0.69379: 100%|████████████| 78/78 [02:09<00:00,  1.66s/it]
loss: 0.72243, total_loss: 0.67855: 100%|██████████████| 5/5 [00:18<00:00,  3.70s/it]


{'epoch': 3,
 'fold': 4,
 'loss_train': 0.6937901293620085,
 'loss_valid': 0.678551435470581,
 'roc_auc_train': 0.5429217840982546,
 'roc_auc_valid': 0.5916542473919523}
loss_min (0.678551 --> 0.678551). Saving model ...


loss: 0.71007, total_loss: 0.68301: 100%|████████████| 78/78 [01:58<00:00,  1.51s/it]
loss: 0.67361, total_loss: 0.67009: 100%|██████████████| 5/5 [00:17<00:00,  3.53s/it]


{'epoch': 4,
 'fold': 4,
 'loss_train': 0.6830070717976644,
 'loss_valid': 0.6700934767723083,
 'roc_auc_train': 0.589768214978299,
 'roc_auc_valid': 0.6205663189269746}
loss_min (0.670093 --> 0.670093). Saving model ...


loss: 0.66793, total_loss: 0.69108: 100%|████████████| 78/78 [01:55<00:00,  1.49s/it]
loss: 0.69891, total_loss: 0.69786: 100%|██████████████| 5/5 [00:17<00:00,  3.59s/it]


{'epoch': 5,
 'fold': 4,
 'loss_train': 0.6910755748932178,
 'loss_valid': 0.6978554606437684,
 'roc_auc_train': 0.5690737833594977,
 'roc_auc_valid': 0.5684053651266765}


loss: 0.72314, total_loss: 0.67176: 100%|████████████| 78/78 [01:54<00:00,  1.47s/it]
loss: 0.70782, total_loss: 0.69808: 100%|██████████████| 5/5 [00:17<00:00,  3.57s/it]


{'epoch': 6,
 'fold': 4,
 'loss_train': 0.6717615169592392,
 'loss_valid': 0.6980847835540771,
 'roc_auc_train': 0.6069720195770616,
 'roc_auc_valid': 0.5463487332339791}
Early Stopping...
