In [1]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np

from seg_data import getDataset
from transforms_dict import getSegmentationPostProcessingForLabel, getSegmentationPostProcessingForLabelOutput
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice
from seg_model import getUNetForSegmentation

from monai.transforms import Activations
from torch import nn
from tqdm import tqdm

from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.metrics import DiceMetric

In [2]:
#Parameters
modelname = "test_seg_labels.pth"
dataset = "painfactlabels"
ft = None
ct = None
batchsize = 1
num_epochs = 200
factor = 0.9
patience = 10
augment = True
N4 = False

In [3]:
#Modelname and device
torch.multiprocessing.set_sharing_strategy('file_system')
modelname = check_model_name(modelname)
print_model_output(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()

=> Saving to test_seg_labels.pth


In [4]:
#postprocessing
outputs_processing = getSegmentationPostProcessingForLabelOutput()
labels_processing = getSegmentationPostProcessingForLabel()

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [5]:
#dataloaders
dataloaders, size = getDataset(dataset=dataset, batch=batchsize, augment=augment, training=True, n4=N4)

=> Using Painfact-Segmentation dataset.
=> Using augmented transforms for train set


In [6]:
#Metric MUNet
def metric_munet(preds, labels):  
    labels = labels.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()    
    labels[np.where(labels == np.amax(labels, axis=0))] = 1
    labels[labels != 1] = 0
    dice=2*np.sum(labels*preds,(1,2,3))/(np.sum((labels+preds),(1,2,3))+1)    
    return dice

In [7]:
#Loss
def loss_Dice(preds, labels):
    dice = 1-torch.div(
        torch.sum(torch.mul(torch.mul(labels,preds),2)),
        torch.sum(torch.mul(preds,preds)) + torch.sum(torch.mul(labels,labels))
        )    
    return dice

def loss_CE(input, target):
        n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
        if n_pred_ch == n_target_ch:
            target = torch.argmax(target, dim=1)
        else:
            target = torch.squeeze(target, dim=1)
        target = target.long()        
        device = getDevice()
        weight = torch.FloatTensor([1.0, 5.0, 5.0, 20.0]).to(device)
        return nn.CrossEntropyLoss(reduction="mean", weight=weight)(input, target)

loss_GDice = monai.losses.GeneralizedDiceLoss(other_act=nn.Softmax(dim=1))
weight = torch.FloatTensor([1.0, 5.0, 5.0, 20.0]).to(device)
loss_DiceCE = monai.losses.DiceCELoss(other_act=nn.Softmax(dim=1), ce_weight=weight)

def loss_GDiceCE(input, target, lambda_gdice=1.0, lambda_ce=1.0):    
    GDice = loss_GDice(input, target)
    CE = loss_CE(input, target)    
    GDiceCELoss = lambda_gdice*GDice + lambda_ce*CE
    return GDiceCELoss

In [8]:
#Train loop
best_loss = np.inf
writer = SummaryWriter()

torch.backends.cudnn.benchmark = True

#Model optimizer and scheduler
model = getUNetForSegmentation()
lr = 1e-3#/np.sqrt(6)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = getReducePlateauScheduler(optimizer, patience=patience, factor=factor)
loadExistingModel(model, optimizer, ft, ct)

train_dices = []
valid_dices = []

for epoch in range(num_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{num_epochs}")

    train_loss = 0
    valid_loss = 0

    for phase in ['train', 'valid']:
        model.train()
        
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0        
        metrics = []
        for i in range(4):
            metrics.append([])
        
        if phase == 'train' or (epoch % 4 == 0 and epoch != 0):
            for i, data in enumerate(dataloaders[phase]):
                print("{}/{}".format(
                    i, len(dataloaders[phase])), end='\r'
                )
                optimizer.zero_grad()            

                with torch.set_grad_enabled(phase == 'train'):
                    inputs, labels = data["img"].to(device), data["seg"].to(device)
                    labels = labels.squeeze(2) 

                    if phase == 'train':        
                        outputs = model(inputs)
                    else:
                        outputs = sliding_window_inference(inputs, (96, 96, 96), 4, model)                
                    loss = loss_GDiceCE(outputs, labels)
                    preds = [outputs_processing(pred) for pred in decollate_batch(outputs)]
                    labels = [labels_processing(label) for label in decollate_batch(labels)]  
                    for j in range(len(preds)):
                        metric = metric_munet(preds[j], labels[j])                     
                        for k in range(4):
                            metrics[k].append(metric[k])                

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item()

            running_loss /= size[phase]
            metrics_mean = [np.mean(x) for x in metrics]
            running_metric = np.mean(metrics_mean)      

            print(
                "{}: loss: {:.4f}, dice: {:.4f}".format(
                    phase, running_loss, running_metric
                )
            )
            print(
                "dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(
                    metrics_mean[0], metrics_mean[1], metrics_mean[2], metrics_mean[3]
                )
            )

            if phase == 'train':
                train_loss = running_loss               
                train_dices.append(running_metric)
            elif phase == 'valid':
                valid_loss = running_loss 
                valid_dices.append(running_metric)
                scheduler.step(running_loss)
                if running_loss < best_loss:
                    best_loss = running_loss
                    best_epoch = epoch + 1
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()
                    },
                        './models/' + str(modelname))

                    print(
                        "best loss {:.4f} at epoch {}".format(
                            best_loss, best_epoch
                        )
                    )            
    writer.add_scalars('epoch_loss', {
        'train': train_loss,
        'valid': valid_loss,
    }, epoch + 1)
    

print(f"train completed")
writer.close()

----------
epoch 1/200
train: loss: 1.6316, dice: 0.5564
dices: 0.8353, 0.7898, 0.4775, 0.1230
----------
epoch 2/200
train: loss: 1.4521, dice: 0.6257
dices: 0.8757, 0.8273, 0.5901, 0.2099
----------
epoch 3/200
train: loss: 1.3752, dice: 0.6474
dices: 0.8995, 0.8350, 0.6154, 0.2397
----------
epoch 4/200
train: loss: 1.3141, dice: 0.6708
dices: 0.9249, 0.8434, 0.6552, 0.2595
----------
epoch 5/200
train: loss: 1.2672, dice: 0.6814
dices: 0.9293, 0.8530, 0.6693, 0.2740
valid: loss: 1.0909, dice: 0.7204
dices: 0.9785, 0.8797, 0.7621, 0.2611
best loss 1.0909 at epoch 5
----------
epoch 6/200
train: loss: 1.2368, dice: 0.6901
dices: 0.9304, 0.8598, 0.6857, 0.2845
----------
epoch 7/200
train: loss: 1.2215, dice: 0.6923
dices: 0.9353, 0.8645, 0.6861, 0.2835
----------
epoch 8/200
train: loss: 1.1934, dice: 0.7048
dices: 0.9347, 0.8667, 0.7152, 0.3024
----------
epoch 9/200
train: loss: 1.1751, dice: 0.7067
dices: 0.9376, 0.8731, 0.7110, 0.3051
valid: loss: 1.0028, dice: 0.7427
dices: 0.98

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6f6f78ca60>
Traceback (most recent call last):
  File "/home/valentini/dev/Mousenet/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/valentini/dev/Mousenet/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1322, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 

KeyboardInterrupt

