Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Batch-dependent behavior in SoftmaxCEMaskedLoss? #238

Open
JulianSlzr opened this issue Jul 31, 2018 · 1 comment
Open

Batch-dependent behavior in SoftmaxCEMaskedLoss? #238

JulianSlzr opened this issue Jul 31, 2018 · 1 comment

Comments

@JulianSlzr
Copy link

Hi all,

Consider the following (placed in scripts/nmt/):

import mxnet as mx
import loss

if __name__ == '__main__':
    BATCH_SIZE = 3
    MAX_SEQ_LEN = 100
    VOCAB_SIZE = 26
    # Each time step has a uniform distribution over the vocabulary
    uniform_tensor = 1./VOCAB_SIZE * mx.nd.ones(shape=(BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE))
    # Sequences have different lengths
    valid_lens = mx.nd.array((1, 10, 100))

    loss = loss.SoftmaxCEMaskedLoss(sparse_label=False, from_logits=False)
    ce_loss = loss(uniform_tensor, uniform_tensor, valid_lens)
    print(ce_loss)

This outputs:

[0.03258096 0.32580966 3.2580965 ]
<NDArray 3 @cpu(0)>

However, these should all be 3.2580965 (which is ln(26)), as this is the (average) CE across all valid timesteps, per sequence.

The problem is the averaging step of CE is not aware of valid_length, which leads to "masking zeros" being included in the mean. Compare with, e.g., Sockeye's implementation.

One workaround is to increase sample weights so their weighted sum compensates for the zeros. For example, adding the following line before calling super in SoftmaxCEMaskedLoss:

sample_weight = sample_weight * (sample_weight.shape[1] / F.reshape(valid_length, shape=(-1, 1, 1)))

I have not evaluated how this affects current MT models. This causes short sequences to have smaller per-token loss, weighting long sequences more.

However, perhaps we should edit SoftmaxCEMaskedLoss for correctness (or someone point out if I'm mistaken); I can submit a PR. The CE loss of an sequence should not depend on the batch it is in.

@szhengac
Copy link
Member

szhengac commented Jul 31, 2018

@JulianSlzr , in the current nmt scripts, the token loss is averaged by the total number of the tokens in the mini-batch. You can imagine it by concatenating all the sequences in the mini-batch to a single one and taking average. You can find it in the following code.
loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
This ensures that all the tokens are equally weighted. In fact, we are minimizing the following objective:
\sum_{i=1}^N\sum_{t=1}^{T_i}\logP(y_t^i|y_1^i, ...y_{t-1}^i, X^i, \Theta)
But sure, what you propose can be useful in some scenarios. A PR is always welcome.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants