diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 65b8466284..8a7f978ba6 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -535,7 +535,68 @@ defmodule Nx.Defn.Grad do updates = Nx.reshape(g, {num_elements}) # The intuition for this grad is that for each index taken, we'll - # add the corresponding result grad to the original + # add the corresponding result grad to the original + g = + t + |> Expr.broadcast(0, Nx.shape(t), Nx.axes(t)) + |> Nx.scatter_add(indices, updates) + + to_grad(t, g, cache) + end + + defp grad(:take, [t, i, axis], _ans, g, cache) do + axes_range = 0..(Nx.rank(t) - 1)//1 + + indices_shape = + axes_range + |> Enum.flat_map(fn + ^axis -> Tuple.to_list(i.shape) + _ -> [1] + end) + |> List.to_tuple() + + idx_tiling = + t.shape + |> Tuple.to_list() + |> Enum.with_index(fn + _x, ^axis -> + List.duplicate(1, Nx.rank(i)) + + x, _ -> + x + end) + |> List.flatten() + + num_elements = Tuple.product(g.shape) + + indices_for_axis = + i + |> Nx.reshape(indices_shape) + |> Nx.tile(idx_tiling) + + axis_offset = Nx.rank(i) - 1 + + indices = + axes_range + |> Enum.map(fn + ^axis -> + indices_for_axis + |> Nx.reshape({num_elements, 1}) + + current when current < axis -> + indices_for_axis + |> Nx.iota(axis: current, backend: Nx.Defn.Expr) + |> Nx.reshape({num_elements, 1}) + + current when current > axis -> + indices_for_axis + |> Nx.iota(axis: current + axis_offset, backend: Nx.Defn.Expr) + |> Nx.reshape({num_elements, 1}) + end) + |> Nx.concatenate(axis: 1) + + updates = Nx.reshape(g, {num_elements}) + g = t |> Expr.broadcast(0, Nx.shape(t), Nx.axes(t)) diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 5fd174e66b..7d31c69651 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -2598,6 +2598,140 @@ defmodule Nx.Defn.GradTest do end end + describe "take" do + defn grad_sum_take(t, i, axis \\ 0) do + grad( + t, + fn t -> + t + |> Nx.take(i, axis: axis) + |> Nx.sum() + end + ) + end + + defn grad_sum_take_axis_1_power(t, i) do + grad( + t, + fn t -> + t + |> Nx.power(2) + |> Nx.take(i, axis: 1) + |> Nx.sum() + end + ) + end + + defn grad_sum_log_power_take_axis_1_cos(t, i) do + grad( + t, + fn t -> + t + |> Nx.cos() + |> Nx.take(i, axis: 1) + |> Nx.power(2) + |> Nx.log() + |> Nx.sum() + end + ) + end + + test "computes gradient" do + assert Nx.tensor([ + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [6.0, 6.0, 6.0, 6.0] + ]) == + grad_sum_take( + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]), + Nx.tensor([ + [0, 1, 2, 2, 2], + [0, 1, 2, 2, 2] + ]) + ) + + assert Nx.tensor([ + [0.0, 4.0, 24.0, 0.0], + [16.0, 20.0, 72.0, 0.0], + [32.0, 36.0, 120.0, 0.0] + ]) == + grad_sum_take_axis_1_power( + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]), + Nx.tensor([ + [0, 1, 2, 2, 2], + [0, 1, 2, 2, 2] + ]) + ) + + assert Nx.tensor([ + [-0.0, -6.2296305, 26.220474, -0.0], + [-4.631285, 13.522059, 3.4920743, -0.0], + [27.198847, 1.8092626, -7.7803297, 0.0] + ]) == + grad_sum_log_power_take_axis_1_cos( + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]), + Nx.tensor([ + [0, 1, 2, 2, 2], + [0, 1, 2, 2, 2] + ]) + ) + end + + test "works with more dimensions" do + assert Nx.tensor([ + [3.0, 3.0, 3.0, 3.0], + [3.0, 3.0, 3.0, 3.0], + [3.0, 3.0, 3.0, 3.0] + ]) == + grad_sum_take( + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]), + Nx.tensor([ + [ + [[0], [1], [2]], + [[2], [1], [0]], + [[0], [1], [2]] + ] + ]) + ) + + assert Nx.tensor([ + [1.0, 2.0, 1.0, 0.0], + [1.0, 2.0, 1.0, 0.0], + [1.0, 2.0, 1.0, 0.0] + ]) == + grad_sum_take( + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]), + Nx.tensor([ + [ + [[0], [1]], + [[2], [1]] + ] + ]), + 1 + ) + end + end + describe "not implemented" do defn grad_reduce(t), do: grad(t, &Nx.reduce(&1, 0, fn x, y -> x + y end))