Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New losses: focal loss & generalised dice loss #46

Merged
merged 39 commits into from Sep 30, 2019
Merged

Conversation

charleygros
Copy link
Member

@charleygros charleygros commented Sep 9, 2019

We are facing a severe class-imbalance issue since the introduction of the MS lesion segmentation task. This PR allows the use of new loss functions:

Done:

@charleygros charleygros added the enhancement category: improves performance/results of an existing feature label Sep 9, 2019
@charleygros charleygros self-assigned this Sep 9, 2019
@charleygros
Copy link
Member Author

charleygros commented Sep 10, 2019

Focal Loss implementation: works in log space, to be numerically stable.

I first tried without but: very easily got NaNs when training. Basically, it boosts the loss for the cases when objects are not detected correctly --> avoid FN predictions (ref).

In addition, the model effectively incorporates the small objects since the loss for these objects is very high.

@charleygros
Copy link
Member Author

charleygros commented Sep 10, 2019

Display Dice Loss results in the terminal while using a new loss --> as reference / safety check, since we have a good idea of the ideal Dice loss curve:

Epoch 1 training loss: 0.3283.                                                                                                                 
        Dice training loss: -0.0047.
Epoch 1 validation loss: 0.2450.                                                                                                               
        Dice validation loss: -0.0073.                                                                                                         
Epoch 1 took 58.76 seconds.

@charleygros
Copy link
Member Author

charleygros commented Sep 10, 2019

As the gamma parameter of the FocalLoss is quite crucial / important to correctly optimise (see Fig 1 of article), I introduced the possibility to tune it in the config files:

"loss": {"name": "focal", "params": {"gamma": 0.4}}

And then, while running main.py:

Loss function: focal, with gamma=0.4.

@charleygros
Copy link
Member Author

Generalised Dice Loss: Adapted from the original paper (multi-class segmentation) for our task (ie binary segmentation).
I set a default value for the epsilon. We may want to tune it at some point. Which could be done by adding it in the config files as params.

@charleygros
Copy link
Member Author

charleygros commented Sep 17, 2019

Mixed Loss: Combination of the focal loss and the log of the dice loss.

The Log is also used here to boost the loss when objects are not detected correctly --> dice close or equal to zero.
The Log Dice loss focuses more on less accurate label.

To bring the two losses to a similar scale, a new hyperparameter is introduced: alpha. The terminal output helps us to tune this parameter, cf with alpha=10:

Epoch 1 training loss: 11.9020.                                                                                                                
        Focal training loss: 0.6866.                                                                                                           
        Log Dice training loss: -5.0363.                                                                                                       
Epoch 1 validation loss: 10.0586.                                                                                                              
        Focal validation loss: 0.5996.                                                                                                         
        Log Dice validation loss: -4.0622.                                                                                                     
Epoch 1 took 159.52 seconds.

@charleygros
Copy link
Member Author

Soft Dice Loss: implemented to allow the use of Dice Loss during mixup experiments (not binary masks).

@charleygros
Copy link
Member Author

@olix86: PR ready! Could I ask you to review it? Thanks :)

Copy link
Contributor

@olix86 olix86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good to me, I just have a small doubt about the implementation of the new dice score, let me know what you think :)

ivadomed/utils.py Outdated Show resolved Hide resolved
@olix86
Copy link
Contributor

olix86 commented Sep 18, 2019

I've found a bug, the fix should be quite simple, please wait before merging, I'll update this soon

Edit : should be fixed by the following commit

olix86 and others added 3 commits September 18, 2019 16:41
ValueError: Shape mismatch: im1 and im2 must have the same shape.
Co-Authored-By: olix86 <olix86@users.noreply.github.com>
@charleygros
Copy link
Member Author

Just adding new subjects to the config file, recently added by @alexfoias to sct_testing/large.

@jcohenadad
Copy link
Member

Maybe ignore the dice_loss values if there is no segmented object (instead of returning one)

@olix86
Copy link
Contributor

olix86 commented Sep 23, 2019

Looks like the last commit that added new contrasts broke something, I think it's because they were not added to contrast_dct.json so it crashes when trying to use FiLM with contrast

 File "/ivadomed/loader.py", line 229, in normalize_metadata
    generic_contrast = GENERIC_CONTRAST[subject["input_metadata"]["bids_metadata"]["contrast"]]
KeyError: 'acq-ax_T2star'

@olix86
Copy link
Contributor

olix86 commented Sep 23, 2019

Alright this should fix it but I'm not super knowledgeable in MRI contrasts so somebody should probably check if it makes sense :)

@olix86
Copy link
Contributor

olix86 commented Sep 23, 2019

Is there a reason why there's the following line to prevent using both FiLM and mixup at the same time?
mixup_bool = False if film_bool else bool(context["mixup_bool"])

(main.py line 61)

@charleygros
Copy link
Member Author

charleygros commented Sep 23, 2019

Is there a reason why there's the following line to prevent using both FiLM and mixup at the same time?

The reason was: If we used FiLM and MixUp at the same time: how do you apply mixup to the metadata (ie FiLM input)? Let's say we do a mixup between a T1w and a T2star: ideally we would like to do a mixup on their metadata as well when feeding the FiLM generator.
When I implemented it, I could not think of an appropriate way to do that. Hence why we prevent using both FiLM and mixup at the same time.
But that's an open question and would be very happy to brainstorm about it!

@charleygros
Copy link
Member Author

Alright this should fix it but I'm not super knowledgeable in MRI contrasts so somebody should probably check if it makes sense :)

Checked! All good 👍

@charleygros
Copy link
Member Author

Maybe ignore the dice_loss values if there is no segmented object (instead of returning one)

Done by:

@charleygros
Copy link
Member Author

@olix86: That's ready! Could you please review when suits you?

@olix86
Copy link
Contributor

olix86 commented Sep 30, 2019

@charleygros looks good to me, I think the only thing missing was adding the new contrasts from your last commit to the json file (I found 4 and added them).

If that's ok with you I think we're ready to merge 😄

@charleygros charleygros merged commit 908ff27 into master Sep 30, 2019
@olix86 olix86 deleted the cg/new-losses branch October 29, 2019 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement category: improves performance/results of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeWarning: invalid value encountered in double_scalars
3 participants