In [None]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np

from seg_data import getSegmentationDataset
from seg_model import getUNetForSegmentation, getUNETRForSegmentation
from transforms_dict import getSegmentationPostProcessingForLabel, getSegmentationPostProcessingForLabelOutput
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice

from monai.metrics import DiceMetric

from monai.transforms import Activations, AsDiscrete, Activations
from torch import nn

import matplotlib.pyplot as plt
import nibabel as nib

In [None]:
loss_gdice = monai.losses.GeneralizedDiceLoss(other_act=nn.Softmax(dim=1), include_background=True)
loss_gdicefoc = monai.losses.GeneralizedDiceFocalLoss(other_act=nn.Softmax(dim=1), include_background=True)
loss_dicece = monai.losses.DiceCELoss(include_background=False, other_act=nn.Softmax(dim=1))

#Metric MUNet
def metric_munet(preds, labels):        
    labels = labels.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()    
    labels[np.where(labels == np.amax(labels, axis=1))] = 1
    labels[labels != 1] = 0
    dice=2*np.sum(labels*preds,(0,2,3,4))/(np.sum((labels+preds),(0,2,3,4))+1)    
    return dice

In [None]:
#Parameters
dataset = "Femina3"
ft = None
ct = None
batchsize = 1
num_epochs = 500
lr = 0.001
factor = 0.9
patience = 7
augment = True
N4 = False
#Modelname and device
torch.multiprocessing.set_sharing_strategy('file_system')
#modelname = check_model_name(modelname)
#print_model_output(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()

#postprocessing
outputs_processing = getSegmentationPostProcessingForLabelOutput()
labels_processing = getSegmentationPostProcessingForLabel()

#activations
softmax = Activations(other=nn.Softmax(dim=1))


modelnames = [#"test_labels_femina3_gdice.pth",
              #"test_labels_femina3_gdicefoc.pth",
              #"test_labels_femina3_dicece.pth",
              "test_labels_femina3_unetr_dicece.pth",
]
              
              

for modelname in modelnames:
    

    #dataloaders
    dataloaders, size = getSegmentationDataset(dataset=dataset, batch=batchsize, augment=False, training=False, n4=N4, labels=True, eval_augment=False)

    #Train loop
    best_loss = 1
    writer = SummaryWriter()

    #Model optimizer and scheduler
    #model = getUNetForSegmentation()
    model = getUNETRForSegmentation()
    optimizer = getAdamOptimizer(model, lr)
    scheduler = getReducePlateauScheduler(optimizer, patience=patience, factor=factor)
    loadExistingModel(model, optimizer, ft, ct)
    softmax = Activations(other=nn.Softmax(dim=1))
    ptdr = "dataset3/Atlas/P56_Atlas_128_norm_id.nii.gz"
    affine = nib.load(ptdr).affine
    header = nib.load(ptdr).header
    for epoch in range(num_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{num_epochs}")

        train_loss = 0
        valid_loss = 0

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            metrics = []
            for i in range(4):
                metrics.append([])

            for i, data in enumerate(dataloaders[phase]):
                print(i, end='\r')
                optimizer.zero_grad()      

                with torch.set_grad_enabled(phase == 'train'):
                    inputs, labels = data["img"].to(device), data["seg"].to(device)

                    outputs = model(inputs)  
                    onehot_labels = monai.networks.utils.one_hot(labels, num_classes=4,dim=1)

                    if modelname == "test_labels_femina3_gdice.pth":
                        loss = loss_gdice(outputs, onehot_labels)
                    elif modelname == "test_labels_femina3_gdicefoc.pth":
                        loss = loss_gdicefoc(outputs, onehot_labels)
                    elif modelname == "test_labels_femina3_dicece.pth" or modelname == "test_labels_femina3_unetr_dicece.pth":
                        loss = loss_dicece(outputs, onehot_labels)
                            
                    probs = softmax(outputs)             

                    dice_metric = DiceMetric(include_background=True, reduction="mean_channel", get_not_nans=False)
                    dice_metric(y_pred=probs.squeeze(), y=onehot_labels.squeeze())
                    metric = dice_metric.aggregate()

                    for j in range(4):
                        metrics[j].append(metric[j].item())                

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    #if phase=='valid' and i==5:
                    #    print(data['img_meta_dict']['filename_or_obj'])
                    #    print(data['seg_meta_dict']['filename_or_obj'])
                    #    image = inputs.squeeze().detach().cpu().numpy()
                    #    plt.imshow(image[64,:,:])
                    #    outputs_labels = outputs.detach().cpu().numpy()
                    #    outputs_labels = np.argmax(outputs_labels, axis=1).squeeze()
                    #    plt.imshow(outputs_labels[64,:,:], cmap='jet', alpha=0.5)
                    #    plt.show()
                    #    nib.save(nib.Nifti1Image(outputs_labels, affine, header), "test.nii.gz")  


                running_loss += loss.item() * inputs.size(0)

            running_loss /= size[phase]
            metrics_mean = [np.mean(x) for x in metrics]
            running_metric = np.mean(metrics_mean)     
            print(
                "{}: loss: {:.4f}, dice: {:.4f}".format(
                    phase, running_loss, running_metric
                )
            )
            print(
                "dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(
                    metrics_mean[0], metrics_mean[1], metrics_mean[2], metrics_mean[3]
                )
            )

            if phase == 'train':
                train_loss = running_loss            
            elif phase == 'valid':
                valid_loss = running_loss
                scheduler.step(running_loss)
                if running_loss < best_loss:
                    best_loss = running_loss
                    best_epoch = epoch + 1
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()
                    },
                        './models/' + str(modelname))

                    print(
                        "best loss {:.4f} at epoch {}".format(
                            best_loss, best_epoch
                        )
                    )            
        writer.add_scalars('epoch_loss', {
            'train': train_loss,
            'valid': valid_loss,
        }, epoch + 1)

    print(f"train completed")
    writer.close()

