In [1]:
import sys
sys.path.remove('/home/sentic/.local/lib/python3.6/site-packages')

import torch
# torch.backends.cudnn.benchmark=True

torch.cuda.set_enabled_lms(True)
torch.cuda.set_size_lms(120009999)

device_id = 0
torch.cuda.set_device(device_id)

#root = "../train"
root = "/home/sentic/MICCAI/data/train/"
use_gpu = True
n_epochs = 200
batch_size = 1
use_amp = False
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import LambdaLR

from torch.utils.data import DataLoader

from model import LargeCascadedModel
from dataset import BraTS
from losses import DiceLoss
from tqdm import tqdm_notebook, tqdm
from learning_rate import GradualWarmupScheduler, PolyLR

import pytorch_warmup as warmup

In [2]:
path_resume = "./checkpoints/checkpoint_188.pt"
checkpoint_optimizer = True
model = LargeCascadedModel(inplanes_encoder_1=4, channels_encoder_1=16, num_classes_1=3,
                           inplanes_encoder_2=7, channels_encoder_2=32, num_classes_2=3)

if use_gpu:
    model = model.to("cuda")

start_lr = 1e-4
optimizer = optim.Adam(list(filter(lambda p: p.requires_grad, model.parameters())), lr=start_lr, weight_decay=1e-5)
scheduler = PolyLR(optimizer, max_decay_steps=n_epochs, end_learning_rate=1e-8, power=0.9)

if path_resume is not None:
    dict_state = torch.load(path_resume)
    model.load_state_dict(dict_state['state_dict'])
    if checkpoint_optimizer:
        optimizer.load_state_dict(dict_state['optimizer'])
        last_epoch = dict_state['epoch']
    else:
        last_epoch = 0
else:
    last_epoch = 0
#scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.85)


# warm_up = GradualWarmupScheduler(optimizer, multiplier=2, total_epoch=5, after_scheduler=scheduler)
# warm_up = warmup.LinearWarmup(optimizer, warmup_period=369*5)

BraTSDataset = BraTS(root=root, phase="train", desired_depth=128, desired_height=240, desired_width=240, 
                     normalize_flag=True, scale_intensity_flag=True, shift_intesity_flag=True, flip_axes_flag=True)
dataloader = DataLoader(BraTSDataset, batch_size=batch_size, shuffle=True)

diceLoss = DiceLoss()

In [None]:
string = ""
for epoch in range(n_epochs - last_epoch):
    epoch_loss = 0
    for param_group in optimizer.param_groups:
        print("Learning rate =", param_group['lr'])
        break
    for ix, loader in tqdm(enumerate(dataloader)):
        volume, mask, contour, patient = loader
        torch.cuda.empty_cache()
        if use_gpu:
            volume = volume.to("cuda")
            mask = mask.to("cuda")
            contour = contour.to("cuda")
            
        
        optimizer.zero_grad()
        decoded_region1, decoded_region2, decoded_region3, decoded_contour = model(volume)
        
        loss_dice_region1 = diceLoss(decoded_region1, mask)
        loss_dice_region2 = diceLoss(decoded_region2, mask)
        loss_dice_region3 = diceLoss(decoded_region3, mask)
        loss_dice_contour = diceLoss(decoded_contour, contour)
        loss = loss_dice_region1 + loss_dice_region2 + loss_dice_region3 + loss_dice_contour
        
        with torch.no_grad():
            string += "P: {} DC R1: {}, DC R2: {}, DC R3: {},DC C: {}".format(patient, - np.round(loss_dice_region1.item(), 4),
                                                                              - np.round(loss_dice_region2.item(), 4),
                                                                              - np.round(loss_dice_region3.item(), 4),
                                                                               - np.round(loss_dice_contour.item(), 4))
            print("P: {} DC R1: {}, DC R2: {}, DC R3: {}, DC C: {}".format(patient, - np.round(loss_dice_region1.item(), 4),
                                                                           - np.round(loss_dice_region2.item(), 4),
                                                                           - np.round(loss_dice_region3.item(), 4),
                                                                           - np.round(loss_dice_contour.item(), 4)))


            string += "\n"
            
        with torch.no_grad():
            epoch_loss += loss_dice_region3.item()
        
        loss.backward()
        optimizer.step()
        
        scheduler.step(step=(epoch + last_epoch))
        # warm_up.dampen()
        
        
        del volume
        del mask
        del contour
        del decoded_region1
        del decoded_region2
        del decoded_region3
        del decoded_contour
        del loss
        del loss_dice_region1
        del loss_dice_region2
        del loss_dice_region3
        del loss_dice_contour
          
    epoch_loss = epoch_loss / (ix + 1)
    string += "Epoch {}: loss {}".format(epoch, epoch_loss)
    print("Epoch {}: loss {}".format(epoch, epoch_loss))
    for param_group in optimizer.param_groups:
        print("Learning rate =", param_group['lr'])
        break
    string += "\n"
    if epoch % 10 == 0 and epoch != 0:
        with open("results.txt", "w") as fhandle:
            print("Logged the results")
            fhandle.write(string)
    if (epoch + last_epoch) % 1 == 0 and epoch != 0:
        path_checkpoint = "/home/sentic/MICCAI/Madu/APPROACH_6/checkpoints/checkpoint_" + str(epoch + last_epoch) + ".pt"
        dict_state = {'epoch': epoch + last_epoch,
                     'state_dict': model.state_dict(),
                     'optimizer': optimizer.state_dict()
                     }
        print(path_checkpoint)
        torch.save(dict_state, path_checkpoint)

