In [1]:
from brats_data_loader import get_list_of_patients, get_train_transform, BRATSDataLoader
from train_test_function import ModelTrainer
from jonas_net import AlbuNet3D34

from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.dataloading import MultiThreadedAugmenter

In [2]:
patients = get_list_of_patients('brats_data_preprocessed/Brats17TrainingData')
batch_size = 2 # 24
patch_size = [24, 128, 128]
in_channels = ['t1c', 't2', 'flair']

In [3]:
# num_splits=5 means 1/5th is validation data!
patients_train, patients_val = get_split_deterministic(patients, fold=0, num_splits=5, random_state=12345)

In [4]:
train_dl = BRATSDataLoader(
    patients_train,
    batch_size=batch_size,
    patch_size=patch_size,
    in_channels=in_channels
)

val_dl = BRATSDataLoader(
    patients_val,
    batch_size=batch_size,
    patch_size=patch_size,
    in_channels=in_channels
)

In [5]:
tr_transforms = get_train_transform(patch_size)

In [6]:
# finally we can create multithreaded transforms that we can actually use for training
# we don't pin memory here because this is pytorch specific.
tr_gen = MultiThreadedAugmenter(train_dl, tr_transforms, num_processes=4,
                                num_cached_per_queue=3,
                                seeds=None, pin_memory=False)
# we need less processes for vlaidation because we dont apply transformations
val_gen = MultiThreadedAugmenter(val_dl, None,
                                 num_processes=max(1, 4 // 2),
                                 num_cached_per_queue=1,
                                 seeds=None,
                                 pin_memory=False)

In [7]:
tr_gen.restart()
val_gen.restart()

## Start Training

In [8]:
def dice(outputs, targets):

    # try without sigmoid
    # outputs = F.sigmoid(outputs)
    outputs = (outputs>0).float()
    smooth = 1e-15

    targets = (targets == 1).float()
    union_fg = (outputs+targets).sum() + smooth
    intersection_fg = (outputs*targets).sum() + smooth

    dice = 2 * intersection_fg / union_fg

    return dice

In [9]:
# Differentiable version of the dice metric
class SimpleDiceLoss():
    def __call__(self, outputs, targets):

        # try without sigmoid
        # outputs = F.sigmoid(outputs)
        outputs = torch.sigmoid(outputs)
        # outputs = (outputs>0).float()
        smooth = 1e-15
        
        targets = (targets == 1).float()
        union_fg = (outputs+targets).sum() + smooth
        intersection_fg = (outputs*targets).sum() + smooth
        
        dice = 2 * intersection_fg / union_fg

        return 1 - dice

In [10]:
net_3d = AlbuNet3D34(pretrained=True, is_deconv=True)

In [None]:
# before we went from 1e-2 to 1e-1
# wang uses 1e-3, isensee uses 1e-4*5 and decays it 0.985 every epoch, original albunet goes from 1e-3 to 1e-4
# wang uses 1e-7 weight decay, isensee 1e-5
optimizer = optim.Adam(net_3d.parameters(), lr=1e-2, weight_decay=1e-6)

In [16]:
model_trainer = ModelTrainer('jonas_net_3d', net_3d, tr_gen, val_gen, SimpleDiceLoss(), dice,
                             lr=0.001, epochs=10,
                             num_batches_per_epoch=10, num_validation_batches_per_epoch=3)

In [17]:
model_trainer.run()

[Val] Avg. Loss: 0.95, Avg. Metric: 0.05%

# Epoch 1 #

[Train] Avg. Loss: 0.97, Avg. Metric: 0.06%
[Val] Avg. Loss: 0.96, Avg. Metric: 0.22%

# Epoch 2 #

[Train] Avg. Loss: 0.94, Avg. Metric: 0.20%
[Val] Avg. Loss: 1.00, Avg. Metric: 0.00%

# Epoch 3 #

[Train] Avg. Loss: 0.98, Avg. Metric: 0.08%
[Val] Avg. Loss: 0.89, Avg. Metric: 0.30%

# Epoch 4 #

[Train] Avg. Loss: 0.97, Avg. Metric: 0.11%
[Val] Avg. Loss: 0.96, Avg. Metric: 0.08%

# Epoch 5 #

[Train] Avg. Loss: 0.99, Avg. Metric: 0.06%
[Val] Avg. Loss: 0.94, Avg. Metric: 0.15%

# Epoch 6 #

[Train] Avg. Loss: 0.95, Avg. Metric: 0.18%
[Val] Avg. Loss: 0.99, Avg. Metric: 0.03%

# Epoch 7 #

[Train] Avg. Loss: 0.97, Avg. Metric: 0.10%
[Val] Avg. Loss: 0.99, Avg. Metric: 0.11%

# Epoch 8 #

[Train] Avg. Loss: 0.90, Avg. Metric: 0.29%
[Val] Avg. Loss: 0.97, Avg. Metric: 0.15%

# Epoch 9 #

[Train] Avg. Loss: 0.97, Avg. Metric: 0.06%
[Val] Avg. Loss: 0.97, Avg. Metric: 0.08%

# Epoch 10 #

[Train] Avg. Loss: 0.92, Avg. Metric: 0.21%

2214.5091140270233