diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 4da5c7cf..7f34664f 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -837,15 +837,12 @@ defmodule Axon.Compiler do # parameter map, so we just need to extract them and then apply # freezing and dtype policy parameter_inputs = - Enum.map(layer_params, fn %{type: type, name: v, frozen: frz} -> + Enum.map(layer_params, fn %{name: v, frozen: frz} -> param = params[name][v] cond do - param != nil and should_cast?(type, compute) -> - safe_as_type(maybe_freeze(param, frz), compute) - param != nil -> - maybe_freeze(param, frz) + safe_as_type(maybe_freeze(param, frz), compute) true -> raise ArgumentError, @@ -936,8 +933,11 @@ defmodule Axon.Compiler do out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name}) %{stateful | output: out} - out -> + %Nx.Tensor{} = out -> Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name}) + + out -> + out end rescue exception -> @@ -1082,17 +1082,23 @@ defmodule Axon.Compiler do none %Nx.Tensor{} = tensor -> - Nx.as_type(tensor, type) + if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do + Nx.as_type(tensor, type) + else + tensor + end container -> - deep_new(container, &Nx.as_type(&1, type)) + deep_new(container, fn tensor -> + if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do + Nx.as_type(tensor, type) + else + tensor + end + end) end end - defp should_cast?(type1, type2) do - not Nx.Type.integer?(type1) and not Nx.Type.integer?(type2) - end - defp safe_shape(container_or_tensor) do case container_or_tensor do %Axon.None{} = none -> diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 35999ed0..b8643192 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -1795,7 +1795,7 @@ defmodule Axon.Layers do @doc type: :linear defn embedding(input, kernel, _opts \\ []) do assert_rank!("Axon.Layers.embedding", "kernel", kernel, 2) - Nx.take(kernel, Nx.as_type(input, {:s, 64}), axis: 0) + Nx.take(kernel, input, axis: 0) end ## Shape diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index e16723d2..6f63e281 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -602,7 +602,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding") - input = random({1, 1}) + input = random({1, 1}) |> Nx.as_type(:s64) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{}) @@ -615,7 +615,7 @@ defmodule CompilerTest do Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros) - input = random({1, 1}) + input = random({1, 1}) |> Nx.as_type(:s64) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{})