You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
of a tensor along along an axis of length $a$, we calculate $1/a$ as a f64 and later use it to renormalize the mean() values from the hadamard product of probs and target_probs. This results in more instructions and an unnecessary loss of precision.
We may also be able to replace the call to mean() with a call to sum(), but it requires fiddling with the tensor dimensions. Using mean() and rescaling by the last axis size calculates a sum of the form $$\text{mean} \cdot \text{last axis numel} = \left(\dfrac{1}{|A||I|}\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |A|$$ whereas we could instead compute $$\text{sum} / \text{foo} = \left(\sum_{(a, i)\in A\times I}p_{a, i}\right)\cdot |I|^{-1}.$$
I couldn't immediately see how to compute for $|I|$ using the generic parameter S or its axis using the existing trait bounds.
This was done specifically for f16 support - notably the max value an f16 can store is 65504, and it has low precision for generally high values.
In the scalar version of div/mul the operands will be converted to the dtype before executing - so the f64 would be converted to an f16 before actually running. By using 1 / a, we have a better chance that f16 can store the actual value.
When computing cross entropy,
of a tensor along along an axis of length$a$ , we calculate $1/a$ as a
f64
and later use it to renormalize themean()
values from the hadamard product ofprobs
andtarget_probs
. This results in more instructions and an unnecessary loss of precision.This could be replaced by
This pattern is repeated in
kl_div_with_logits_loss
.The text was updated successfully, but these errors were encountered: