Skip to content

Commit

Permalink
Fix uniform returning end point for BFloat16 and Half (#96962)
Browse files Browse the repository at this point in the history
Fixes #96947

If we generate `1.0 - float_eps`, the BFloat16 and Half constructors will round this to 1.0 which is outside of the half-open range. Instead, we delay the bounds change until after the value has been rounded.

Pull Request resolved: pytorch/pytorch#96962
Approved by: https://github.com/lezcano, https://github.com/ngimel
  • Loading branch information
peterbell10 authored and cyyever committed Mar 27, 2023
1 parent df152f4 commit c23608d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandBase.h>
#include <ATen/OpMathType.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/util/Half.h>
Expand Down Expand Up @@ -458,19 +459,22 @@ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen)
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
using accscalar_t = at::acc_type<scalar_t, true>;
auto range = static_cast<accscalar_t>(to-from);
using opmath_t = at::opmath_type<scalar_t>;
auto range = static_cast<opmath_t>(to-from);
// define lambda to reverse bounds, multiply 'range' and add 'from_'
auto uniform_func = [range, from] __device__ (accscalar_t rand) {
auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
// Compute output value before reversing the bounds
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
auto value = static_cast<scalar_t>(rand * range + from);
// reverse the bounds of curand4 from (0, 1] to [0, 1)
// Note that this method is from legacy THCTensorRandom and is likely to give
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
auto reverse_bound_rand = rand == static_cast<accscalar_t>(1.0) ? static_cast<accscalar_t>(0.0) : rand;
return static_cast<scalar_t>(reverse_bound_rand * range + from);
auto reverse_bound_value = value == to ? from : value;
return reverse_bound_value;
};
uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, uniform_func);
uniform_and_transform<scalar_t, opmath_t, curand4_engine_calls>(iter, gen, uniform_func);
});
}

Expand Down
2 changes: 1 addition & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ def test_flash_attention_vs_math_ref_grads(self, batch_size: int, seq_len_q: int

# TODO: Investigate why grad_q needs larger tolerances
grad_q_deviation = query_ref.grad - query_ref_lp.grad
grad_q_ref_atol = max(2 * torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype])
grad_q_ref_atol = max(4 * torch.abs(grad_q_deviation).max().item(), default_atol[out.dtype])
grad_q_ref_rtol = max(get_rtol(query_ref.grad, query_ref_lp.grad), default_rtol[out.dtype])

grad_k_deviation = key_ref.grad - key_ref_lp.grad
Expand Down

0 comments on commit c23608d

Please sign in to comment.