0it [00:00, ?it/s]

Learning rate = 7.958637544298869e-06
P: ('BraTS20_Training_148',) DC R1: 0.8937, DC R2: 0.9039, DC R3: 0.9132, DC C: 0.6045


1it [00:22, 22.53s/it]

P: ('BraTS20_Training_239',) DC R1: 0.9209, DC R2: 0.9287, DC R3: 0.944, DC C: 0.7322


2it [00:42, 21.63s/it]

P: ('BraTS20_Training_209',) DC R1: 0.9478, DC R2: 0.9473, DC R3: 0.9525, DC C: 0.7093


3it [01:01, 21.09s/it]

P: ('BraTS20_Training_127',) DC R1: 0.9752, DC R2: 0.9772, DC R3: 0.9824, DC C: 0.8533


4it [01:21, 20.71s/it]

P: ('BraTS20_Training_088',) DC R1: 0.9153, DC R2: 0.921, DC R3: 0.9262, DC C: 0.6622


5it [01:41, 20.50s/it]

P: ('BraTS20_Training_124',) DC R1: 0.9699, DC R2: 0.9721, DC R3: 0.9775, DC C: 0.7905


6it [02:01, 20.29s/it]

P: ('BraTS20_Training_210',) DC R1: 0.9646, DC R2: 0.9672, DC R3: 0.9669, DC C: 0.6851


7it [02:21, 20.16s/it]

P: ('BraTS20_Training_035',) DC R1: 0.8949, DC R2: 0.9024, DC R3: 0.9138, DC C: 0.6325


8it [02:41, 20.06s/it]

P: ('BraTS20_Training_134',) DC R1: 0.9065, DC R2: 0.9114, DC R3: 0.9221, DC C: 0.6753


9it [03:01, 20.02s/it]

P: ('BraTS20_Training_016',) DC R1: 0.8825, DC R2: 0.8837, DC R3: 0.8905, DC C: 0.5291


10it [03:21, 19.99s/it]

P: ('BraTS20_Training_105',) DC R1: 0.9587, DC R2: 0.9617, DC R3: 0.9676, DC C: 0.8138


11it [03:40, 19.91s/it]

P: ('BraTS20_Training_106',) DC R1: 0.9544, DC R2: 0.9582, DC R3: 0.9647, DC C: 0.7648


12it [04:00, 19.88s/it]

P: ('BraTS20_Training_179',) DC R1: 0.952, DC R2: 0.9579, DC R3: 0.9703, DC C: 0.6881


13it [04:20, 19.84s/it]

P: ('BraTS20_Training_031',) DC R1: 0.9278, DC R2: 0.93, DC R3: 0.9343, DC C: 0.7021


14it [04:40, 19.86s/it]

P: ('BraTS20_Training_232',) DC R1: 0.9026, DC R2: 0.9068, DC R3: 0.9173, DC C: 0.689


