Skip to content

Commit

Permalink
feat: add take back
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Sep 20, 2021
1 parent e69ebbf commit eb2c21c
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 1 deletion.
63 changes: 62 additions & 1 deletion nx/lib/nx/defn/grad.ex
Expand Up @@ -535,7 +535,68 @@ defmodule Nx.Defn.Grad do
updates = Nx.reshape(g, {num_elements})

# The intuition for this grad is that for each index taken, we'll
# add the corresponding result grad to the original
# add the corresponding result grad to the original
g =
t
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
|> Nx.scatter_add(indices, updates)

to_grad(t, g, cache)
end

defp grad(:take, [t, i, axis], _ans, g, cache) do
axes_range = 0..(Nx.rank(t) - 1)//1

indices_shape =
axes_range
|> Enum.flat_map(fn
^axis -> Tuple.to_list(i.shape)
_ -> [1]
end)
|> List.to_tuple()

idx_tiling =
t.shape
|> Tuple.to_list()
|> Enum.with_index(fn
_x, ^axis ->
List.duplicate(1, Nx.rank(i))

x, _ ->
x
end)
|> List.flatten()

num_elements = Tuple.product(g.shape)

indices_for_axis =
i
|> Nx.reshape(indices_shape)
|> Nx.tile(idx_tiling)

axis_offset = Nx.rank(i) - 1

indices =
axes_range
|> Enum.map(fn
^axis ->
indices_for_axis
|> Nx.reshape({num_elements, 1})

current when current < axis ->
indices_for_axis
|> Nx.iota(axis: current, backend: Nx.Defn.Expr)
|> Nx.reshape({num_elements, 1})

current when current > axis ->
indices_for_axis
|> Nx.iota(axis: current + axis_offset, backend: Nx.Defn.Expr)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)

updates = Nx.reshape(g, {num_elements})

g =
t
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
Expand Down
134 changes: 134 additions & 0 deletions nx/test/nx/defn/grad_test.exs
Expand Up @@ -2598,6 +2598,140 @@ defmodule Nx.Defn.GradTest do
end
end

describe "take" do
defn grad_sum_take(t, i, axis \\ 0) do
grad(
t,
fn t ->
t
|> Nx.take(i, axis: axis)
|> Nx.sum()
end
)
end

defn grad_sum_take_axis_1_power(t, i) do
grad(
t,
fn t ->
t
|> Nx.power(2)
|> Nx.take(i, axis: 1)
|> Nx.sum()
end
)
end

defn grad_sum_log_power_take_axis_1_cos(t, i) do
grad(
t,
fn t ->
t
|> Nx.cos()
|> Nx.take(i, axis: 1)
|> Nx.power(2)
|> Nx.log()
|> Nx.sum()
end
)
end

test "computes gradient" do
assert Nx.tensor([
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
[6.0, 6.0, 6.0, 6.0]
]) ==
grad_sum_take(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[0, 1, 2, 2, 2],
[0, 1, 2, 2, 2]
])
)

assert Nx.tensor([
[0.0, 4.0, 24.0, 0.0],
[16.0, 20.0, 72.0, 0.0],
[32.0, 36.0, 120.0, 0.0]
]) ==
grad_sum_take_axis_1_power(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[0, 1, 2, 2, 2],
[0, 1, 2, 2, 2]
])
)

assert Nx.tensor([
[-0.0, -6.2296305, 26.220474, -0.0],
[-4.631285, 13.522059, 3.4920743, -0.0],
[27.198847, 1.8092626, -7.7803297, 0.0]
]) ==
grad_sum_log_power_take_axis_1_cos(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[0, 1, 2, 2, 2],
[0, 1, 2, 2, 2]
])
)
end

test "works with more dimensions" do
assert Nx.tensor([
[3.0, 3.0, 3.0, 3.0],
[3.0, 3.0, 3.0, 3.0],
[3.0, 3.0, 3.0, 3.0]
]) ==
grad_sum_take(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[
[[0], [1], [2]],
[[2], [1], [0]],
[[0], [1], [2]]
]
])
)

assert Nx.tensor([
[1.0, 2.0, 1.0, 0.0],
[1.0, 2.0, 1.0, 0.0],
[1.0, 2.0, 1.0, 0.0]
]) ==
grad_sum_take(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[
[[0], [1]],
[[2], [1]]
]
]),
1
)
end
end

describe "not implemented" do
defn grad_reduce(t), do: grad(t, &Nx.reduce(&1, 0, fn x, y -> x + y end))

Expand Down

0 comments on commit eb2c21c

Please sign in to comment.