Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace bce by focal loss in linknet loss #824

Merged
merged 9 commits into from
Feb 23, 2022
Merged

feat: replace bce by focal loss in linknet loss #824

merged 9 commits into from
Feb 23, 2022

Conversation

charlesmindee
Copy link
Collaborator

Following a suggestion by @fg-mindee and @SiddhantBahuguna, this PR replaces the BCE loss by the Focal loss in the linknet loss to increase the recall (imbalanced classes)

Any feedback is welcome!

@charlesmindee charlesmindee added type: enhancement Improvement module: models Related to doctr.models framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text detection Related to the task of text detection labels Feb 16, 2022
@charlesmindee charlesmindee added this to the 0.5.1 milestone Feb 16, 2022
@charlesmindee charlesmindee self-assigned this Feb 16, 2022
@codecov
Copy link

codecov bot commented Feb 16, 2022

Codecov Report

Merging #824 (972a2b6) into main (fae4923) will decrease coverage by 0.02%.
The diff coverage is 92.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #824      +/-   ##
==========================================
- Coverage   96.01%   95.98%   -0.03%     
==========================================
  Files         131      131              
  Lines        5019     5033      +14     
==========================================
+ Hits         4819     4831      +12     
- Misses        200      202       +2     
Flag Coverage Δ
unittests 95.98% <92.00%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
doctr/models/detection/linknet/tensorflow.py 97.67% <91.66%> (-1.06%) ⬇️
doctr/models/detection/linknet/pytorch.py 97.97% <92.30%> (-0.94%) ⬇️
doctr/transforms/modules/base.py 94.59% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fae4923...972a2b6. Read the comment docs.

Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a comment on the implementation! Have you tried to check if it improves training perf?

doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
Comment on lines 187 to 191
p_t = (seg_target[seg_mask] * pred_prob) + ((1 - seg_target[seg_mask]) * (1 - pred_prob))
# Compute alpha factor
alpha_factor = seg_target[seg_mask] * alpha + (1 - seg_target[seg_mask]) * (1 - alpha)
# compute the final loss
focal_loss = (alpha_factor * (1. - p_t) ** gamma * bce_loss[seg_mask]).mean()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to address the masking + reduction problem of the dice loss first: once masked, this reduces the tensor to something in 1D. So if one class has 10 times more masked region than another, this will be a problem

I'd suggest doing the following: changing

seg_target[mask].mean()

to

mask = mask.to(dtype=torch.foat32)
# Average on N, H, W
class_loss = (seg_target * mask).sum((0, 2, 3)) / mask.sum((0, 2, 3))
loss = class_loss.mean()

or average it on H, W only before the final mean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand well here: do you want to remove each ...[mask] occurrence ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the difference is the following:

  • my_tensor[mask] is 1D tensor with a number of elements = number of True in the mask
  • my_tensor * mask.to(dtype=torch.float32) has the shape shape as my_tensor, it only puts zero on elements that are masked out

Now if you perform a reduction operation like mean:

  • in the first case, you divide by the number of elements in my_tensor[mask] = mask.sum()
  • in the second one, you dive by the number of elements in my_tensor

And this extends to dimension-specific operations, so if we mask it, we lose the separation of the dimensions to get a contiguous tensor in the end. To properly scale the loss, in the first case, this widely increases the contribution of classes with the highest amount of positive mask (makes no difference if there is only a single class). And since we specifically want to help balance the less-frequent classes here, I suggest leveraging that second option with my suggestion above 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and we need to do this for both the dice loss and the focal loss)

it will only make a difference in cases of multi-class and where the mask isn't only True but that might be safer!

Either way, I think we should run a training with the configuration to make sure this yields a positive change 👍

Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few corrections in the loss computation and we'll be good to go!

doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final cosmetic adjustments and we're good to go 👍

doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/tensorflow.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, a few fixes to do on my previous suggestion

doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/pytorch.py Outdated Show resolved Hide resolved
doctr/models/detection/linknet/tensorflow.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@charlesmindee charlesmindee merged commit b06a27f into main Feb 23, 2022
@charlesmindee charlesmindee deleted the focal branch February 23, 2022 16:21
felixdittrich92 pushed a commit to felixdittrich92/doctr that referenced this pull request Feb 24, 2022
* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes
felixdittrich92 pushed a commit to felixdittrich92/doctr that referenced this pull request Feb 24, 2022
* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes
felixdittrich92 added a commit to felixdittrich92/doctr that referenced this pull request Feb 24, 2022
charlesmindee added a commit that referenced this pull request Apr 5, 2022
* backup

* onnx classification

* fix: Fixed some ResNet architecture imprecisions (#828)

* feat: Added new resnets

* feat: Added ResNet101

* fix: Fixed ResNet31 & ResNet34 wide

* feat: Added new pretrained resnets

* style: Fixed isort

* fix: Fixed ResNet architectures

* refactor: Refactored LinkNet

* feat: Added more LinkNets

* fix: Fixed MAGResNet

* docs: Updated documentation

* refactor: Removed ResNet101

* fix: Fixed warning

* fix: Fixed a few bugs

* test: Updated unittests

* docs: Fixed docstrings

* update with new models

* feat: replace bce by focal loss in linknet loss (#824)

* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes

* Revert "feat: replace bce by focal loss in linknet loss (#824)"

This reverts commit 6511183.

* Revert "fix: Fixed some ResNet architecture imprecisions (#828)"

This reverts commit 72e5e0d.

* happy codacy

* sapply suggestions

* fix-setup

* remove onnx from test req

* move onnx deps ftm to torch

* up

* up

* revert requirements

* fix

* update docstring

* up

Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com>
Co-authored-by: Charles Gaillard <charles@mindee.co>
felixdittrich92 added a commit to felixdittrich92/doctr that referenced this pull request Apr 5, 2022
* backup

* onnx classification

* fix: Fixed some ResNet architecture imprecisions (mindee#828)

* feat: Added new resnets

* feat: Added ResNet101

* fix: Fixed ResNet31 & ResNet34 wide

* feat: Added new pretrained resnets

* style: Fixed isort

* fix: Fixed ResNet architectures

* refactor: Refactored LinkNet

* feat: Added more LinkNets

* fix: Fixed MAGResNet

* docs: Updated documentation

* refactor: Removed ResNet101

* fix: Fixed warning

* fix: Fixed a few bugs

* test: Updated unittests

* docs: Fixed docstrings

* update with new models

* feat: replace bce by focal loss in linknet loss (mindee#824)

* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes

* Revert "feat: replace bce by focal loss in linknet loss (mindee#824)"

This reverts commit 6511183.

* Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)"

This reverts commit 72e5e0d.

* happy codacy

* sapply suggestions

* fix-setup

* remove onnx from test req

* move onnx deps ftm to torch

* up

* up

* revert requirements

* fix

* update docstring

* up

Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com>
Co-authored-by: Charles Gaillard <charles@mindee.co>
felixdittrich92 added a commit to felixdittrich92/doctr that referenced this pull request Apr 7, 2022
* backup

* onnx classification

* fix: Fixed some ResNet architecture imprecisions (mindee#828)

* feat: Added new resnets

* feat: Added ResNet101

* fix: Fixed ResNet31 & ResNet34 wide

* feat: Added new pretrained resnets

* style: Fixed isort

* fix: Fixed ResNet architectures

* refactor: Refactored LinkNet

* feat: Added more LinkNets

* fix: Fixed MAGResNet

* docs: Updated documentation

* refactor: Removed ResNet101

* fix: Fixed warning

* fix: Fixed a few bugs

* test: Updated unittests

* docs: Fixed docstrings

* update with new models

* feat: replace bce by focal loss in linknet loss (mindee#824)

* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes

* Revert "feat: replace bce by focal loss in linknet loss (mindee#824)"

This reverts commit 6511183.

* Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)"

This reverts commit 72e5e0d.

* happy codacy

* sapply suggestions

* fix-setup

* remove onnx from test req

* move onnx deps ftm to torch

* up

* up

* revert requirements

* fix

* update docstring

* up

Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com>
Co-authored-by: Charles Gaillard <charles@mindee.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text detection Related to the task of text detection type: enhancement Improvement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants