In [1]:
from brats_data_loader import get_list_of_patients, get_train_transform, iterate_through_patients, BRATSDataLoader
from train_test_function import ModelTrainer
from jonas_net import AlbuNet3D34

from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.dataloading import MultiThreadedAugmenter

In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

In [3]:
patients = get_list_of_patients('brats_data_preprocessed/Brats17TrainingData')
batch_size = 24 # 24
patch_size = [24, 128, 128]
in_channels = ['t1c', 't2', 'flair']

In [4]:
# num_splits=5 means 1/5th is validation data!
patients_train, patients_val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=12345)

In [5]:
patients_test = get_list_of_patients('brats_data_preprocessed/Brats18ValidationData')

In [6]:
train_dl = BRATSDataLoader(
    patients_train,
    batch_size=batch_size,
    patch_size=patch_size,
    in_channels=in_channels
)

val_dl = BRATSDataLoader(
    patients_val,
    batch_size=batch_size,
    patch_size=patch_size,
    in_channels=in_channels
)

In [7]:
tr_transforms = get_train_transform(patch_size)

In [8]:
# finally we can create multithreaded transforms that we can actually use for training
# we don't pin memory here because this is pytorch specific.
tr_gen = MultiThreadedAugmenter(train_dl, tr_transforms, num_processes=4, # tr_transforms
                                num_cached_per_queue=3,
                                seeds=None, pin_memory=False)