15it [05:00, 19.90s/it]

P: ('BraTS20_Training_158',) DC R1: 0.9137, DC R2: 0.9179, DC R3: 0.9344, DC C: 0.7704


16it [05:20, 19.89s/it]

P: ('BraTS20_Training_348',) DC R1: 0.9597, DC R2: 0.9624, DC R3: 0.9658, DC C: 0.7667


17it [05:40, 20.00s/it]

P: ('BraTS20_Training_191',) DC R1: 0.9608, DC R2: 0.962, DC R3: 0.9653, DC C: 0.7179


18it [06:00, 19.93s/it]

P: ('BraTS20_Training_001',) DC R1: 0.9427, DC R2: 0.9461, DC R3: 0.9481, DC C: 0.6689


19it [06:19, 19.85s/it]

P: ('BraTS20_Training_023',) DC R1: 0.9411, DC R2: 0.9439, DC R3: 0.9491, DC C: 0.7591


20it [06:39, 19.78s/it]

P: ('BraTS20_Training_298',) DC R1: 0.7843, DC R2: 0.7939, DC R3: 0.8244, DC C: 0.5771


21it [06:59, 19.83s/it]

P: ('BraTS20_Training_081',) DC R1: 0.9062, DC R2: 0.904, DC R3: 0.9211, DC C: 0.6702


22it [07:19, 19.86s/it]

P: ('BraTS20_Training_111',) DC R1: 0.955, DC R2: 0.957, DC R3: 0.965, DC C: 0.7746


23it [07:39, 19.93s/it]

P: ('BraTS20_Training_048',) DC R1: 0.8866, DC R2: 0.8891, DC R3: 0.9047, DC C: 0.7253


24it [07:59, 19.90s/it]

P: ('BraTS20_Training_357',) DC R1: 0.9353, DC R2: 0.9384, DC R3: 0.9515, DC C: 0.7448


25it [08:18, 19.82s/it]

P: ('BraTS20_Training_024',) DC R1: 0.9169, DC R2: 0.9186, DC R3: 0.9282, DC C: 0.611


26it [08:38, 19.89s/it]

P: ('BraTS20_Training_245',) DC R1: 0.9265, DC R2: 0.9326, DC R3: 0.9312, DC C: 0.5452


27it [08:58, 19.86s/it]

P: ('BraTS20_Training_334',) DC R1: 0.9516, DC R2: 0.9566, DC R3: 0.9642, DC C: 0.7073


28it [09:18, 19.89s/it]

P: ('BraTS20_Training_208',) DC R1: 0.9423, DC R2: 0.9437, DC R3: 0.9491, DC C: 0.6926


29it [09:38, 19.91s/it]

P: ('BraTS20_Training_142',) DC R1: 0.954, DC R2: 0.9578, DC R3: 0.9633, DC C: 0.7661


30it [09:58, 19.90s/it]

P: ('BraTS20_Training_333',) DC R1: 0.9113, DC R2: 0.9145, DC R3: 0.9297, DC C: 0.6232


31it [10:18, 19.97s/it]

P: ('BraTS20_Training_229',) DC R1: 0.9689, DC R2: 0.9715, DC R3: 0.9742, DC C: 0.7469


32it [10:38, 19.87s/it]

P: ('BraTS20_Training_100',) DC R1: 0.9546, DC R2: 0.9587, DC R3: 0.9653, DC C: 0.7784


33it [10:58, 19.84s/it]

P: ('BraTS20_Training_246',) DC R1: 0.9499, DC R2: 0.9538, DC R3: 0.961, DC C: 0.679


34it [11:17, 19.79s/it]

P: ('BraTS20_Training_013',) DC R1: 0.8778, DC R2: 0.88, DC R3: 0.8978, DC C: 0.7053


35it [11:37, 19.82s/it]

P: ('BraTS20_Training_278',) DC R1: 0.6203, DC R2: 0.6299, DC R3: 0.634, DC C: 0.4038


36it [11:57, 19.75s/it]

P: ('BraTS20_Training_007',) DC R1: 0.9271, DC R2: 0.9306, DC R3: 0.9363, DC C: 0.7172


37it [12:16, 19.76s/it]

P: ('BraTS20_Training_093',) DC R1: 0.972, DC R2: 0.9738, DC R3: 0.979, DC C: 0.8088


