Skip to content

Commit

Permalink
Migrate index_add cpu from TH to ATen (pytorch#28421)
Browse files Browse the repository at this point in the history
Summary:
Migrate index_add cpu from TH to ATen.

I couldn't find replacement for get1d and set1d, so doing pointer arithmetic inplace.
Pull Request resolved: pytorch#28421

Test Plan: existing tests

Differential Revision: D18060971

Pulled By: ggoossen

fbshipit-source-id: 413719990cdb2fe578964cde14e93577e48a4342
  • Loading branch information
ggoossen authored and facebook-github-bot committed Nov 22, 2019
1 parent 183aa15 commit faacbfa
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 68 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -215,6 +215,8 @@
name: _th_index_add_
cname: indexAdd
variants: function
backends:
- CUDA
return: argument 0
arguments:
- THTensor* self
Expand Down
64 changes: 64 additions & 0 deletions aten/src/ATen/native/Indexing.cpp
Expand Up @@ -55,6 +55,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/core/EnableNamedTensor.h>

#include <algorithm>
Expand Down Expand Up @@ -314,6 +315,69 @@ Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const
return self.clone(at::MemoryFormat::Preserve).index_copy_(dim, index, source);
}


Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
dim = maybe_wrap_dim(dim, self.dim());

auto numel = index.numel();
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
"index_add_(): self and source must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < source.dim(),
"index_add_(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
"index_add_(): Number of indices should be equal to self.size(dim)");

auto index_contig = index.contiguous();
auto index_data = index_contig.data_ptr<int64_t>();

if (self.dim() > 1) {
// Equivalent to:
// for (auto i = 0; i < numel; i++) {
// auto selfSlice = self.select(dim, index_data[i]);
// auto sourceSlice = source.select(dim, i);
// selfSlice.add_(sourceSlice);
// }
// But much faster as this reuses the iterator from add_
if (numel == 0) {
return self;
}
auto selfSlice = self.select(dim, 0);
auto sourceSlice = source.select(dim, 0);
auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
auto self_dim_size = self.size(dim);
auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);

for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
iter.unsafe_replace_operand(0, self_data);
iter.unsafe_replace_operand(1, self_data);
iter.unsafe_replace_operand(2, source_data);
add_stub(iter.device_type(), iter, 1);
}
}
else {
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");

AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] {
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
scalar_t *self_ip = self.data<scalar_t>() + self_i * self_stride;
*self_ip += *(source.data<scalar_t>() + i * source_stride);
}
});
}
return self;
}

Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source);
}
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/TensorIterator.cpp
Expand Up @@ -630,9 +630,8 @@ void TensorIterator::remove_operand(int arg) {
operands_.erase(operands_.begin() + arg);
}

void TensorIterator::replace_operand(int arg, void* data, IntArrayRef stride) {
void TensorIterator::unsafe_replace_operand(int arg, void* data) {
operands_[arg].data = data;
operands_[arg].stride_bytes = stride;
}

void TensorIterator::remove_dimension(int dim) {
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/TensorIterator.h
Expand Up @@ -238,8 +238,10 @@ struct CAFFE2_API TensorIterator {
void narrow(int dim, int64_t start, int64_t size);
/// Narrows every dim after and including `start_dim` to size one.
void select_all_keeping_dim(int start_dim, IntArrayRef starts);
/// Replaces the data pointer and strides for the operand at index `arg`
void replace_operand(int arg, void* data, IntArrayRef stride);
/// Replaces the data pointer for the operand at index `arg`.
/// The new pointer should have the same sizes, strides and dtype as the
/// original
void unsafe_replace_operand(int arg, void* data);

/// Splits this TensorIterator into two iterators. Together they iterate over
/// the entire operation. Used by `with_32bit_indexing()`.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3751,7 +3751,7 @@
- func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_index_add_
CPU: index_add_cpu_
CUDA: legacy::cuda::_th_index_add_

- func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
Expand Down
41 changes: 0 additions & 41 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Expand Up @@ -723,47 +723,6 @@ void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTen

#if !defined(TH_REAL_IS_BOOL)

void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
ptrdiff_t i, numel;
THTensor *tSlice, *sSlice;
int64_t *index_data;

numel = THLongTensor_nElement(index);
THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector");
THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(src), 4,"Indexing dim %d is out of bounds of tensor", dim);
THArgCheck(numel == THTensor_sizeLegacyNoScalars(src, dim),4,"Number of indices should be equal to source:size(dim)");

index = THLongTensor_newContiguous(index);
index_data = THLongTensor_data(index);

if (tensor->dim() > 1)
{
tSlice = THTensor_(new)();
sSlice = THTensor_(new)();

for (i=0; i<numel; i++)
{
THTensor_(select)(tSlice, tensor, dim, index_data[i]);
THTensor_(select)(sSlice, src, dim, i);
THTensor_(cadd)(tSlice, tSlice, 1.0, sSlice);
}

c10::raw::intrusive_ptr::decref(tSlice);
c10::raw::intrusive_ptr::decref(sSlice);
}
else
{
for (i=0; i<numel; i++)
{
THTensor_(set1d)(tensor,
index_data[i],
THTensor_(get1d)(src,i) + THTensor_(get1d)(tensor,index_data[i]));
}
}
THLongTensor_free(index);
}

accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
{
#ifdef BUILD_NAMEDTENSOR
Expand Down
2 changes: 0 additions & 2 deletions aten/src/TH/generic/THTensorMath.h
Expand Up @@ -107,8 +107,6 @@ TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);

#if !defined(TH_REAL_IS_BOOL) /* non bool only part */

TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);

TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);

TH_API void THTensor_(cinv)(THTensor *self, THTensor *src);
Expand Down
39 changes: 19 additions & 20 deletions test/test_torch.py
Expand Up @@ -2621,26 +2621,25 @@ def checkPartialAssign(index):
reference[0.0, :, 0.0] = 1

def test_index_add(self):
num_copy, num_dest = 3, 3
dest = torch.randn(num_dest, 4, 5)
src = torch.randn(num_copy, 4, 5)
idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
dest2 = dest.clone()
dest.index_add_(0, idx, src)
for i in range(idx.size(0)):
dest2[idx[i]] += src[i]
self.assertEqual(dest, dest2)

dest = torch.randn(num_dest)
src = torch.randn(num_copy)
idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
dest2 = dest.clone()
dest.index_add_(0, idx, src)
for i in range(idx.size(0)):
dest2[idx[i]] = dest2[idx[i]] + src[i]
self.assertEqual(dest, dest2)

# add coverage for issue with atomic add that appeared only for
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
for other_sizes in ((), (4, 5)):
num_copy, num_dest = 3, 3
dest = torch.randn(num_dest, *other_sizes)
if not dest_contig:
dest = torch.testing.make_non_contiguous(dest)
src = torch.randn(num_copy, *other_sizes)
if not src_contig:
src = torch.testing.make_non_contiguous(src)
idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
if not index_contig:
idx = torch.testing.make_non_contiguous(idx)
dest2 = dest.clone()
dest.index_add_(0, idx, src)
for i in range(idx.size(0)):
dest2[idx[i]] += src[i]
self.assertEqual(dest, dest2)

# add coverage for issue with atomic add that appeared only for
# specific dtypes on cuda:
# https://github.com/pytorch/pytorch/issues/29153
def test_index_add_all_dtypes(self):
Expand Down

0 comments on commit faacbfa

Please sign in to comment.