# we need less processes for vlaidation because we dont apply transformations
val_gen = MultiThreadedAugmenter(val_dl, None,
                                 num_processes=max(1, 4 // 2),
                                 num_cached_per_queue=1,
                                 seeds=None,
                                 pin_memory=False)

In [9]:
tr_gen.restart()
val_gen.restart()

## Start Training

In [44]:
(A == 1) | (A == 3)

tensor([[1, 1],
        [1, 0]], dtype=torch.uint8)

In [10]:
def dice(outputs, targets):

    # try without sigmoid
    # outputs = F.sigmoid(outputs)
    outputs = (outputs>0).float()
    smooth = 1e-15

    targets = ((targets == 1) | (targets == 3)).float()
    union_fg = (outputs+targets).sum() + smooth
    intersection_fg = (outputs*targets).sum() + smooth

    dice = 2 * intersection_fg / union_fg

    return dice

In [11]:
# Differentiable version of the dice metric
class SimpleDiceLoss():
    def __call__(self, outputs, targets):

        # try without sigmoid
        # outputs = F.sigmoid(outputs)
        outputs = torch.sigmoid(outputs)
        # outputs = (outputs>0).float()
        smooth = 1e-15
        
        targets = ((targets == 1) | (targets == 3)).float()
        union_fg = (outputs+targets).sum() + smooth
        intersection_fg = (outputs*targets).sum() + smooth
        
        dice = 2 * intersection_fg / union_fg

        return 1 - dice

In [12]:
net_3d = AlbuNet3D34(pretrained=True, is_deconv=True)

In [13]:
# before we went from 1e-2 to 1e-1
# wang uses 1e-3, isensee uses 1e-4*5 and decays it 0.985 every epoch, original albunet goes from 1e-3 to 1e-4
# wang uses 1e-7 weight decay, isensee 1e-5
# optimizer = optim.Adam(net_3d.parameters(), lr=1e-2, weight_decay=1e-6)

In [14]:
model_trainer = ModelTrainer('jonas_net_3d_brats17_pretr', net_3d, tr_gen, val_gen, SimpleDiceLoss(), dice,
                             lr=1e-4, epochs=50,
                             num_batches_per_epoch=100, num_validation_batches_per_epoch=100, use_gpu=True)

In [None]:
# with proposed augmentations
# pretrained 2017
# lr=0.0001, epochs=50, num_batches_per_epoch=100, num_validation_batches_per_epoch=100
# ~4.5 hrs
# batch_size = 24, patch_size = [24, 128, 128]
model_trainer.run()
model_trainer.save_model('models/pr_augm_lr_0001_epochs_100_bs_24_brats17.pt')

[Val] Avg. Loss: 0.96, Avg. Metric: 0.04

# Epoch 1 #



In [14]:
# with proposed augmentations
# pretrained 2019
# lr=0.0001, epochs=50, num_batches_per_epoch=100, num_validation_batches_per_epoch=100
# ~4.5 hrs
# batch_size = 24, patch_size = [24, 128, 128]
model_trainer.run()
model_trainer.save_model('models/pr_augm_lr_0001_epochs_100_bs_24.pt')

[Val] Avg. Loss: 0.96, Avg. Metric: 0.04

# Epoch 1 #

[Train] Avg. Loss: 0.95, Avg. Metric: 0.05
[Val] Avg. Loss: 0.92, Avg. Metric: 0.07

# Epoch 2 #

[Train] Avg. Loss: 0.93, Avg. Metric: 0.08
[Val] Avg. Loss: 0.92, Avg. Metric: 0.13

# Epoch 3 #

[Train] Avg. Loss: 0.93, Avg. Metric: 0.18
[Val] Avg. Loss: 0.91, Avg. Metric: 0.27

# Epoch 4 #

[Train] Avg. Loss: 0.93, Avg. Metric: 0.26
[Val] Avg. Loss: 0.91, Avg. Metric: 0.46

# Epoch 5 #

[Train] Avg. Loss: 0.93, Avg. Metric: 0.52
[Val] Avg. Loss: 0.90, Avg. Metric: 0.64

# Epoch 6 #

[Train] Avg. Loss: 0.93, Avg. Metric: 0.61
[Val] Avg. Loss: 0.90, Avg. Metric: 0.53

# Epoch 7 #

[Train] Avg. Loss: 0.92, Avg. Metric: 0.66
[Val] Avg. Loss: 0.90, Avg. Metric: 0.69

# Epoch 8 #

[Train] Avg. Loss: 0.91, Avg. Metric: 0.65
[Val] Avg. Loss: 0.90, Avg. Metric: 0.72

# Epoch 9 #

[Train] Avg. Loss: 0.91, Avg. Metric: 0.67
[Val] Avg. Loss: 0.89, Avg. Metric: 0.72

# Epoch 10 #

[Train] Avg. Loss: 0.90, Avg. Metric: 0.69
[Val] Avg. Loss: 0.

In [62]:
model_trainer.load_model('models/pr_augm_lr_0001_epochs_100_bs_24.pt')

In [63]:
try:
    import SimpleITK as sitk
except ImportError:
    print("You need to have SimpleITK installed to run this example!")
    raise ImportError("SimpleITK not found")

def save_segmentation_as_nifti(segmentation, metadata, output_file):
    original_shape = metadata['original_shape']
    seg_original_shape = np.zeros(original_shape, dtype=np.uint8)
    nonzero = metadata['nonzero_region']
    seg_original_shape[nonzero[0, 0] : nonzero[0, 1] + 1,
               nonzero[1, 0]: nonzero[1, 1] + 1,
               nonzero[2, 0]: nonzero[2, 1] + 1] = segmentation
    sitk_image = sitk.GetImageFromArray(seg_original_shape)
    sitk_image.SetDirection(metadata['direction'])
    sitk_image.SetOrigin(metadata['origin'])
    # remember to revert spacing back to sitk order again
    sitk_image.SetSpacing(tuple(metadata['spacing'][[2, 1, 0]]))
    print(output_file)
    sitk.WriteImage(sitk_image, output_file)

In [64]:
def np_dice(outputs, targets):

    # try without sigmoid
    # outputs = F.sigmoid(outputs)
    outputs = np.float32(outputs>0)
    smooth = 1e-15

    targets = np.float32((targets == 1) | (targets == 3))
    union_fg = np.sum(outputs+targets) + smooth
    intersection_fg = np.sum(outputs*targets) + smooth

    dice = 2 * intersection_fg / union_fg

    return dice

In [80]:
import skimage

def predict_patient_in_patches(patient_data, model):
    # we pad the patient data in order to fit the patches in it
    patient_data_pd = pad_nd_image(patient_data, [144, 192, 192]) # 24*6, 128+2*32, 128+2*32
    # patches.shape = (1, 1, 6, 3, 3, 1, 3, 24, 128, 128)
    steps = (1,1,24,32,32)
    window_shape = (1, 3, 24, 128, 128)
    patches = skimage.util.view_as_windows(patient_data_pd[:, :3, :, :, :], window_shape=window_shape, step=steps)
    
    # (1, 4, 138, 169, 141)
    target_shape = list(patient_data_pd.shape)
    target_shape[1] = 1 # only one output channel
    prediction = torch.zeros(*target_shape).cuda()
    
    for i in range(patches.shape[2]):
        for j in range(patches.shape[3]):
            for k in range(patches.shape[4]):
                data = torch.from_numpy(patches[0, 0, i, j, k])
                data = data.cuda()
                output = model.forward(data)

                prediction[:, :,
                           i*steps[2]:i*steps[2]+window_shape[2],
                           j*steps[3]:j*steps[3]+window_shape[3],
                           k*steps[4]:k*steps[4]+window_shape[4]] += output
                    
    return prediction

In [81]:
from batchgenerators.augmentations.utils import pad_nd_image
from batchgenerators.augmentations.utils import center_crop_3D_image

target_patients = patients_val

dices = []

for idx, (patient_data, meta_data) in enumerate(iterate_through_patients(target_patients, in_channels + ['seg'])): #  + ['seg']
    print(patient_data.shape)
    
    model_trainer.model.eval()
    with torch.no_grad():
        prediction = predict_patient_in_patches(patient_data, model_trainer.model)
        
    np_prediction = prediction.cpu().detach().numpy()
    np_prediction[np_prediction > 0] = 1 # tumor core
    np_prediction[np_prediction < 0] = 0
    
    np_cut = center_crop_3D_image(np_prediction[0,0], patient_data.shape[2:])
    
    dice = np_dice(np_cut, patient_data[0,3,:,:,:])
    print(idx, dice)
    dices.append(dice)
    
    output_path = '/'.join(target_patients[idx].split('/')[-2:])
    save_segmentation_as_nifti(np_cut, meta_data, os.path.join('segmentation_output', output_path + '.nii.gz'))
    
print('Mean:', np.mean(np.array(dices)))

(1, 4, 138, 169, 141)
0 0.8375910612325261
segmentation_output/Brats19TrainingData/BraTS19_2013_19_1.nii.gz
(1, 4, 132, 175, 146)
1 0.6555020584746354
segmentation_output/Brats19TrainingData/BraTS19_2013_1_1.nii.gz
(1, 4, 130, 165, 141)
2 0.928977473758698
segmentation_output/Brats19TrainingData/BraTS19_2013_24_1.nii.gz
(1, 4, 128, 180, 141)
3 0.9455917394757745
segmentation_output/Brats19TrainingData/BraTS19_2013_26_1.nii.gz
(1, 4, 142, 166, 149)
4 0.9275800998228949
segmentation_output/Brats19TrainingData/BraTS19_2013_27_1.nii.gz
(1, 4, 133, 161, 149)
5 0.16364931343481356
segmentation_output/Brats19TrainingData/BraTS19_2013_28_1.nii.gz
(1, 4, 129, 179, 142)
6 0.9222455013608506
segmentation_output/Brats19TrainingData/BraTS19_2013_2_1.nii.gz
(1, 4, 132, 159, 137)
7 0.9517716747442111
segmentation_output/Brats19TrainingData/BraTS19_2013_7_1.nii.gz
(1, 4, 140, 187, 134)
8 0.9572199301808949
segmentation_output/Brats19TrainingData/BraTS19_CBICA_AAG_1.nii.gz
(1, 4, 140, 182, 135)
9 0.928

TypeError: Axis must be specified when shapes of a and weights differ.