Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary loss of precision when computing loss functions #872

Closed
ariasanovsky opened this issue Oct 3, 2023 · 2 comments
Closed

Unnecessary loss of precision when computing loss functions #872

ariasanovsky opened this issue Oct 3, 2023 · 2 comments

Comments

@ariasanovsky
Copy link

ariasanovsky commented Oct 3, 2023

When computing cross entropy,

pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    logits: Tensor<S, E, D, T>,
    target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
    let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
    let probs = logits.log_softmax::<S::LastAxis>();
    (probs * target_probs).mean().negate() / inv_last_axis_numel
}

of a tensor along along an axis of length $a$, we calculate $1/a$ as a f64 and later use it to renormalize the mean() values from the hadamard product of probs and target_probs. This results in more instructions and an unnecessary loss of precision.

This could be replaced by

pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    logits: Tensor<S, E, D, T>,
    target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
    let last_axis_numel = <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
    let probs = logits.log_softmax::<S::LastAxis>();
    (probs * target_probs).mean().negate() * last_axis_numel
}

This pattern is repeated in kl_div_with_logits_loss.

@ariasanovsky
Copy link
Author

We may also be able to replace the call to mean() with a call to sum(), but it requires fiddling with the tensor dimensions. Using mean() and rescaling by the last axis size calculates a sum of the form $$\text{mean} \cdot \text{last axis numel} = \left(\dfrac{1}{|A||I|}\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |A|$$ whereas we could instead compute $$\text{sum} / \text{foo} = \left(\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |I|^{-1}.$$

I couldn't immediately see how to compute for $|I|$ using the generic parameter S or its axis using the existing trait bounds.

@coreylowman
Copy link
Owner

This was done specifically for f16 support - notably the max value an f16 can store is 65504, and it has low precision for generally high values.

In the scalar version of div/mul the operands will be converted to the dtype before executing - so the f64 would be converted to an f16 before actually running. By using 1 / a, we have a better chance that f16 can store the actual value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants