Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not cast integer to float unintentionally #547

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -837,15 +837,12 @@ defmodule Axon.Compiler do
# parameter map, so we just need to extract them and then apply
# freezing and dtype policy
parameter_inputs =
Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} ->
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
param = params[name][v]

cond do
param != nil and should_cast?(type, compute) ->
safe_as_type(maybe_freeze(param, frz), compute)

param != nil ->
maybe_freeze(param, frz)
safe_as_type(maybe_freeze(param, frz), compute)

true ->
raise ArgumentError,
Expand Down Expand Up @@ -936,8 +933,11 @@ defmodule Axon.Compiler do
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
%{stateful | output: out}

out ->
%Nx.Tensor{} = out ->
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})

out ->
out
end
rescue
exception ->
Expand Down Expand Up @@ -1082,17 +1082,23 @@ defmodule Axon.Compiler do
none

%Nx.Tensor{} = tensor ->
Nx.as_type(tensor, type)
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end

container ->
deep_new(container, &Nx.as_type(&1, type))
deep_new(container, fn tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end
end)
end
end

defp should_cast?(type1, type2) do
not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ defmodule Axon.Layers do
@doc type: :linear
defn embedding(input, kernel, _opts \\ []) do
assert_rank!("Axon.Layers.embedding", "kernel", kernel, 2)
Nx.take(kernel, Nx.as_type(input, {:s, 64}), axis: 0)
Nx.take(kernel, input, axis: 0)
end

## Shape
Expand Down
4 changes: 2 additions & 2 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ defmodule CompilerTest do
test "initializes in default case" do
model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding")

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand All @@ -615,7 +615,7 @@ defmodule CompilerTest do
Axon.input("input_0", shape: {nil, 1})
|> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros)

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model1)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand Down
Loading