Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_class to sparse crossentropy and IoU #16712

Merged
merged 5 commits into from
Jul 18, 2022

Conversation

lucasdavid
Copy link
Contributor

@lucasdavid lucasdavid commented Jun 23, 2022

Fix #6118
Fix #5911
Relates to keras-team/tf-keras#617

Summary

  • Add the ignore_class: Optional[int] parameter to the following functions/constructors:
    • backend.sparse_categorical_crossentropy
    • lossses.sparse_categorical_crossentropy
    • metrics.SparseCategoricalCrossentropy
    • metrics._IoUBase
    • metrics.IoU
    • metrics.MeanIoU
    • metrics.OneHotIoU
    • metrics.OneHotMeanIoU
  • Add sparse_y_true: bool and sparse_y_pred: bool parameters in _IoUBase, IoU, MeanIoU metric classes.
  • Add sparse_y_pred:bool to the OneHotIoU and OneHotMeanIoU metric classes and refactor these classes to reuse more of the base class.
  • Refactor: A replicated code section shared among backend.categorical_crossentropy, backend.sparse_categorical_crossentropy, and backend.binary_crossentropy into a single function named _get_logits.

Goals

  1. ignore_class: In segmentation problems, some pixels in segmentation maps might not represent valid categorical labels. Examples:

    • object boundaries are marked with void category, as the annotators disagree on which label to attribute
    • small maps are padded with the void class to conform with the sizes of larger ones after Dataset#padded_batch
    • specific objects out of the context of the problem, such as the hood of a car being captured by a static camera
    • pseudo-labels (originated from weakly supervised strategies) might contain pixels/regions where label is uncertain

    It's common to attribute the label -1 or 255 and ignore these pixels during training. This PR implements this feature by masking the target and the output signals, only computing the metrics over the valid pixels. Moreover, it mirrors PyTorch's CrossEntropyLoss(ignore_index=-100).

  2. sparse_y_pred: IoU and MeanIoU assumes both target and output are sparse signals, where categories are represented as natural integers. Conversely, OneHotIoU and OneHotMeanIoU assume both are probability distribution vectors. This is far from what I believe to be the most obvious case: sparse segmentation labels and dense output vectors:

    >>> classes = 20
    >>> model = Sequential([
    >>>    ResNet50V2(input_shape=[512, 512, 3], include_top=False, pooling=None, weights=None),
    >>>    Conv2D(classes, kernel_size=1, activation='softmax', name='predictions')
    >>> ])
    >>> print(model.output.shape)
    (None, 16, 16, 20)

    So now IoU can be easily used as this:

    model.compile(opt='sgd', loss='sparse_categorical_crossentropy', metrics=[
      MeanIoU(classes, sparse_y_pred=False, ignore_index=-1)
    ])

Limitations

Currently, backend.sparse_categorical_crossentropy only reduces the dimension containing the logits, and the result is reshaped into the original output shape (except for the last axis) if the information is available.
However, when a pixel is not valid, its associated cross-entropy value is not available and reshape cannot occur without creating a ragged tensor. Therefore, when ignore_index is not None (and only then), I opted to sum all cross-entropy values over the axes range(1, output_rank-1) and divide by the number of valid pixels (similar to what pytorch does). In this case, the output tensor will have shape=[output_shape[0]]=[batch_size]. An alternative would be to return a flatten array containing only valid entries, though the batch information would be lost and the user would have difficulties if they had per-sample operations being applied to these loss values.

No visible limitations now. backend.sparse_categorical_crossentropy will set the _keras_mask property in the loss Tensor, which will be used during the reduction procedure to mask out invalid pixels.

@lucasdavid lucasdavid force-pushed the ignore_index branch 5 times, most recently from f1cdbfb to 69eddd1 Compare June 24, 2022 02:33
@lucasdavid lucasdavid changed the title Add ignore_index crossentropy and IoU Add ignore_index to sparse crossentropy and IoU Jun 24, 2022
@gbaned gbaned requested a review from fchollet June 27, 2022 12:13
@google-ml-butler google-ml-butler bot added the keras-team-review-pending Pending review by a Keras team member. label Jun 27, 2022
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Wouldn't ignore_class be a more explicit / precise name for this argument?

keras/backend.py Outdated Show resolved Hide resolved
keras/backend.py Outdated Show resolved Hide resolved
@fchollet fchollet removed the keras-team-review-pending Pending review by a Keras team member. label Jun 29, 2022
@lucasdavid
Copy link
Contributor Author

lucasdavid commented Jun 30, 2022

Wouldn't ignore_class be a more explicit / precise name for this argument?

I named it after torch's, hopping that people would make the association transparently. Furthermore, I think caffe called it ignore_label.
All are good choices, in my opinion. Let me know which one should I keep.

The new argument should be at the end of the signature.

Should I use the same parameter order in the modules (losses, metrics)?

This description is not understandable for someone who does not already know what the argument does.

I fixed it. Let me know if it still needs improvement.

@lucasdavid lucasdavid force-pushed the ignore_index branch 4 times, most recently from 49dce10 to b3ab864 Compare June 30, 2022 13:41
@gbaned gbaned requested a review from fchollet June 30, 2022 14:11
@google-ml-butler google-ml-butler bot added the keras-team-review-pending Pending review by a Keras team member. label Jun 30, 2022
@fchollet
Copy link
Member

fchollet commented Jul 2, 2022

I named it after torch's, hopping that people would make the association transparently. Furthermore, I think caffe called it ignore_label.

ignore_index is confusing, because an "index" is typically an integer you use for indexing an array, something like x[4], etc. If you ask someone, "I want to ignore index 2 in my crossentropy loss", is that readily understandable? I'm not so sure. It seems far too imprecise, you don't know what "index" refers to in the context of a crossentropy loss.

ignore_label is slightly incorrect because a "label" is a class annotation for a specific sample (an instance of a class). We're looking to ignore an entire class here. People often get class and label confused. It starts mattering in multi-label tasks.