38it [12:36, 19.82s/it]

P: ('BraTS20_Training_129',) DC R1: 0.8892, DC R2: 0.8938, DC R3: 0.8961, DC C: 0.6863


39it [12:56, 19.90s/it]

P: ('BraTS20_Training_349',) DC R1: 0.9624, DC R2: 0.9656, DC R3: 0.9692, DC C: 0.7353


40it [13:16, 19.88s/it]

P: ('BraTS20_Training_097',) DC R1: 0.9582, DC R2: 0.9602, DC R3: 0.9689, DC C: 0.7825


41it [13:36, 19.86s/it]

P: ('BraTS20_Training_082',) DC R1: 0.8921, DC R2: 0.9003, DC R3: 0.9135, DC C: 0.6984


42it [13:56, 19.88s/it]

P: ('BraTS20_Training_202',) DC R1: 0.9643, DC R2: 0.9652, DC R3: 0.9652, DC C: 0.7189


43it [14:16, 19.86s/it]

P: ('BraTS20_Training_014',) DC R1: 0.9143, DC R2: 0.9174, DC R3: 0.9208, DC C: 0.6527


44it [14:36, 19.95s/it]

P: ('BraTS20_Training_259',) DC R1: 0.9426, DC R2: 0.9498, DC R3: 0.9582, DC C: 0.7176


45it [14:56, 19.86s/it]

P: ('BraTS20_Training_292',) DC R1: 0.9025, DC R2: 0.9053, DC R3: 0.9222, DC C: 0.6504


46it [15:15, 19.80s/it]

P: ('BraTS20_Training_270',) DC R1: 0.9763, DC R2: 0.9783, DC R3: 0.9831, DC C: 0.8296


47it [15:35, 19.73s/it]

P: ('BraTS20_Training_019',) DC R1: 0.9147, DC R2: 0.9167, DC R3: 0.921, DC C: 0.6752


48it [15:55, 19.71s/it]

P: ('BraTS20_Training_066',) DC R1: 0.918, DC R2: 0.9221, DC R3: 0.9274, DC C: 0.7267


49it [16:15, 19.79s/it]

P: ('BraTS20_Training_050',) DC R1: 0.9394, DC R2: 0.9424, DC R3: 0.9481, DC C: 0.7215


50it [16:35, 19.86s/it]

P: ('BraTS20_Training_045',) DC R1: 0.9476, DC R2: 0.9489, DC R3: 0.9532, DC C: 0.7169


51it [16:54, 19.82s/it]

P: ('BraTS20_Training_315',) DC R1: 0.9439, DC R2: 0.9475, DC R3: 0.9688, DC C: 0.8073


52it [17:14, 19.87s/it]

P: ('BraTS20_Training_146',) DC R1: 0.9383, DC R2: 0.9406, DC R3: 0.9456, DC C: 0.6398


53it [17:34, 19.82s/it]

P: ('BraTS20_Training_300',) DC R1: 0.9339, DC R2: 0.938, DC R3: 0.9469, DC C: 0.6761


54it [17:54, 19.77s/it]

P: ('BraTS20_Training_157',) DC R1: 0.9632, DC R2: 0.9662, DC R3: 0.9757, DC C: 0.7982


55it [18:14, 19.83s/it]

P: ('BraTS20_Training_118',) DC R1: 0.9645, DC R2: 0.9675, DC R3: 0.9734, DC C: 0.8161


56it [18:34, 19.92s/it]

P: ('BraTS20_Training_055',) DC R1: 0.9475, DC R2: 0.9491, DC R3: 0.9533, DC C: 0.7455


57it [18:54, 19.98s/it]

P: ('BraTS20_Training_342',) DC R1: 0.9433, DC R2: 0.9466, DC R3: 0.9564, DC C: 0.6943


58it [19:14, 20.07s/it]

P: ('BraTS20_Training_369',) DC R1: 0.9684, DC R2: 0.9699, DC R3: 0.9748, DC C: 0.7981


59it [19:35, 20.20s/it]

P: ('BraTS20_Training_231',) DC R1: 0.9648, DC R2: 0.9669, DC R3: 0.9688, DC C: 0.698


60it [19:54, 20.05s/it]

