Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New losses: focal loss & generalised dice loss #46

Merged
merged 39 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
56f6f6e
ivadomed/losses.py: add focal loss
charleygros Sep 9, 2019
e57ab02
README.md: add about focal loss
charleygros Sep 9, 2019
b5cced0
ivadomed/main.py: add focal loss
charleygros Sep 9, 2019
bc6f13e
ivadomed/main.py: add Dice score in terminal output
charleygros Sep 9, 2019
f12e502
ivadomed/losses.py: add focal loss
charleygros Sep 9, 2019
37a6deb
solve conflicts
charleygros Sep 10, 2019
646e922
ivadomed/main.py: add dice loss while using focal loss
charleygros Sep 10, 2019
ba77fad
solve conflicts
charleygros Sep 10, 2019
1efbc3b
configs: loss as a dict
charleygros Sep 10, 2019
c61a11a
README.md: doc about loss dict
charleygros Sep 10, 2019
a78ab78
ivadomed/main.py: use loss_param
charleygros Sep 10, 2019
bb258dc
ivadomed/losses.py: implement gdl
charleygros Sep 11, 2019
6ccdb6a
README.md: add gdl in doc
charleygros Sep 11, 2019
7853878
ivadomed/main.py: add gdl as loss
charleygros Sep 11, 2019
835bf0d
Merge branch 'master' of https://github.com/neuropoly/ivado-medical-i…
charleygros Sep 16, 2019
b69bd85
ivadomed/losses.py: implement dice loss
charleygros Sep 16, 2019
4833bf2
ivadomed/main.py: add new dice loss
charleygros Sep 16, 2019
5ec8bf1
ivadomed/losses.py: add mixed loss
charleygros Sep 16, 2019
8c267d3
ivadomed/losses.py: rename MixedLoss to FocalDiceLoss
charleygros Sep 16, 2019
e68bea5
README.md: doc about focal_dice
charleygros Sep 16, 2019
21a5ab1
ivadomed/main.py: cleanup
charleygros Sep 16, 2019
134cff9
ivadomed/main.py: display log dice
charleygros Sep 17, 2019
5554966
ivadomed/utils.py: implement dice_score with epsilon
charleygros Sep 18, 2019
3cc81c5
ivadomed/main.py: dice loss negative
charleygros Sep 18, 2019
3faf1b5
ivadomed/main.py: display exp log dice
charleygros Sep 18, 2019
92f2c5f
README.md: soft dice loss for mixup
charleygros Sep 18, 2019
298d25c
Fix metrics error in test command
olix86 Sep 18, 2019
dda78c0
ivadomed/utils: add eps to the numerator of dice score
charleygros Sep 18, 2019
c3e6159
config/config_sctTesting.json: add new subjects 19-09-19
charleygros Sep 19, 2019
4e8e801
added new contrasts to contrast_dct.json
olix86 Sep 23, 2019
1b1b60d
remove trailing comma
olix86 Sep 23, 2019
d982b88
ivadomed/utils.py: dice_score return None if im_sum empty
charleygros Sep 27, 2019
680dc94
ivadomed/utils.py: get_results overidding
charleygros Sep 28, 2019
58ae224
ivadomed/main.py: call IvadoMetricManager
charleygros Sep 28, 2019
0b9a68d
ivadomed/utils.py: override __call__
charleygros Sep 28, 2019
f5275aa
ivadomed/utils.py: return None if all np.nan
charleygros Sep 30, 2019
52da767
cleanup
charleygros Sep 30, 2019
bf27284
config/config_sctTesting.json: update contrast list
charleygros Sep 30, 2019
dc90860
Add new contrasts to contrast_dct
olix86 Sep 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Please find below a description of each parameter:
- `batch_norm_momentum`: float (e.g. 0.1).
- `num_epochs`: int.
- `initial_lr`: initial learning rate.
- `loss`: choice between 'dice' and 'cross_entropy'. Note: Please use 'cross_entropy' when comparing `Unet` vs. `MixedUp-Unet`.
- `loss`: dictionary with a key `'name'` for the choice between `'dice'`, `'focal'`, `'focal_dice'`, `'gdl'` and `'cross_entropy'` and a (optional) key `'params'` (e.g.`{"name": "focal", "params": {"gamma": 0.5}}`.
- `log_directory`: folder name where log files are saved.
- `film_layers`: indicates on which layer(s) of the U-net you want to apply a FiLM modulation: list of 8 elements (because Unet has 8 layers), set to 0 for no FiLM modulation, set 1 otherwise. Note: When running `Unet` or `MixedUp-Unet`, please fill this list with zeros only.
- `mixup_bool`: indicates if mixup is applied to the training data (choice: `false` or `true`). Note: Please use `false` when comparing `Unet` vs. `FiLMed-Unet`.
Expand Down
4 changes: 2 additions & 2 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
"batch_norm_momentum": 0.1,
"num_epochs": 100,
"initial_lr": 0.001,
"loss": "dice",
"loss": {"name": "dice"},
"log_directory": "log_sc",
"film_layers": [1, 0, 0, 0, 0, 0, 0, 0],
"film_layers": [0, 0, 0, 0, 0, 0, 0, 0],
"mixup_bool": false,
"mixup_alpha": 2,
"metadata": "contrast",
Expand Down
2 changes: 1 addition & 1 deletion config/config_gm.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"batch_norm_momentum": 0.1,
"num_epochs": 500,
"initial_lr": 0.001,
"loss": "dice",
"loss": {"name": "dice"},
"log_directory": "log_gm",
"film_layers": [0, 0, 0, 0, 0, 0, 0, 0],
"mixup_bool": false,
Expand Down
10 changes: 5 additions & 5 deletions config/config_sctTesting.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
"contrast_balance": {},
"contrast_test": ["T2w", "T1w", "acq-axmid00100_T2w", "acq-axbottom_T2w", "acq-sagtsp_T2w", "acq-sagstir_T2w", "acq-axtop_T2w", "acq-ax00014_T2w", "acq-axtsp_T2w", "acq-axtop00100_T2w", "acq-axcsp_T2w", "acq-sup_T2star", "acq-ax_T2w", "acq-sagcsp_T2w", "acq-axmid00005_T2w", "acq-sagstirtsp_T2w", "acq-ax00012_T2w", "acq-axlow_T2w", "T2star", "acq-sagstircsp_T2w", "acq-sag_T2w", "acq-sag00015_T2w", "acq-inf_T2star"],
"center_test": [],
"batch_size": 18,
"batch_size": 32,
"dropout_rate": 0.3,
"batch_norm_momentum": 0.1,
"num_epochs": 100,
"initial_lr": 0.001,
"loss": "dice",
"initial_lr": 0.0001,
"loss": {"name": "focal_dice", "params": {"gamma": 0.4, "alpha": 10}},
"log_directory": "log_large",
"film_layers": [1, 0, 0, 0, 0, 0, 0, 0],
"film_layers": [0, 0, 0, 0, 0, 0, 0, 0],
"mixup_bool": false,
"mixup_alpha": 2,
"metadata": "contrast",
"metadata": "without",
"debugging": false
}
10 changes: 5 additions & 5 deletions config/config_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"command": "train",
"gpu": 4,
"gt_suffix": "_seg-manual",
"bids_path_train": "../duke/projects/ivado-medical-imaging/spineGeneric_201907041011/result/",
"bids_path": "../duke/projects/ivado-medical-imaging/spineGeneric_201907041011/result/",
"random_seed": 1313,
"contrast_train_validation": ["T1w", "T2w", "T2star", "acq-MToff_MTS", "acq-MTon_MTS", "acq-T1w_MTS"],
"contrast_balance": {},
Expand All @@ -13,11 +13,11 @@
"batch_norm_momentum": 0.1,
"num_epochs": 100,
"initial_lr": 0.001,
"loss": "dice",
"loss": {"name": "dice"},
"log_directory": "log_sc_small",
"film_layers": [1, 0, 0, 0, 0, 0, 0, 0],
"mixup_bool": false,
"film_layers": [0, 0, 0, 0, 0, 0, 0, 0],
"mixup_bool": true,
"mixup_alpha": 2,
"metadata": "contrast",
"metadata": "without",
"debugging": false
}
82 changes: 82 additions & 0 deletions ivadomed/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def dice_loss(input, target):
# input = torch.sigmoid(input)
smooth = 1.0

iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()

return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


class FocalLoss(nn.Module):
"""
Focal Loss: https://arxiv.org/abs/1708.02002
"""

def __init__(self, gamma):
super().__init__()
self.gamma = gamma

def forward(self, input, target):
if not (target.size() == input.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

# This gives us the log sigmoid of 1-p if y is 0 and of p if y is 1
invprobs = F.logsigmoid(-input * (target * 2 - 1))
loss = (invprobs * self.gamma).exp() * loss

# Note: works in log space to be numerically stable (ie to avoid NaNs when training).
return loss.mean()


class FocalDiceLoss(nn.Module):
"""
Motivated by https://arxiv.org/pdf/1809.00076.pdf
:param alpha: to bring the dice and focal losses at similar scale.
:param gamma: gamma value used in the focal loss.
"""
def __init__(self, alpha, gamma):
super().__init__()
self.alpha = alpha
self.focal = FocalLoss(gamma)

def forward(self, input, target):
loss = self.alpha * self.focal(input, target) - torch.log(dice_loss(input, target))
return loss.mean()


class GeneralizedDiceLoss(nn.Module):
"""
Generalized Dice Loss: https://arxiv.org/pdf/1707.03237
"""
def __init__(self, epsilon=1e-5):
super(GeneralizedDiceLoss, self).__init__()
self.epsilon = epsilon

def forward(self, input, target):
if not (target.size() == input.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

input = input.view(-1)
target = target.view(-1)

target = target.float()
target_sum = target.sum(-1)
class_weights = nn.Parameter(1. / (target_sum * target_sum).clamp(min=self.epsilon))

intersect = (input * target).sum(-1) * class_weights
intersect = intersect.sum()

denominator = ((input + target).sum(-1) * class_weights).sum()

return 1. - 2. * intersect / denominator.clamp(min=self.epsilon)

76 changes: 58 additions & 18 deletions ivadomed/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import shutil
import random
import joblib
from math import exp
import numpy as np

import torch
import torch.nn as nn
Expand All @@ -16,7 +18,6 @@

from medicaltorch import transforms as mt_transforms
from medicaltorch import datasets as mt_datasets
from medicaltorch import losses as mt_losses
from medicaltorch import filters as mt_filters
from medicaltorch import metrics as mt_metrics

Expand All @@ -26,10 +27,9 @@

from ivadomed import loader as loader
from ivadomed import models
from ivadomed import losses
from ivadomed.utils import *

import numpy as np

cudnn.benchmark = True


Expand Down Expand Up @@ -176,24 +176,36 @@ def cmd_train(context):
var_contrast_list = []

# Loss
if context["loss"] in ["dice", "cross_entropy"]:
if context["loss"] == "cross_entropy":
if context["loss"]["name"] in ["dice", "cross_entropy", "focal", "gdl", "focal_dice"]:
if context["loss"]["name"] == "cross_entropy":
loss_fct = nn.BCELoss()
elif context["loss"]["name"] == "focal":
loss_fct = losses.FocalLoss(gamma=context["loss"]["params"]["gamma"])
print("\nLoss function: {}, with gamma={}.\n".format(context["loss"]["name"], context["loss"]["params"]["gamma"]))
elif context["loss"]["name"] == "gdl":
loss_fct = losses.GeneralizedDiceLoss()
elif context["loss"]["name"] == "focal_dice":
loss_fct = losses.FocalDiceLoss(gamma=context["loss"]["params"]["gamma"], alpha=context["loss"]["params"]["alpha"])
print("\nLoss function: {}, with gamma={} and alpha={}.\n".format(context["loss"]["name"], context["loss"]["params"]["gamma"], context["loss"]["params"]["alpha"]))
focal_loss_fct = losses.FocalLoss(gamma=context["loss"]["params"]["gamma"]) # for tuning alpha

if not context["loss"]["name"].startswith("focal"):
print("\nLoss function: {}.\n".format(context["loss"]["name"]))

else:
print("Unknown Loss function, please choose between 'dice' or 'cross_entropy'")
print("Unknown Loss function, please choose between 'dice', 'focal', 'focal_dice', 'gdl' or 'cross_entropy'")
exit()

# Training loop -----------------------------------------------------------
best_validation_loss = float("inf")
bce_loss = nn.BCELoss()
for epoch in tqdm(range(1, num_epochs+1), desc="Training"):
start_time = time.time()

lr = scheduler.get_lr()[0]
writer.add_scalar('learning_rate', lr, epoch)

model.train()
train_loss_total = 0.0
train_loss_total, dice_train_loss_total, focal_train_loss_total = 0.0, 0.0, 0.0
num_steps = 0
for i, batch in enumerate(train_loader):
input_samples, gt_samples = batch["input"], batch["gt"]
Expand All @@ -207,7 +219,6 @@ def cmd_train(context):
mixup_folder = os.path.join(context["log_directory"], 'mixup')
if not os.path.isdir(mixup_folder):
os.makedirs(mixup_folder)
print(lambda_tensor.data.numpy()[0])
random_idx = np.random.randint(0, input_samples.size()[0])
val_gt = np.unique(gt_samples.data.numpy()[random_idx,0,:,:])
mixup_fname_pref = os.path.join(mixup_folder, str(i).zfill(3)+'_'+str(lambda_tensor.data.numpy()[0])+'_'+str(random_idx).zfill(3)+'.png')
Expand All @@ -234,10 +245,15 @@ def cmd_train(context):
else:
preds = model(var_input)

if context["loss"] == "dice":
loss = mt_losses.dice_loss(preds, var_gt)
if context["loss"]["name"] == "dice":
loss = - losses.dice_loss(preds, var_gt)
else:
loss = loss_fct(preds, var_gt)
if context["loss"]["name"] == "focal_dice":
focal_train_loss_total += focal_loss_fct(preds, var_gt).item()
dice_train_loss_total += torch.log(losses.dice_loss(preds, var_gt)).item()
else:
dice_train_loss_total += losses.dice_loss(preds, var_gt).item()
train_loss_total += loss.item()

optimizer.zero_grad()
Expand Down Expand Up @@ -267,13 +283,23 @@ def cmd_train(context):
train_loss_total_avg = train_loss_total / num_steps

tqdm.write(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.")
if context["loss"]["name"] == 'focal_dice':
focal_train_loss_total_avg = focal_train_loss_total / num_steps
log_dice_train_loss_total_avg = dice_train_loss_total / num_steps
dice_train_loss_total_avg = exp(log_dice_train_loss_total_avg)
tqdm.write(f"\tFocal training loss: {focal_train_loss_total_avg:.4f}.")
tqdm.write(f"\tLog Dice training loss: {log_dice_train_loss_total_avg:.4f}.")
tqdm.write(f"\tDice training loss: {dice_train_loss_total_avg:.4f}.")
elif context["loss"]["name"] != 'dice':
dice_train_loss_total_avg = dice_train_loss_total / num_steps
tqdm.write(f"\tDice training loss: {dice_train_loss_total_avg:.4f}.")

# Validation loop -----------------------------------------------------
model.eval()
val_loss_total = 0.0
val_loss_total, dice_val_loss_total, focal_val_loss_total = 0.0, 0.0, 0.0
num_steps = 0

metric_fns = [mt_metrics.dice_score,
metric_fns = [dice_score, # from ivadomed/utils.py
mt_metrics.hausdorff_score,
mt_metrics.precision_score,
mt_metrics.recall_score,
Expand Down Expand Up @@ -304,11 +330,15 @@ def cmd_train(context):
else:
preds = model(var_input)

# loss = mt_losses.dice_loss(preds, var_gt)
if context["loss"] == "dice":
loss = mt_losses.dice_loss(preds, var_gt)
if context["loss"]["name"] == "dice":
loss = - losses.dice_loss(preds, var_gt)
else:
loss = loss_fct(preds, var_gt)
if context["loss"]["name"] == "focal_dice":
focal_val_loss_total += focal_loss_fct(preds, var_gt).item()
dice_val_loss_total += torch.log(losses.dice_loss(preds, var_gt)).item()
else:
dice_val_loss_total += losses.dice_loss(preds, var_gt).item()
val_loss_total += loss.item()

# Metrics computation
Expand Down Expand Up @@ -369,6 +399,16 @@ def cmd_train(context):
}, epoch)

tqdm.write(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.")
if context["loss"]["name"] == 'focal_dice':
focal_val_loss_total_avg = focal_val_loss_total / num_steps
log_dice_val_loss_total_avg = dice_val_loss_total / num_steps
dice_val_loss_total_avg = exp(log_dice_val_loss_total_avg)
tqdm.write(f"\tFocal validation loss: {focal_val_loss_total_avg:.4f}.")
tqdm.write(f"\tLog Dice validation loss: {log_dice_val_loss_total_avg:.4f}.")
tqdm.write(f"\tDice validation loss: {dice_val_loss_total_avg:.4f}.")
elif context["loss"]["name"] != 'dice':
dice_val_loss_total_avg = dice_val_loss_total / num_steps
tqdm.write(f"\tDice validation loss: {dice_val_loss_total_avg:.4f}.")

end_time = time.time()
total_time = end_time - start_time
Expand Down Expand Up @@ -465,8 +505,8 @@ def cmd_test(context):
model.cuda()
model.eval()

metric_fns = [mt_metrics.dice_score,
# mt_metrics.hausdorff_score,
metric_fns = [dice_score, # from ivadomed/utils.py
mt_metrics.hausdorff_score,
mt_metrics.precision_score,
mt_metrics.recall_score,
mt_metrics.specificity_score,
Expand Down
17 changes: 17 additions & 0 deletions ivadomed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@
from medicaltorch import filters as mt_filters
from medicaltorch import transforms as mt_transforms


def dice_score(im1, im2, eps=1.0):
"""
Computes the Dice coefficient between im1 and im2.
"""
im1 = np.asarray(im1).astype(np.bool)
im2 = np.asarray(im2).astype(np.bool)

if im1.shape != im2.shape:
raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")

im_sum = im1.sum() + im2.sum() + eps

intersection = np.logical_and(im1, im2)
return 2. * intersection.sum() / im_sum
charleygros marked this conversation as resolved.
Show resolved Hide resolved


def mixup(data, targets, alpha):
"""Compute the mixup data.
Return mixed inputs and targets, lambda.
Expand Down