diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3bb478258d..fece273800 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1141,8 +1141,7 @@ defmodule EXLA.Defn do end batch_size = tensor_rank - length(axes) - offset_size = indices_rank - length(axes) - offset_dims = count_up(batch_size, offset_size) + offset_dims = count_up(batch_size, index_vector_dim) Value.gather( tensor, diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 7f99c8bc6f..e917c68f72 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -225,6 +225,42 @@ defmodule EXLA.BackendTest do "1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i" end + test "gather vectorized regression" do + gradients = + Nx.tensor( + [ + [1.0, 1.0], + [-1.0, 1.0], + [1.0, -1.0], + [-1.0, -1.0] + ], + backend: EXLA.Backend + ) + + i = + Nx.tensor([[0, 2, 3, 2, 2, 2, 2, 1]], type: {:u, 16}, backend: EXLA.Backend) + |> Nx.vectorize([:x, :octaves]) + + result = Nx.gather(gradients, Nx.reshape(i, {1})) + + assert_equal( + result, + Nx.tensor([ + [ + [1.0, 1.0], + [1.0, -1.0], + [-1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [-1.0, 1.0] + ] + ]) + |> Nx.vectorize([:x, :octaves]) + ) + end + describe "quantized types" do test "s2" do tensor = Nx.s2(-1)