In [None]:
import nibabel as nib
import os
from glob import glob

def get_labels_as_one_hot(seg):
    labels=np.zeros([3]+list(seg.shape))
    for k in range(int(seg.max())+1):
        one_hot=np.zeros_like(seg)        
        one_hot[seg==k]=1 
        labels[k,:,:,:]=one_hot
    return labels

path = data['seg_file'][0]
print(path)
data_raw = nib.load(path).get_fdata()
labels_raw = get_labels_as_one_hot(data_raw)

print(data_raw.shape)
print(labels_raw.shape)

a = np.where(labels_raw[0,:,:,:] == 1)
b = np.where(labels_raw[1,:,:,:] == 1)
c = np.where(labels_raw[2,:,:,:] == 1)
d = np.where(labels_raw[3,:,:,:] == 1)
e = np.where((labels_raw[0,:,:,:] == 0) & (labels_raw[1,:,:,:] == 0) & (labels_raw[2,:,:,:] == 0) & (labels_raw[3,:,:,:] == 0))

maxV = 305*216*227

print(len(e[0])/maxV*100)
print(len(a[0])/maxV*100)
print(len(b[0])/maxV*100)
print(len(c[0])/maxV*100)
print(len(d[0])/maxV*100)

In [None]:
from monai.transforms import Resize
labels_raw_resize = Resize(spatial_size=(128,128,128))(labels_raw)
print(labels_raw_resize.shape)


a = np.where(labels_raw_resize[0,:,:,:] == 1)
b = np.where(labels_raw_resize[1,:,:,:] == 1)
c = np.where(labels_raw_resize[2,:,:,:] == 1)
d = np.where(labels_raw_resize[3,:,:,:] == 1)
e = np.where((labels_raw_resize[0,:,:,:] == 0) & (labels_raw_resize[1,:,:,:] == 0) & (labels_raw_resize[2,:,:,:] == 0) & (labels_raw_resize[3,:,:,:] == 0))

maxV = 128*128*128

print(len(e[0])/maxV*100)
print(len(a[0])/maxV*100)
print(len(b[0])/maxV*100)
print(len(c[0])/maxV*100)
print(len(d[0])/maxV*100)

In [None]:
data = next(iter(dataloaders['train']))
labels = data['seg']
print(data['seg'].shape)
print(labels.shape)

a = np.where(labels[0,0,0,:,:,:] == 1)
b = np.where(labels[0,0,1,:,:,:] == 1)
c = np.where(labels[0,0,2,:,:,:] == 1)
d = np.where(labels[0,0,3,:,:,:] == 1)
e = np.where((labels[0,0,0,:,:,:] == 0) & (labels[0,0,1,:,:,:] == 0) & (labels[0,0,2,:,:,:] == 0) & (labels[0,0,3,:,:,:] == 0))

maxV = 128*128*128

print(len(e[0])/maxV*100)
print(len(a[0])/maxV*100)
print(len(b[0])/maxV*100)
print(len(c[0])/maxV*100)
print(len(d[0])/maxV*100)