Skip to content

Commit

Permalink
Do not cast integer to float unintentionally (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Nov 14, 2023
1 parent f51442e commit 2a434f3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
30 changes: 18 additions & 12 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -837,15 +837,12 @@ defmodule Axon.Compiler do
# parameter map, so we just need to extract them and then apply
# freezing and dtype policy
parameter_inputs =
Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} ->
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
param = params[name][v]

cond do
param != nil and should_cast?(type, compute) ->
safe_as_type(maybe_freeze(param, frz), compute)

param != nil ->
maybe_freeze(param, frz)
safe_as_type(maybe_freeze(param, frz), compute)

true ->
raise ArgumentError,
Expand Down Expand Up @@ -936,8 +933,11 @@ defmodule Axon.Compiler do
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
%{stateful | output: out}

out ->
%Nx.Tensor{} = out ->
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})

out ->
out
end
rescue
exception ->
Expand Down Expand Up @@ -1082,17 +1082,23 @@ defmodule Axon.Compiler do
none

%Nx.Tensor{} = tensor ->
Nx.as_type(tensor, type)
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end

container ->
deep_new(container, &Nx.as_type(&1, type))
deep_new(container, fn tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end
end)
end
end

defp should_cast?(type1, type2) do
not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ defmodule Axon.Layers do
@doc type: :linear
defn embedding(input, kernel, _opts \\ []) do
assert_rank!("Axon.Layers.embedding", "kernel", kernel, 2)
Nx.take(kernel, Nx.as_type(input, {:s, 64}), axis: 0)
Nx.take(kernel, input, axis: 0)
end

## Shape
Expand Down
4 changes: 2 additions & 2 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ defmodule CompilerTest do
test "initializes in default case" do
model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding")

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand All @@ -615,7 +615,7 @@ defmodule CompilerTest do
Axon.input("input_0", shape: {nil, 1})
|> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros)

input = random({1, 1})
input = random({1, 1}) |> Nx.as_type(:s64)

assert {init_fn, _predict_fn} = Axon.build(model1)
assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})
Expand Down

0 comments on commit 2a434f3

Please sign in to comment.