Skip to content

Commit

Permalink
Fix undefined behavior detected by clang-12 (pytorch#106354)
Browse files Browse the repository at this point in the history
Compiler behavior when non-zero offset is added to a null pointer is undefined and is a bad habit.

- When `lapackEig` is called with to estimate a workspace size, do not add matrix size to the W pointer.
- When `unpack_pivots_cpu_kernel` with zero `dim_size` exit early.
- When `topk_impl_loop` is called with  `k` is zero, exit right away as output tensors are empty anyway.
- Ignore adding non-zero storage-offset in `TensorImpl::data_ptr_impl_impl`, which can be the case if tensor is created as `torch.empty(3)[4:]`.
- In `s_addmm_out_sparse_dense_worker` do not call `axpy` over an empty vector.
- In `_sparse_binary_op_intersection_kernel_impl` do skip computing `ptr_indices_dim` when `sparse_dim` is empty.
- Exit `grid_sample` forward/backward kernels earlier if either `input` or `grid` are empty tensors.

Found by asan in clang-12

Before the change UBSan report looks as follows:
```
 ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-12/bin/llvm-symbolizer UBSAN_OPTIONS=print_stacktrace=1 LD_PRELOAD=/usr/lib/llvm-12/lib/clang/12.0.1/lib/linux/libclang_rt.asan-x86_64.so python test_fx_experimental.py -v -k test_normalize_operator_exhaustive_linalg_eig_cpu_float32
Test results will be stored in test-reports/python-unittest/test_fx_experimental

Running tests...
----------------------------------------------------------------------
  test_normalize_operator_exhaustive_linalg_eig_cpu_float32 (__main__.TestNormalizeOperatorsCPU) ... /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
/var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17: runtime error: applying non-zero offset 20 to null pointer
    #0 0x7f2025794888 in void at::native::lapackEig<float, float>(char, char, int, float*, int, float*, float*, int, float*, int, float*, int, float*, int*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9945888)
    #1 0x7f20257da256 in void at::native::(anonymous namespace)::apply_linalg_eig<float>(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998b256)
    #2 0x7f20257d902d in at::native::(anonymous namespace)::linalg_eig_kernel(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor const&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998a02d)
    #3 0x7f20257b5b3d in at::native::linalg_eig_out_info(at::Tensor const&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9966b3d)
    #4 0x7f20257b4770 in at::native::linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9965770)
    #5 0x7f20280710e6 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU_out_linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&))>, std::tuple<at::Tensor&, at::Tensor&>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor&, at::Tensor&> >, std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc2220e6)
    #6 0x7f202727a045 in at::_ops::linalg_eig_out::call(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42b045)
    #7 0x7f20257b7e29 in at::native::linalg_eig(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9968e29)
    #8 0x7f2028070bf0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU__linalg_eig(at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc221bf0)
    #9 0x7f2026b1f787 in std::tuple<at::Tensor, at::Tensor> c10::Dispatcher::redispatch<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&>(c10::TypedOperatorHandle<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&) const (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xacd0787)
    #10 0x7f20273230a7 in at::_ops::linalg_eig::redispatch(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb4d40a7)
    #11 0x7f202c3cc32d in torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057d32d)
    #12 0x7f202c3cba96 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&), &(torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057ca96)
    #13 0x7f20272798e0 in at::_ops::linalg_eig::call(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42a8e0)
    #14 0x7f2043d97ae3 in torch::autograd::THPVariable_linalg_eig(_object*, _object*, _object*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_python.so+0x23feae3)
    #15 0x5072d6 in cfunction_call /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543:19
    ...

SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17 in
```

Pull Request resolved: pytorch#106354
Approved by: https://github.com/huydhn, https://github.com/lezcano
  • Loading branch information
malfet authored and pytorchmergebot committed Aug 3, 2023
1 parent 6e2a284 commit 97396cd
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 21 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
double *wr = w;
double *wi = w + n;
double *wi = w ? w + n : nullptr;
(void)rwork; // unused
dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}
Expand All @@ -934,7 +934,7 @@ template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int ld
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
float *wr = w;
float *wi = w + n;
float *wi = w ? w + n : nullptr;
(void)rwork; // unused
sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}
Expand Down
9 changes: 4 additions & 5 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#endif
namespace at { namespace native {
namespace at::native {

namespace {
/*
Expand Down Expand Up @@ -1102,15 +1102,14 @@ void svd_kernel(const Tensor& A,
}

void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) {
if (iter.numel() == 0) {
if (iter.numel() == 0 || dim_size == 0) {
return;
}
auto loop = [&](char* const* const data, const int64_t* const strides, const int64_t nelems) {
auto* perm_ptr = data[0];
const auto* pivots_ptr = data[1];

for (const auto elem : c10::irange(nelems)) {
(void)elem; //Suppress unused variable warning
for (C10_UNUSED const auto elem : c10::irange(nelems)) {
// WARNING: linalg.lu_factor returns int32 pivots,
// this behavior could change in the future.
const auto perm_data = reinterpret_cast<int64_t*>(perm_ptr);
Expand Down Expand Up @@ -1224,4 +1223,4 @@ REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
}} // namespace at::native
} // namespace at::native
14 changes: 14 additions & 0 deletions aten/src/ATen/native/GridSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ namespace {
int64_t out_H = grid.size(2);
int64_t out_W = grid.size(3);
auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
if (output.numel() == 0) {
return output;
}
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sD = input.stride(2);
Expand Down Expand Up @@ -219,6 +222,10 @@ namespace {
}
})();
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (grid.numel() == 0 || input.numel() == 0) {
grad_grid.zero_();
return std::make_tuple(grad_input, grad_grid);
}
// If interpolation mode is Nearest, then grad_grid is not filled in the
// loop below.
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
Expand Down Expand Up @@ -567,6 +574,9 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
int64_t out_H = grid.size(1);
int64_t out_W = grid.size(2);
auto output = at::empty({N, C, out_H, out_W}, input.options());
if (output.numel() == 0) {
return output;
}
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sH = input.stride(2);
Expand Down Expand Up @@ -715,6 +725,10 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,

auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (grid.numel() == 0 || input.numel() == 0) {
grad_grid.zero_();
return std::make_tuple(grad_input, grad_grid);
}
// If interpolation mode is Nearest, then grad_grid is not filled in the
// loop below.
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,9 @@ TORCH_IMPL_FUNC(index_copy_out)
// Not calling into index_reduce_func_impl because of a different dtype dispatch
TORCH_IMPL_FUNC(index_add_cpu_out)
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
if (!result.is_same(self)) result.copy_(self);
if (!result.is_same(self)) {
result.copy_(self);
}
auto numel = index.numel();

auto index_contig = index.contiguous();
Expand All @@ -870,7 +872,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
// selfSlice.add_(sourceSlice);
// }
// But much faster as this reuses the iterator from add_
if (numel == 0) {
if (numel == 0 || self.numel() == 0) {
return;
}

Expand Down Expand Up @@ -945,8 +947,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
add_stub(iter.device_type(), iter, alpha);
}
});
}
else {
} else {
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");

// explicitly capture all required variables to work around windows build
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/TopKImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ void topk_impl_loop(
const bool sorted,
char** data, const int64_t* strides, const int64_t n) {

// If k is zero, then output values and indices are empty tensors
// So iterating over other dims is pointless
if (k == 0) {
return;
}
using elem_t = std::pair<accscalar_t, int64_t>;
std::vector<elem_t> queue(dim_size);
for (const auto i : c10::irange(n)) {
Expand Down
9 changes: 8 additions & 1 deletion aten/src/ATen/native/cpu/GridSamplerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,13 +1157,16 @@ void grid_sampler_2d_cpu_kernel_impl(
auto spatial_size = H * W;
auto grain_size = spatial_size == 0 ? (N + 1)
: at::divup(at::internal::GRAIN_SIZE, spatial_size * 4 /* 2d * 2 tensors*/);
if (output.numel() == 0) {
return;
}

#define HANDLE_CASE(interp, padding, align_corners) \
case padding: { \
ApplyGridSample<scalar_t, 2, interp, padding, align_corners> \
grid_sample(inp_acc); \
parallel_for(0, N, grain_size, [&](int64_t begin, int64_t end) { \
for (const auto n : c10::irange(begin, end)) { \
for (const auto n : c10::irange(begin, end)) { \
auto out_slice = out_acc[n]; \
auto inp_slice = inp_acc[n]; \
grid_sample_2d_grid_slice_iterator( \
Expand Down Expand Up @@ -1220,6 +1223,10 @@ void grid_sampler_2d_backward_cpu_kernel_impl(
int64_t padding_mode,
bool align_corners,
std::array<bool,2> output_mask) {
if (grad_output_.numel() == 0) {
grad_grid.zero_();
return;
}
// grad_output should be contiguous most of time. Ensuring that it is
// contiguous can greatly simplify this code.
auto grad_output = grad_output_.contiguous();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/SparseFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void _spdiags_kernel_cpu(
TensorBase& values,
TensorBase& indices) {
auto* row_index_write_ptr = indices.data_ptr<int64_t>();
auto* col_index_write_ptr = row_index_write_ptr + indices.stride(0);
auto* col_index_write_ptr = row_index_write_ptr ? row_index_write_ptr + indices.stride(0) : nullptr;
const int64_t diagonals_index_stride = diagonals.stride(0);
const int64_t diagonals_read_stride = diagonals.stride(1);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/group_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ void GroupNormInputBackward(
const int64_t g = i % G;
const T_ACC* ds_ptr = ds + i * D;
const T_ACC* db_ptr = db + i * D;
const PT* gamma_ptr = gamma + g * D;
const PT* gamma_ptr = !gamma_null ? gamma + g * D : nullptr;
CalcDsDb(ds_ptr, db_ptr, gamma_null, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
T_ACC ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), T_ACC(0));
T_ACC db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), T_ACC(0));
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ void _sparse_binary_op_intersection_kernel_impl(
KernelLauncher::launch(iter,
// NOTE: capture by value required by CUDA
[=] FUNCAPI (index_t nnz_idx) -> int64_t {
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
const auto* RESTRICT ptr_indices_dim = ptr_indices ? ptr_indices + nnz_idx * indices_nnz_stride : nullptr;
int64_t hash = 0;
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
const auto dim_hash_coeff = hash_coeffs[dim];
Expand All @@ -299,8 +299,7 @@ void _sparse_binary_op_intersection_kernel_impl(
// NOTE: argsort.dtype == nnz_arange.dtype
const auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
return std::make_tuple(probably_coalesced_indices_hash, argsort);
}
else {
} else {
// NOTE: we want argsort.dtype == nnz_arange.dtype,
// but sort() produces indices of type int64_t,
// so we convert to nnz_arange.dtype to avoid issues
Expand Down Expand Up @@ -360,12 +359,12 @@ void _sparse_binary_op_intersection_kernel_impl(
KernelLauncher::launch(iter,
// NOTE: capture by value required by CUDA
[=] FUNCAPI (index_t nnz_idx) -> index_t {
// Compute hash value
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
int64_t hash = 0;
if (hash_ptr) {
hash = hash_ptr[nnz_idx];
} else {
} else if (sparse_dim) {
// Compute hash value
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
const auto dim_hash_coeff = hash_coeffs[dim];
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,10 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
int64_t row = indices_accessor[0][i];
int64_t col = indices_accessor[1][i];
if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
// AXPY call is no-op over an empty vector
if (dim_k == 0) {
continue;
}
at::native::cpublas::axpy<scalar_t>(dim_k,
cast_alpha * val,
dense_ptr + col * dense_stride0, dense_stride1,
Expand Down
6 changes: 5 additions & 1 deletion c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// Shared implementation of mutable_data_ptr_impl() and the future
// mutable_data_ptr_impl().
template <typename T, typename Func>
T* data_ptr_impl_impl(const Func& get_data) const {
__ubsan_ignore_pointer_overflow__ T* data_ptr_impl_impl(
const Func& get_data) const {
if (C10_UNLIKELY(!has_storage())) {
throw_data_ptr_access_error();
}
Expand All @@ -1560,6 +1561,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
"Caffe2 uses a lazy allocation, so you will need to call "
"mutable_data() or raw_mutable_data() to actually allocate memory.");
// Caller does the type check.
// Note: storage_offset_ can be non-null even for zero-elements tensors
// (for example if created as `torch.empty(5)[10:]`) that triggers
// applying non-zero offset to null pointer in UBSan
return get_data() + storage_offset_;
}

Expand Down
3 changes: 3 additions & 0 deletions c10/macros/Macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
#define __ubsan_ignore_signed_int_overflow__ \
__attribute__((no_sanitize("signed-integer-overflow")))
#define __ubsan_ignore_pointer_overflow__ \
__attribute__((no_sanitize("pointer-overflow")))
#define __ubsan_ignore_function__ __attribute__((no_sanitize("function")))
#else
#define __ubsan_ignore_float_divide_by_zero__
#define __ubsan_ignore_undefined__
#define __ubsan_ignore_signed_int_overflow__
#define __ubsan_ignore_pointer_overflow__
#define __ubsan_ignore_function__
#endif

Expand Down

0 comments on commit 97396cd

Please sign in to comment.