Skip to content

Commit

Permalink
refactor: use binary-matching syntax for scatter-add impl (#464)
Browse files Browse the repository at this point in the history
* refactor: use binary-matching syntax for scatter-add impl

* fix: type coercion in scatter-add
  • Loading branch information
polvalente committed Sep 20, 2021
1 parent 432bc74 commit b09ee45
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions nx/lib/nx/binary_backend.ex
Expand Up @@ -1581,14 +1581,12 @@ defmodule Nx.BinaryBackend do
indices_bin_list =
indices |> to_binary() |> aggregate_axes([1], indices_shape, elem(indices.type, 1))

target_byte_size = div(target_size, 8)

offsets_list =
match_types [indices.type] do
for idx_bin <- indices_bin_list do
idx = for <<match!(x, 0) <- idx_bin>>, do: read!(x, 0)
offset = index_to_binary_offset(idx, shape)
offset * target_byte_size
offset * target_size
end
end

Expand All @@ -1604,7 +1602,7 @@ defmodule Nx.BinaryBackend do
|> Enum.sort_by(fn {off, _} -> off end)
|> Enum.map_reduce(0, fn {next_offset, upds}, previous_offset ->
{{
previous_offset + target_byte_size,
previous_offset + target_size,
next_offset,
Enum.sum(upds)
}, next_offset}
Expand All @@ -1616,29 +1614,33 @@ defmodule Nx.BinaryBackend do
for {previous, current, update} <- offsets_with_updates, reduce: {<<>>, target_binary} do
{traversed, to_traverse} ->
before_slice_size = max(current - previous, 0)
before_offset = binary_part(to_traverse, 0, before_slice_size)

# this can be a list of binaries because we are accumulation an iodata list
before_offset =
match_types [target.type] do
for <<match!(x, 0) <- before_offset>>, do: scalar_to_binary(read!(x, 0), out.type)
end

element = binary_part(to_traverse, before_slice_size, target_byte_size)
match_types [target.type, out.type] do
<<before_offset::bitstring-size(before_slice_size), match!(element, 0),
to_traverse::bitstring>> = to_traverse

total_size = byte_size(to_traverse)
# this can be a list of binaries because we are accumulation an iodata list
before_offset =
if target.type == out.type do
before_offset
else
for <<match!(x, 0) <- before_offset>>, do: scalar_to_binary(read!(x, 0), out.type)
end

to_traverse =
binary_part(
to_traverse,
before_slice_size + target_byte_size,
total_size - (before_slice_size + target_byte_size)
)
updated_element = <<write!(read!(element, 0) + update, 1)>>

updated_element =
scalar_to_binary(binary_to_number(element, target.type) + update, out.type)
{[traversed | [before_offset, updated_element]], to_traverse}
end
end

{[traversed | [before_offset, updated_element]], to_traverse}
# this can be a list of binaries because we are accumulation an iodata list
tail =
match_types [target.type] do
if target.type == out.type do
tail
else
for <<match!(x, 0) <- tail>>, do: scalar_to_binary(read!(x, 0), out.type)
end
end

from_binary(out, IO.iodata_to_binary([result, tail]))
Expand Down

0 comments on commit b09ee45

Please sign in to comment.