ignore_class is precise (it tells you what you want to ignore: a class from your class set) and accurate (you're talking a class in general, not a label attached to a sample).

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

keras/backend.py Outdated Show resolved Hide resolved
keras/losses.py Outdated Show resolved Hide resolved
@lucasdavid lucasdavid changed the title Add ignore_index to sparse crossentropy and IoU Add ignore_class to sparse crossentropy and IoU Jul 3, 2022
@lucasdavid
Copy link
Contributor Author

lucasdavid commented Jul 5, 2022

Some pixels in the segmentation map were ignored, which means cross-entropy values weren't computed for them (I only pass valid pixels for tf.nn.sparse_softmax_cross_entropy_with_logits, and then tf.scatter_nd them into a new tensor with the original shape). Thus the non-ignored pixels have correct cross-entropy values and the ignored ones have zeros.
Now we have to average them considering only the valid pixels. Leaving reduction for losses_utils.compute_weighted_loss (which uses tf.reduce_mean or tf.reduce_sum) would be incorrect, because it would also consider these zeros into the mix and artificially reduce the cost value.

Here is an example with a batch of one sample, that is a segmentation map of size (2, 2) containing up to three classes:

y_true = [
  [[ 0,  2],
   [-1, -1]]]
y_pred = [
  [[[1.0, 0.0, 0.0], [0.0, 0.5, 0.5]],
   [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]]]

If we were to ignore -1, then backend.sparse_categorical_crossentropy proceeds as:

valid_mask = [[               # L5621
  [True, True],
  [False, False],
]]
target = [0, 2]               # L5622 (select valid pixel labels)
output = tf.math.log(y_pred)  # L5585, not from_logits
output = [
  [ 0.  ,    -inf,    -inf],
  [ -inf, -0.6931, -0.6931]]  # L5623 (select probabilities associated with valid pixels)
res = [0., 0.6931]  # L5631 (tf.nn.softmax_crossentropy...)
res = [
  [[ 0., 0.6931],
   [ 0., 0.    ]]]  # L5639 (reconstruct samples with tf.scatter_nd)
res = [0.3466]      # L5647 (average amongst valid pixels -- the two in the top row)

res will contain the exact number of samples (so sample_weight still works), while all invalid pixels were ignored when computing the cost function value for each sample. Does it make sense?


For reference, pytorch works similarly (implementation at aten/src/ATen/native/LossNLL.cpp):

ignore_index (int, optional) – Specifies a target value that is ignored and does not contribute to the input gradient. When size_average is True, the loss is averaged over non-ignored targets. Note that ignore_index is only applicable when the target contains class indices.

This "issue" isn't as apparent because the low level loss function is also responsible for doing the reduction:

>>> y_true = torch.tensor([
...  [[ 0,  2],
...   [-1, -1]]])
>>> y_pred = torch.from_numpy(np.asarray([
... [[1.0, 0.0, 0.0], [0.0, 0.5, .5]],
... [[0.2, 0.5, 0.3], [0.0, 1.0, 0.0]]]).transpose((2, 0, 1))[np.newaxis, ...])
>>>
>>> loss=torch.nn.CrossEntropyLoss(ignore_index=-1)
>>> loss(torch.log(y_pred), y_true)
tensor(0.3466, dtype=torch.float64)
>>> loss=torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
>>> loss(torch.log(y_pred), y_true)
tensor([[[-0.0000, 0.6931],
         [ 0.0000, 0.0000]]], dtype=torch.float64)

@lucasdavid
Copy link
Contributor Author

lucasdavid commented Jul 7, 2022

Loss reduction needs to be controllable by the user when writing custom distributed training loops.

I understand that the batch (first axis) should not be reduced in distributed environments. However, considering that we don't split a single sample in multiple replicas and that the default reduction is called SUM_OVER_BATCH_SIZE, I don't quite see the problem in reducing the remaining (within sample) axes.


I implemented it as you asked. I reused existing get_mask and apply_mask functions from the compile_utils module (moved them to losses_utils to avoid a cyclical import). Let me know if more test cases are required.

I admit that sparse_categorical_crossentropy is more consistent now, but code had to be added to at least three different classes in order to achieve the same result (losses.Loss, metrics.MeanMetricWrapper and metrics.SumOverBatchSizeMetricWrapper). More importantly, now sparse_categorical_crossentropy(ignore_index=...) returns 0s that should be ignored, so it's one more thing for people with custom training operations to worry about.

We probably need to write an example in the docstring of how to handle these in custom training loops, right?

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. The changes all look good to me!

keras/utils/losses_utils.py Outdated Show resolved Hide resolved
keras/metrics/metrics.py Outdated Show resolved Hide resolved
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the answers!

@gbaned gbaned requested a review from fchollet July 13, 2022 08:14
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work, thank you for the contribution. LGTM

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jul 13, 2022
@fchollet fchollet removed the keras-team-review-pending Pending review by a Keras team member. label Jul 13, 2022
@copybara-service copybara-service bot merged commit f61f853 into keras-team:master Jul 18, 2022
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 18, 2022
Imported from GitHub PR keras-team/keras#16712

Properly fixes #6118 and #5911.

#### Summary
* Add the `ignore_index: Optional[int]` parameter to the following functions/constructors:
  - `backend.sparse_categorical_crossentropy`
  - `lossses.sparse_categorical_crossentropy`
  - `metrics.SparseCategoricalCrossentropy`
  - `metrics._IoUBase`
  - `metrics.IoU`
  - `metrics.MeanIoU`
  - `metrics.OneHotIoU`
  - `metrics.OneHotMeanIoU`
* Add `sparse_labels: bool` and `sparse_preds: bool` parameters in `_IoUBase`, `IoU`, `MeanIoU` metric classes.
* Add `sparse_preds:bool` to the `OneHotIoU` and `OneHotMeanIoU` metric classes.
* Refactor: A replicated code section shared among `backend.categorical_crossentropy`, `backend.sparse_categorical_crossentropy`, and `backend.binary_crossentropy` into a single function named `_get_logits`.

#### Goals
1. **ignore_index**: In segmentation problems, some pixels in segmentation maps might not represent valid categorical labels. Examples:
   - object boundaries are marked with void category, as the annotators disagree on which label to attribute
   - small maps are padded with the *void* class to conform with the sizes of larger ones after `Dataset#padded_batch`
   - specific objects out of the context of the problem, such as the hood of a car being captured by a static camera
   - pseudo-labels (originated from weakly supervised strategies) might contain pixels/regions where label is uncertain

   It's common to attribute the label `-1` or `255` and ignore these pixels during training. This PR implements this feature by masking the target and the output signals, only computing the metrics over the valid pixels. Moreover, it mirrors PyTorch's [CrossEntropyLoss(ignore_index=-100)](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).

2. **sparse_preds**: `IoU` and `MeanIoU` assumes both `target` and `output` are sparse signals, where categories are represented as natural integers. Conversely, `OneHotIoU` and `OneHotMeanIoU` assume both are probability distribution vectors. This is far from what I believe to be the most obvious case: sparse segmentation labels and dense output vectors:
   ```py
   >>> classes = 20
   >>> model = Sequential([
   >>>    ResNet50V2(input_shape=[512, 512, 3], include_top=False, pooling=None, weights=None),
   >>>    Conv2D(classes, kernel_size=1, activation='softmax', name='predictions')
   >>> ])
   >>> print(model.output.shape)
   (None, 16, 16, 20)
   ```

   So now IoU can be easily used as this:
   ```py
   model.compile(opt='sgd', loss='sparse_categorical_crossentropy', metrics=[
     MeanIoU(classes, sparse_preds=False, ignore_index=-1)
   ])
   ```

#### Limitations
Currently, `backend.sparse_categorical_crossentropy` only reduces the dimension containing the logits, and the result is reshaped into the original output shape (except for the last axis) if the information is available.
However, when a pixel is not valid, its associated cross-entropy value is not available and reshape cannot occur without creating a ragged tensor. Therefore, when `ignore_index is not None` (and only then), I opted to sum all cross-entropy values over the axes `range(1, output_rank-1)` and divide by the number of valid pixels (similar to what pytorch does). In this case, the output tensor will have `shape=[output_shape[0]]=[batch_size]`.

An alternative would be to return a flatten array containing only valid entries, though the batch information would be lost and the user would have difficulties if they had per-sample operations being applied to these loss values.
Copybara import of the project:

--
b7f02816b5320855ae528971766fdcaad7134a9b by lucasdavid <lucasolivdavid@gmail.com>:

Add ignore_index crossentropy and IoU

--
1589a843bac4390c8377db05cbd6ae650b6210cc by lucasdavid <lucasolivdavid@gmail.com>:

Remove duplicate convert_to_tensor

--
70f7fb6789e1a7e030737a46847b24b892965e4e by lucasdavid <lucasolivdavid@gmail.com>:

Rename ignore_index to ignore_label, update docs

--
db9f76ac8d1945630061582b03381939349bb59a by lucasdavid <lucasolivdavid@gmail.com>:

Implement masked loss reduction

--
4f1308112f4188c4e14fdf3a59af8fe5f30db61f by lucasdavid <lucasolivdavid@gmail.com>:

Update docs

Merging this change closes #16712

PiperOrigin-RevId: 461661376
@lucasdavid lucasdavid deleted the ignore_index branch July 18, 2022 20:44
copybara-service bot pushed a commit that referenced this pull request Jul 30, 2022
PiperOrigin-RevId: 463204427
copybara-service bot pushed a commit that referenced this pull request Jul 30, 2022
PiperOrigin-RevId: 463204427
@lucasdavid
Copy link
Contributor Author

@fchollet is the PR being rolled back? Did I break something?

@visionscaper
Copy link

@lucasdavid @fchollet Keras and TF are not in sync with respect to this new feature:

  • The TF v2.10.0 documentation for SparseCategoricalCrossentropy mentions the ignore_class parameter
  • tf.keras.losses.SparseCategoricalCrossentropy code (v2.10.0) does not support the ignore_class parameter
  • keras.losses.SparseCategoricalCrossentropy code (v2.10.0) does support the ignore_class parameter

So, although the documentation mentions the parameter, Tensorflow doesn't support it yet, while Keras does.

@lucasdavid
Copy link
Contributor Author

lucasdavid commented Sep 22, 2022

@visionscaper I thought these two synchronized automatically... Maybe this has something to do with PR #16851?
I think it might be best to create a new issue to increase visibility over this problem.

A suspicious thing happened during this PR:

  1. capybara tests were passing (http://cl/460977716);
  2. @fchollet asked me to remove some unnecessary code/comments and to change some names;
  3. I performed the changes and pushed-force;
  4. Capybara tests failed (http://cl/460977716), but the PR was merged anyways

I don't have the necessary access permissions to see what failed in capybara's logs, but all tests cases pass in my machine and in the GPU and CPU CIs.

@lucasdavid
Copy link
Contributor Author

@visionscaper, I believe it's now working in tf-nightly:

import numpy as np
import tensorflow as tf
print(tf.__version__)
print(tf.keras.losses.sparse_categorical_crossentropy(
  np.random.randint(-1, 10, size=[40, 1]),
  np.random.randn(40, 10),
  ignore_class=-1
))
print(tf.keras.losses.SparseCategoricalCrossentropy(ignore_class=-1)(
  np.random.randint(-1, 10, size=[40, 1]),
  np.random.randn(40, 10)
))
2.11.0-dev20221011

<tf.Tensor: shape=(40,), dtype=float64, numpy=
array([17.46086028,  1.64064914, 16.93660172,  2.25964914,  3.02902916,
        2.38092942, 16.97059782, 16.5225091 ,  0.        , 17.21726741,
        1.47735013, 16.65561627, 16.58555793,  0.        ,  1.94407448,
       17.10451514,  2.04504283, 16.77500514,  2.10482156, 16.75909515,
        1.45206898,  1.36361875, 16.52289682, 16.88043939, 17.54349066,
       17.05594301,  2.36618914,  1.87394029, 17.44958865,  1.42225717,
       17.27105659,  0.        , 17.47401625, 17.47470669, 17.42965694,
       17.33150152,  1.03149344,  1.56646177, 16.54211899, 17.32527626])>

<tf.Tensor: shape=(), dtype=float64, numpy=10.997069327410369>

@visionscaper
Copy link

Thanks for letting me know @lucasdavid!

@svobora
Copy link

svobora commented Oct 14, 2022

Well the ignore labels work but it is extremely slow.

@lucasdavid
Copy link
Contributor Author

I used a very similar version of this one to train over Pascal VOC 2012 and I didn't see any performance issues.
Can you give us more information about your problem domain, data and model? Or perhaps a code snippet that illustrates the performance degradation.

@miticollo
Copy link

I noticed that Accuracy metrics don't have an arg like ignore_class. Is there a particular reason?

Background

I used SparseCategoricalCrossentropy with ignore_class=0 and now I would want to do the same with SparseCategoricalAccuracy to be consistent, but this argument doesn't exist. I could create a callable wrapper around sparse_categorical_crossentropy to apply a mask. But I'm not sure if this is a good idea. I mean if I do this I will get an improvement in accuracy but forcing match maybe is like I'm cheating.

@lucasdavid
Copy link
Contributor Author

@miticollo I don't think there's a reason for it. Maybe people just did not hit this use-case enough times to come around and implement it. In any case, a wrapper should be fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull Ready to be merged into the codebase size:L
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support ignore label in cross entropy functions incomplete data annotation
7 participants