Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Require Erlang/OTP 24 and Elixir v1.12 (#416)
* Require latest Erlang and Elixir

* Add support for half-precision floats

* Run the formatter
  • Loading branch information
josevalim committed May 19, 2021
1 parent 0fb0080 commit c1e1508
Show file tree
Hide file tree
Showing 20 changed files with 144 additions and 239 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Expand Up @@ -13,8 +13,8 @@ jobs:
working-directory: "nx"
strategy:
matrix:
elixir: [1.11.4]
otp: [23.3.4]
elixir: ["1.12.0-rc.1"]
otp: ["24.0"]
steps:
- uses: actions/checkout@v2
- uses: erlef/setup-beam@v1
Expand Down
21 changes: 6 additions & 15 deletions exla/lib/exla/defn.ex
Expand Up @@ -134,11 +134,10 @@ defmodule EXLA.Defn do
defp to_root_computation(key, expr, pos_shapes, options) do
builder = EXLA.Builder.new(inspect(key))

# TODO: Use Enum.with_index on Elixir v1.12
params =
for {{pos, shape}, i} <- Enum.with_index(pos_shapes) do
Enum.with_index(pos_shapes, fn {pos, shape}, i ->
{pos, EXLA.Op.parameter(builder, i, shape, "p#{i}")}
end
end)

state = %{
precision: Keyword.get(options, :precision, :default),
Expand Down Expand Up @@ -710,12 +709,7 @@ defmodule EXLA.Defn do
all_static? = Enum.all?(start_indices, &is_integer/1)

if all_static? do
# TODO: Use Enum.zip_with on Elixir v1.12
limit_indices =
start_indices
|> Enum.zip(lengths)
|> Enum.map(fn {i, len} -> i + len end)

limit_indices = Enum.zip_with(start_indices, lengths, fn i, len -> i + len end)
EXLA.Op.slice(tensor, start_indices, limit_indices, strides)
else
zeros = List.duplicate(0, tuple_size(ans.shape))
Expand Down Expand Up @@ -812,12 +806,11 @@ defmodule EXLA.Defn do
defp to_computation(name, args, state, fun) do
subbuilder = subbuilder(state.builder, Atom.to_string(name))

# TODO: Use Enum.with_index on Elixir v1.12
arg_params =
for {arg, i} <- Enum.with_index(args) do
Enum.with_index(args, fn arg, i ->
fun_shape = computation_arg_shape(arg)
{arg, EXLA.Op.parameter(subbuilder, i, fun_shape, "p#{i}")}
end
end)

{_, params} = Enum.reduce(arg_params, {0, []}, &computation_arg_param(&1, &2))
state = %{state | builder: subbuilder, params: Map.new(params)}
Expand All @@ -840,11 +833,9 @@ defmodule EXLA.Defn do
end

defp computation_arg_param({tuple, param}, counter_acc) do
# TODO: Use Enum.with_index on Elixir v1.12
tuple
|> Tuple.to_list()
|> Enum.with_index()
|> Enum.map(fn {arg, i} -> {arg, EXLA.Op.get_tuple_element(param, i)} end)
|> Enum.with_index(fn arg, i -> {arg, EXLA.Op.get_tuple_element(param, i)} end)
|> Enum.reduce(counter_acc, &computation_arg_param/2)
end

Expand Down
9 changes: 3 additions & 6 deletions exla/lib/exla/executable.ex
Expand Up @@ -133,14 +133,11 @@ defmodule EXLA.Executable do
defp decompose_output(data, shape, client) do
%Shape{dtype: {:t, shapes}} = shape

# TODO: Use Enum.zip_with on Elixir v1.12
data
|> Enum.zip(shapes)
|> Enum.map(fn
{buf, subshape} when is_reference(buf) ->
Enum.zip_with(data, shapes, fn
buf, subshape when is_reference(buf) ->
Buffer.buffer({buf, client.name}, subshape)

{buf, subshape} ->
buf, subshape ->
Buffer.buffer(buf, subshape)
end)
end
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Expand Up @@ -9,7 +9,7 @@ defmodule EXLA.MixProject do
app: :exla,
name: "EXLA",
version: @version,
elixir: "~> 1.11",
elixir: "~> 1.12-dev",
deps: deps(),
docs: docs(),
compilers: [:elixir_make] ++ Mix.compilers(),
Expand Down
44 changes: 27 additions & 17 deletions nx/lib/nx.ex
Expand Up @@ -180,11 +180,13 @@ defmodule Nx do
]
>
Negative positions in ranges read from the back. The right-side of
the range must be equal or greater than the left-side:
Ranges can receive negative positions and they will read from
the back. In such cases, the range step must be explicitly given
and the right-side of the range must be equal or greater than
the left-side:
iex> t = Nx.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
iex> t[1..-2]
iex> t[1..-2//1]
#Nx.Tensor<
s64[2][2]
[
Expand All @@ -198,7 +200,7 @@ defmodule Nx do
axes with ranges, it is often desired to use a list:
iex> t = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
iex> t[[1..-2, 1..2]]
iex> t[[1..-2//1, 1..2]]
#Nx.Tensor<
s64[2][2]
[
Expand All @@ -210,14 +212,14 @@ defmodule Nx do
You can mix both ranges and integers in the list too:
iex> t = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
iex> t[[1..-2, 2]]
iex> t[[1..-2//1, 2]]
#Nx.Tensor<
s64[2]
[6, 9]
>
If the list has less elements than axes, the remaining dimensions
are returned in full (equivalent to 0..-1):
are returned in full:
iex> t = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
iex> t[[1..2]]
Expand Down Expand Up @@ -352,7 +354,7 @@ defmodule Nx do
[44, 45, 46]
>
Mixed types get the highest precision type:
Mixed types give higher priority to floats:
iex> Nx.tensor([1, 2, 3.0])
#Nx.Tensor<
Expand Down Expand Up @@ -398,9 +400,24 @@ defmodule Nx do
]
>
Besides single-precision (32 bits), floats can also have
half-precision (16) or double-precision (64):
iex> Nx.tensor([1, 2, 3], type: {:f, 16})
#Nx.Tensor<
f16[3]
[1.0, 2.0, 3.0]
>
iex> Nx.tensor([1, 2, 3], type: {:f, 64})
#Nx.Tensor<
f64[3]
[1.0, 2.0, 3.0]
>
Brain-floating points are also supported, although they are
emulated in Elixir and therefore perform slower without a
compilation backend:
native backend:
iex> Nx.tensor([1, 2, 3], type: {:bf, 16})
#Nx.Tensor<
Expand Down Expand Up @@ -2319,7 +2336,7 @@ defmodule Nx do
"""
@doc type: :shape
def size(shape) when is_tuple(shape), do: tuple_product(shape, tuple_size(shape))
def size(shape) when is_tuple(shape), do: Tuple.product(shape)
def size(tensor), do: size(shape(tensor))

@doc """
Expand Down Expand Up @@ -2363,10 +2380,6 @@ defmodule Nx do
defp count_up(0, _n), do: []
defp count_up(i, n), do: [n | count_up(i - 1, n + 1)]

# TODO: Use Tuple.product on Elixir v1.12
defp tuple_product(_tuple, 0), do: 1
defp tuple_product(tuple, i), do: :erlang.element(i, tuple) * tuple_product(tuple, i - 1)

## Backend API

@backend_key {Nx, :default_backend}
Expand Down Expand Up @@ -7992,13 +8005,10 @@ defmodule Nx do
defp to_indices(start_indices) do
all_static? = Enum.all?(start_indices, &is_integer/1)

# TODO: Use Enum.with_index/3 in Elixir v1.12
if all_static? do
start_indices
else
start_indices
|> Enum.with_index()
|> Enum.map(fn {index, i} ->
Enum.with_index(start_indices, fn index, i ->
%T{shape: idx_shape, type: idx_type} = t = to_tensor(index)

unless idx_shape == {} do
Expand Down
55 changes: 23 additions & 32 deletions nx/lib/nx/backend.ex
Expand Up @@ -203,46 +203,37 @@ defmodule Nx.Backend do
<<x::float-little-32>> = <<0::16, bf16::binary>>
Float.to_string(x)
end

defp inspect_float(data, 32) do
case data do
<<0xFF800000::32-native>> -> "-Inf"
<<0x7F800000::32-native>> -> "Inf"
<<_::16, 1::1, _::7, _sign::1, 0x7F::7>> -> "NaN"
<<x::float-32-native>> -> Float.to_string(x)
end
end

defp inspect_float(data, 64) do
case data do
<<0x7FF0000000000000::64-native>> -> "Inf"
<<0xFFF0000000000000::64-native>> -> "-Inf"
<<_::48, 0xF::4, _::4, _sign::1, 0x7F::7>> -> "NaN"
<<x::float-64-native>> -> Float.to_string(x)
end
end
else
defp inspect_bf16(bf16) do
<<x::float-big-32>> = <<bf16::binary, 0::16>>
Float.to_string(x)
end
end

defp inspect_float(data, 32) do
case data do
<<0xFF800000::32-native>> -> "-Inf"
<<0x7F800000::32-native>> -> "Inf"
<<_sign::1, 0x7F::7, 1::1, _::7, _::16>> -> "NaN"
<<x::float-32-native>> -> Float.to_string(x)
end
defp inspect_float(data, 16) do
case data do
<<0xFC00::16-native>> -> "-Inf"
<<0x7C00::16-native>> -> "Inf"
<<x::float-16-native>> -> Float.to_string(x)
_ -> "NaN"
end
end

defp inspect_float(data, 64) do
case data do
<<0x7FF0000000000000::64-native>> -> "Inf"
<<0xFFF0000000000000::64-native>> -> "-Inf"
<<_sign::1, 0x7F::7, 0xF::4, _::4, _::48>> -> "NaN"
<<x::float-64-native>> -> Float.to_string(x)
end
defp inspect_float(data, 32) do
case data do
<<0xFF800000::32-native>> -> "-Inf"
<<0x7F800000::32-native>> -> "Inf"
<<x::float-32-native>> -> Float.to_string(x)
_ -> "NaN"
end
end

defp inspect_float(data, 64) do
case data do
<<0x7FF0000000000000::64-native>> -> "Inf"
<<0xFFF0000000000000::64-native>> -> "-Inf"
<<x::float-64-native>> -> Float.to_string(x)
_ -> "NaN"
end
end
end
32 changes: 8 additions & 24 deletions nx/lib/nx/binary_backend.ex
Expand Up @@ -688,20 +688,9 @@ defmodule Nx.BinaryBackend do
defp element_remainder(_, a, b) when is_integer(a) and is_integer(b), do: rem(a, b)
defp element_remainder(_, a, b), do: :math.fmod(a, b)

defp element_power({type, _}, a, b) when type in [:s, :u], do: integer_pow(a, b)
defp element_power({type, _}, a, b) when type in [:s, :u], do: Integer.pow(a, b)
defp element_power(_, a, b), do: :math.pow(a, b)

# TODO: Use Integer.pow on Elixir v1.12
defp integer_pow(base, exponent) when is_integer(base) and is_integer(exponent) do
if exponent < 0, do: :erlang.error(:badarith, [base, exponent])
guarded_pow(base, exponent)
end

defp guarded_pow(_, 0), do: 1
defp guarded_pow(b, 1), do: b
defp guarded_pow(b, e) when (e &&& 1) == 0, do: guarded_pow(b * b, e >>> 1)
defp guarded_pow(b, e), do: b * guarded_pow(b * b, e >>> 1)

defp element_bitwise_and(_, a, b), do: :erlang.band(a, b)
defp element_bitwise_or(_, a, b), do: :erlang.bor(a, b)
defp element_bitwise_xor(_, a, b), do: :erlang.bxor(a, b)
Expand Down Expand Up @@ -950,11 +939,10 @@ defmodule Nx.BinaryBackend do
# as a list because it is handled in the Nx module before
# lowering to the implementation; however, the padding
# configuration only accounts for spatial dims
# TODO: Use Enum.zip_with on Elixir v1.12
spatial_padding_config =
padding
|> Enum.zip(input_dilation)
|> Enum.map(fn {{lo, hi}, dilation} -> {lo, hi, dilation - 1} end)
Enum.zip_with(padding, input_dilation, fn {lo, hi}, dilation ->
{lo, hi, dilation - 1}
end)

padding_config = [
{0, 0, 0},
Expand Down Expand Up @@ -1623,11 +1611,10 @@ defmodule Nx.BinaryBackend do

# The weighted shape is altered such that we traverse
# with respect to the stride for each dimension
# TODO: Use Enum.zip_with on Elixir v1.12
weighted_shape =
weighted_shape
|> Enum.zip(strides)
|> Enum.map(fn {{d, dim_size}, s} -> {d, dim_size + (s - 1) * dim_size} end)
Enum.zip_with(weighted_shape, strides, fn {d, dim_size}, s ->
{d, dim_size + (s - 1) * dim_size}
end)

input_data = to_binary(tensor)

Expand All @@ -1639,10 +1626,7 @@ defmodule Nx.BinaryBackend do
end

defp clamp_indices(start_indices, shape, lengths) do
# TODO: Use Enum.zip_with on Elixir v1.12
[Tuple.to_list(shape), start_indices, lengths]
|> Enum.zip()
|> Enum.map(fn {dim_size, idx, len} ->
Enum.zip_with([Tuple.to_list(shape), start_indices, lengths], fn [dim_size, idx, len] ->
idx = to_scalar(idx)
min(max(idx, 0), dim_size - len)
end)
Expand Down
13 changes: 4 additions & 9 deletions nx/lib/nx/binary_backend/matrix.ex
Expand Up @@ -403,15 +403,12 @@ defmodule Nx.BinaryBackend.Matrix do

# This function also sorts singular values from highest to lowest,
# as this can be convenient.

# TODO: Use Enum.zip_with on Elixir v1.12
s
|> Enum.zip(transpose_matrix(v))
|> Enum.map(fn
{singular_value, row} when singular_value < 0 ->
|> Enum.zip_with(transpose_matrix(v), fn
singular_value, row when singular_value < 0 ->
{-singular_value, Enum.map(row, &(&1 * -1))}

{singular_value, row} ->
singular_value, row ->
{singular_value, row}
end)
|> Enum.sort_by(fn {s, _} -> s end, &>=/2)
Expand Down Expand Up @@ -645,9 +642,7 @@ defmodule Nx.BinaryBackend.Matrix do
end

defp transpose_matrix(m) do
m
|> Enum.zip()
|> Enum.map(&Tuple.to_list/1)
Enum.zip_with(m, & &1)
end

defp matrix_to_binary([r | _] = m, type) when is_list(r) do
Expand Down

0 comments on commit c1e1508

Please sign in to comment.