Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
element_counts[i] = input_strides[i];
}

int64_t err_indice = 0;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
Expand All @@ -150,13 +149,23 @@ Status ScatterNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) co
for (int64_t i = 0; i < offset_count; ++i) {
for (int64_t j = 0; j < last_indice_dimension; ++j) {
auto indice = *(indice_offset + i * last_indice_dimension + j);
if (indice < 0 || indice >= input_shape[j]) {
err_indice = indice;

if (indice >= 0) {
if (indice >= input_shape[j]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
}
} else {
if (indice < -input_shape[j]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", indice);
} else {
indice += input_shape[j];
}
}

p.element_offsets[i] += indice * element_counts[j];
}
}
return err_indice == 0 ? Status::OK() : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
return Status::OK();
}

Status ScatterND::Compute(OpKernelContext* context) const {
Expand Down
18 changes: 13 additions & 5 deletions onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@ __global__ void _ScatterNDKernel(
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
// is hard. This is the approach taken by other frameworks for out of bound indices
// in their corresponding GPU backends as well.
if (index < 0)
index = 0;

else if (index >= dim_value)
index = dim_value - 1;
// index >= -dim_value && index < dim_value

if (index >= 0) {
if (index >= dim_value) {
index = dim_value - 1;
}
} else {
if (index < -dim_value) {
index = 0;
} else {
index += dim_value;
}
}

data_offset += (index * element_count_dim);
}
Expand Down
31 changes: 28 additions & 3 deletions onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64) {
test.Run();
}

TEST(ScatterNDOpTest, ScatterND_matrice_int64_int64_neg_indices) {
OpTester test("ScatterND", 11);
test.AddInput<int64_t> ("data", {2,2}, {1LL,1LL,2LL,2LL});
test.AddInput<int64_t> ("indices", {2,2}, {0LL,0LL,-1LL,-1LL});
test.AddInput<int64_t>("updates", {2}, {0LL,3LL});
test.AddOutput<int64_t>("output", {2,2}, {0LL,1LL,2LL,3LL});
test.Run();
}

TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
OpTester test1("ScatterND", 11);
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
Expand All @@ -55,6 +64,22 @@ TEST(ScatterNDOpTest, ScatterND_matrice_string_int64) {
test2.Run();
}

TEST(ScatterNDOpTest, ScatterND_matrice_string_int64_neg_indices) {
OpTester test1("ScatterND", 11);
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","bob","air","smart","terry","laugh","kite"});
test1.AddInput<int64_t>("indices", {2,1,2}, {0,-1,-1,0});
test1.AddInput<std::string>("updates", {2,1,2}, {"air","bob","terry","smart"});
test1.AddOutput<std::string>("output", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"});
test1.Run();

OpTester test2("ScatterND", 11);
test2.AddInput<std::string>("data", {3,3}, {"egg","","air","","terry","smart","laugh","","hop"});
test2.AddInput<int64_t>("indices", {3,2}, {-1,-2,1,0,0,-2});
test2.AddInput<std::string>("updates", {3}, {"kite","bob","dance"});
test2.AddOutput<std::string>("output", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"});
test2.Run();
}

TEST(ScatterNDOpTest, ScatterND_slice_float_int64_t) {
OpTester test("ScatterND", 11);
test.AddInput<float>("data", {2,2}, {0.0f,0.1f,0.1f,0.1f});
Expand All @@ -76,14 +101,14 @@ TEST(ScatterNDOpTest, ScatterND_slice_double_int64_t) {
TEST(ScatterNDOpTest, ScatterND_3tensor_int64) {
OpTester test1("ScatterND", 11);
test1.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,1LL,1LL,1LL,1LL,6LL,7LL});
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,1LL,0LL});
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,-1LL,0LL});
test1.AddInput<int64_t>("updates", {2,2}, {2LL,3LL,4LL,5LL});
test1.AddOutput<int64_t>("output", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
test1.Run();

OpTester test2("ScatterND", 11);
test2.AddInput<int8_t>("data", {2,2,2}, {0,0,2,3,4,0,6,7});
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,1,0,1});
test2.AddInput<int64_t>("indices", {2,3}, {0,0,1,-1,0,-1});
test2.AddInput<int8_t>("updates", {2}, {1,5});
test2.AddOutput<int8_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.Run();
Expand Down Expand Up @@ -142,7 +167,7 @@ TEST(ScatterNDOpTest, ScatterND_batched_3tensor_int64) {

OpTester test2("ScatterND", 11);
test2.AddInput<uint32_t>("data", {2,2,2}, {0,0,2,0,4,0,0,7});
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0});
test2.AddInput<int64_t>("indices", {2,2,3}, {0,0,-1,-1,0,-1,0,1,-1,1,-1,0});
test2.AddInput<uint32_t>("updates", {2,2}, {1,5,3,6});
test2.AddOutput<uint32_t>("output", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.Run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,34 @@ def run_step(model, x):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)

@pytest.mark.parametrize("device", ['cpu', 'cuda'])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a C++ unit test as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
[[ 2, 3, 4, 4],[ 0, 1, 4, 4]]))
def test_scatternd_correctness(device, indices):
class NeuralNetScatterND(torch.nn.Module):
def __init__(self):
super(NeuralNetScatterND, self).__init__()

def forward(self, rerouted_output, dispatch_mask, expert_output):
rerouted_output[dispatch_mask] = expert_output
return rerouted_output

pt_model = NeuralNetScatterND().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))

def run_step(model, rerouted_output, dispatch_mask, expert_output):
prediction = model(rerouted_output, dispatch_mask, expert_output)
return prediction

rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device)
dispatch_mask = torch.tensor(indices, device=device)
expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device)

pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output)
ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)


@pytest.mark.parametrize("use_fp16", [False, True])
@pytest.mark.parametrize("input_requires_grad", [False, True])
def test_gradient_correctness_conv1d(use_fp16, input_requires_grad):
Expand Down