From 7c36e062061161a778ce207b2a7564b778e1af7e Mon Sep 17 00:00:00 2001 From: Benjamin Philip Date: Mon, 13 May 2024 14:53:03 +0530 Subject: [PATCH] Make take_along_axis an optional callback Closes #1440. --- exla/lib/exla/backend.ex | 1 - exla/lib/exla/defn.ex | 37 ------------------------------- nx/lib/nx.ex | 25 ++++++++++++++++----- nx/lib/nx/backend.ex | 5 +++-- nx/lib/nx/binary_backend.ex | 37 ------------------------------- nx/lib/nx/defn/expr.ex | 6 ------ nx/lib/nx/defn/grad.ex | 42 +----------------------------------- torchx/lib/torchx/backend.ex | 4 ++-- 8 files changed, 26 insertions(+), 131 deletions(-) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 460ceea900..581ef613c5 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -325,7 +325,6 @@ defmodule EXLA.Backend do {:reverse, [:tensor, :axes], [:tensor]}, {:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]}, {:clip, [:tensor, :min, :max], [:tensor, :min, :max]}, - {:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]}, {:gather, [:input, :indices, :opts], [:input, :indices]}, {:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]}, {:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]}, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 943fcb9f29..310fa40463 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1259,43 +1259,6 @@ defmodule EXLA.Defn do Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans)) end - defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do - %{shape: indices_shape} = indices_typespec = Value.get_typespec(indices) - indices_rank = tuple_size(indices_shape) - - axes_range = 0..(indices_rank - 1)//1 - - index_vector_dim = indices_rank - slice_sizes = List.duplicate(1, indices_rank) - offset_dims = [] - collapsed_slice_dims = Enum.to_list(axes_range) - start_index_map = Enum.to_list(axes_range) - - new_axis_typespec = Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, 1)) - - full_indices_typespec = - Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, indices_rank)) - - full_indices = - axes_range - |> Enum.map(fn - ^axis -> Value.reshape(indices, new_axis_typespec) - axis -> Value.iota(state.builder, axis, new_axis_typespec) - end) - |> Value.concatenate(indices_rank, full_indices_typespec) - - Value.gather( - tensor, - full_indices, - index_vector_dim, - slice_sizes, - offset_dims, - collapsed_slice_dims, - start_index_map, - expr_to_typespec(ans) - ) - end - defp to_operator(:gather, [%Value{} = tensor, indices, opts], ans, _state) do axes = Keyword.fetch!(opts, :axes) tensor_shape = op_shape(tensor) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index add9fcd605..d134a69093 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -14142,7 +14142,7 @@ defmodule Nx do tensor |> gather(gather_indices, axes: [axis]) |> transpose(axes: transpose_axes) - |> reshape(inner_shape, names: inner_names) + |> rename(inner_names) end) end end @@ -14302,17 +14302,32 @@ defmodule Nx do end opts = keyword!(opts, axis: 0) - tensor = devectorize(tensor, keep_names: false) indices = devectorize(indices, keep_names: false) - offset = length(vectorized_axes) axis = Nx.Shape.normalize_axis(tensor.shape, opts[:axis], tensor.names, offset) - shape = Nx.Shape.take_along_axis(tensor.shape, indices.shape, axis) + out = %{tensor | shape: shape} - result = impl!(tensor).take_along_axis(%{tensor | shape: shape}, tensor, indices, axis) + result = + Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn + tensor, indices, _opts -> + axes_range = axes(indices) + new_axis_shape = Tuple.append(shape(indices), 1) + + full_indices = + axes_range + |> Enum.map(fn + ^axis -> reshape(indices, new_axis_shape) + axis -> iota(new_axis_shape, axis: axis) + end) + |> concatenate(axis: rank(indices)) + + tensor + |> gather(full_indices) + |> rename(tensor.names) + end) vectorize(result, vectorized_axes) end diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index a3d9676ac8..0c2b573a99 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -73,7 +73,6 @@ defmodule Nx.Backend do @callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor @callback slice(out :: tensor, tensor, list, list, list) :: tensor @callback put_slice(out :: tensor, tensor, tensor, list) :: tensor - @callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor @callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor @callback concatenate(out :: tensor, tensor, axis) :: tensor @callback select(out :: tensor, tensor, tensor, tensor) :: tensor @@ -159,6 +158,7 @@ defmodule Nx.Backend do @callback all_close(out :: tensor, tensor, tensor, keyword) :: tensor @callback top_k(out :: tensor, tensor, keyword) :: tensor @callback take(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor + @callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor @optional_callbacks [ optional: 3, @@ -178,7 +178,8 @@ defmodule Nx.Backend do qr: 3, cholesky: 2, eigh: 3, - take: 4 + take: 4, + take_along_axis: 4 ] ## Inspect implementation diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index a4f5de664a..f24c5e2dda 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1939,43 +1939,6 @@ defmodule Nx.BinaryBackend do from_binary(out, data) end - @impl true - def take_along_axis( - %T{type: output_type} = output, - %T{shape: t_shape, type: {_, t_size} = t_type} = tensor, - %T{shape: idx_shape, type: {_, idx_size} = idx_type} = indices, - axis - ) do - permutation = - tensor - |> Nx.axes() - |> List.delete(axis) - |> List.insert_at(Nx.rank(tensor) - 1, axis) - - inverse_permutation = inverse_permutation(permutation) - shape_list = Tuple.to_list(output.shape) - permuted_shape = permutation |> Enum.map(&Enum.at(shape_list, &1)) |> List.to_tuple() - - t_view = tensor |> to_binary() |> aggregate_axes([axis], t_shape, t_size) - idx_view = indices |> to_binary() |> aggregate_axes([axis], idx_shape, idx_size) - - [t_view, idx_view] - |> Enum.zip_with(fn [data_bin, idx_bin] -> - data = binary_to_list(data_bin, t_type) - - binary_to_binary(idx_bin, idx_type, output_type, fn idx -> - if idx < 0 or idx >= elem(tensor.shape, axis) do - raise ArgumentError, - "index #{idx} is out of bounds for axis #{axis} in shape #{inspect(tensor.shape)}" - end - - Enum.at(data, idx) - end) - end) - |> then(&from_binary(%{output | shape: permuted_shape}, &1)) - |> then(&transpose(output, &1, inverse_permutation)) - end - @impl true def gather(out, tensor, indices, opts) do axes = opts[:axes] diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index b5fb56bc72..f534c04bee 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do expr(out, context, :put_slice, [tensor, start, slice]) end - @impl true - def take_along_axis(out, tensor, indices, axis) do - {[tensor, indices], context} = to_exprs([tensor, indices]) - expr(out, context, :take_along_axis, [tensor, indices, axis]) - end - @impl true def gather(out, tensor, indices, opts) do {[tensor, indices], context} = to_exprs([tensor, indices]) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index c04ed72679..7fab315a70 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -150,9 +150,6 @@ defmodule Nx.Defn.Grad do defp reduce_args(:put_slice, %{data: %{args: [arg, _, update | _]}}, acc, fun), do: fun.(arg, fun.(update, acc)) - defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun), - do: fun.(arg, acc) - defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun), do: fun.(arg, acc) @@ -663,44 +660,6 @@ defmodule Nx.Defn.Grad do [{t, g}] end - defp grad(:take_along_axis, [t, i, axis], _ans, g) do - num_elements = i |> Nx.shape() |> Tuple.product() - - # Convert `i`, the take_along_axis indices, to a list of - # fully qualified (i.e. [0, 2, 1] for a {_, _, _}-shaped tensor) - # indices - - indices = - 0..(Nx.rank(g) - 1)//1 - |> Enum.map(fn - # For the axis of interest, we'll use the actual take_along_axis indices - ^axis -> - Nx.reshape(i, {num_elements, 1}) - - axis -> - i - |> Nx.shape() - |> Nx.iota(axis: axis) - |> Nx.reshape({num_elements, 1}) - end) - |> Nx.concatenate(axis: 1) - - # Since g is produced through the given indices, - # we can reshape g to be a {num_elements} shaped tensor - # which will directly correspond to each of the reshaped - # indices above - updates = Nx.reshape(g, {num_elements}) - - # The intuition for this grad is that for each index taken, we'll - # add the corresponding result grad to the original - g = - t - |> Expr.broadcast(0, Nx.shape(t), Nx.axes(t)) - |> Nx.indexed_add(indices, updates) - - [{t, g}] - end - defp grad(:gather, [t, i, opts], _ans, g) do i_axes = opts[:axes] i_shape = i.shape @@ -714,6 +673,7 @@ defmodule Nx.Defn.Grad do g = 0 + |> Nx.as_type(t.type) |> Nx.broadcast(t_shape) |> Nx.indexed_add(indices, updates, opts) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 80b5127454..3c771f6fa7 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -425,12 +425,12 @@ defmodule Torchx.Backend do end @impl true - def take_along_axis(out, tensor, idx, axis) do + def take_along_axis(out, tensor, idx, opts) do idx_tx = idx |> from_nx() |> Torchx.to_type(:long) tensor |> from_nx() - |> Torchx.gather(idx_tx, axis) + |> Torchx.gather(idx_tx, opts[:axis]) |> to_nx(out) end