P: ('BraTS20_Training_197',) DC R1: 0.9471, DC R2: 0.949, DC R3: 0.9532, DC C: 0.6099


61it [20:14, 19.95s/it]

P: ('BraTS20_Training_162',) DC R1: 0.9691, DC R2: 0.9711, DC R3: 0.9733, DC C: 0.7504


62it [20:34, 19.84s/it]

P: ('BraTS20_Training_033',) DC R1: 0.9532, DC R2: 0.9545, DC R3: 0.957, DC C: 0.7346


63it [20:53, 19.77s/it]

P: ('BraTS20_Training_295',) DC R1: 0.7747, DC R2: 0.7781, DC R3: 0.7966, DC C: 0.5763


64it [21:13, 19.74s/it]

P: ('BraTS20_Training_275',) DC R1: 0.5625, DC R2: 0.5709, DC R3: 0.5806, DC C: 0.3339


65it [21:33, 19.69s/it]

P: ('BraTS20_Training_212',) DC R1: 0.9142, DC R2: 0.9205, DC R3: 0.9366, DC C: 0.6755


66it [21:52, 19.71s/it]

P: ('BraTS20_Training_240',) DC R1: 0.8991, DC R2: 0.9083, DC R3: 0.9197, DC C: 0.6208


67it [22:12, 19.76s/it]

P: ('BraTS20_Training_268',) DC R1: 0.5955, DC R2: 0.5983, DC R3: 0.5985, DC C: 0.3126


68it [22:32, 19.79s/it]

P: ('BraTS20_Training_364',) DC R1: 0.9534, DC R2: 0.9535, DC R3: 0.9624, DC C: 0.7746


69it [22:52, 19.93s/it]

P: ('BraTS20_Training_108',) DC R1: 0.9695, DC R2: 0.972, DC R3: 0.9775, DC C: 0.8377


70it [23:12, 19.89s/it]

P: ('BraTS20_Training_107',) DC R1: 0.9656, DC R2: 0.9678, DC R3: 0.9732, DC C: 0.8138


71it [23:32, 19.90s/it]

P: ('BraTS20_Training_308',) DC R1: 0.8836, DC R2: 0.8891, DC R3: 0.9014, DC C: 0.5912


72it [23:52, 19.82s/it]

P: ('BraTS20_Training_277',) DC R1: 0.9, DC R2: 0.9121, DC R3: 0.917, DC C: 0.69


73it [24:12, 19.87s/it]

P: ('BraTS20_Training_069',) DC R1: 0.9113, DC R2: 0.9153, DC R3: 0.9222, DC C: 0.7088


74it [24:31, 19.81s/it]

P: ('BraTS20_Training_073',) DC R1: 0.9298, DC R2: 0.9328, DC R3: 0.9382, DC C: 0.7403


75it [24:51, 19.80s/it]

P: ('BraTS20_Training_238',) DC R1: 0.9297, DC R2: 0.9357, DC R3: 0.9451, DC C: 0.6558


76it [25:11, 19.79s/it]

P: ('BraTS20_Training_227',) DC R1: 0.919, DC R2: 0.9272, DC R3: 0.9395, DC C: 0.7411


77it [25:31, 19.78s/it]

P: ('BraTS20_Training_041',) DC R1: 0.9173, DC R2: 0.9186, DC R3: 0.9261, DC C: 0.6875


78it [25:50, 19.80s/it]

P: ('BraTS20_Training_130',) DC R1: 0.9442, DC R2: 0.9477, DC R3: 0.9494, DC C: 0.676


79it [26:10, 19.79s/it]

P: ('BraTS20_Training_358',) DC R1: 0.9634, DC R2: 0.9668, DC R3: 0.9734, DC C: 0.8055


In [None]:
# path_checkpoint = "/home/sentic/MICCAI/Madu/APPROACH_6/checkpoints/checkpoint_" + str(42) + ".pt"
# dict_state = {'epoch': 42,
#              'state_dict': model.state_dict(),
#              'optimizer': optimizer.state_dict()
#              }
# print(path_checkpoint)
# torch.save(dict_state, path_checkpoint)

In [None]:
for param_group in optimizer.param_groups:
    print("Learning rate =", param_group['lr'])
    break