diff --git a/lib/axon/activations.ex b/lib/axon/activations.ex index 1efd56ec2..38b6ee207 100644 --- a/lib/axon/activations.ex +++ b/lib/axon/activations.ex @@ -85,18 +85,16 @@ defmodule Axon.Activations do """ defn celu(x, opts \\ []) do opts = keyword!(opts, alpha: 1.0) - - transform( - opts[:alpha], - fn x -> - if Elixir.Kernel.==(x, 0), - do: raise(ArgumentError, ":alpha must be non-zero in CELU activation") - end - ) + validate_celu_alpha!(opts[:alpha]) Nx.select(Nx.greater(x, 0), x, opts[:alpha] * Nx.expm1(x / opts[:alpha])) end + deftransformp validate_celu_alpha!(alpha) do + if alpha == 0, + do: raise(ArgumentError, ":alpha must be non-zero in CELU activation") + end + @doc ~S""" Exponential linear unit activation. @@ -376,7 +374,7 @@ defmodule Axon.Activations do """ defn log_sumexp(x, opts \\ []) do opts = keyword!(opts, axis: -1) - axes = transform(opts[:axis], &List.wrap/1) + axes = wrap(opts[:axis]) # This is a scaling term designed to prevent over/under flow when x is very # large. Consider cases where the intermediate value e^x with large positive @@ -457,12 +455,6 @@ defmodule Axon.Activations do defn log_softmax(x, opts \\ []) do opts = keyword!(opts, axis: -1) - transform({x, opts}, fn {x, opts} -> - if Elixir.Kernel.<=(Nx.rank(x), opts[:axis]) do - raise ArgumentError, "log_softmax axis must be within rank of tensor" - end - end) - shifted = x - stop_grad(Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true)) shifted @@ -594,7 +586,7 @@ defmodule Axon.Activations do defn sigmoid(x) do # Cache logits so they are available in certain calculations, # e.g. binary_cross_entropy and categorical_cross_entropy - transform(Nx.sigmoid(x), &Nx.Defn.Expr.metadata(&1, %{logits: x})) + cache_logits(x, Nx.sigmoid(x)) end @doc ~S""" @@ -708,13 +700,7 @@ defmodule Axon.Activations do """ defn softmax(x, opts \\ []) do opts = keyword!(opts, axis: -1) - axes = transform(opts[:axis], &List.wrap/1) - - transform({x, axes}, fn {x, axes} -> - Enum.each(axes, fn axis -> - Nx.Shape.normalize_axis(Nx.shape(x), axis, Nx.names(x)) - end) - end) + axes = wrap(opts[:axis]) # This is a scaling term designed to prevent over/under flow when x is very # large. Consider cases where the intermediate value e^x with large positive @@ -745,7 +731,7 @@ defmodule Axon.Activations do # Cache logits so they are available in certain calculations, # e.g. binary_cross_entropy and categorical_cross_entropy - transform(res, &Nx.Defn.Expr.metadata(&1, %{logits: x})) + cache_logits(x, res) end @doc ~S""" @@ -837,4 +823,12 @@ defmodule Axon.Activations do """ defn tanh(x), do: Nx.tanh(x) + + ## Helpers + + deftransformp cache_logits(input, output) do + Nx.Defn.Expr.metadata(output, %{logits: input}) + end + + deftransformp wrap(axis), do: List.wrap(axis) end diff --git a/lib/axon/initializers.ex b/lib/axon/initializers.ex index 46041904d..a947f7a1d 100644 --- a/lib/axon/initializers.ex +++ b/lib/axon/initializers.ex @@ -684,33 +684,20 @@ defmodule Axon.Initializers do assert_min_rank!("Axon.Initializers.orthogonal", "input_shape", shape, 2) - {{m, n}, random_seed} = - transform({key, shape, distribution, type}, fn {key, shape, distribution, type} -> - flat_shape = - if tuple_size(shape) > 2 do - tuple_list = shape |> Tuple.to_list() |> Enum.reverse() - n = hd(tuple_list) - m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2)) - {m, n} - else - shape - end - - out = - case distribution do - :uniform -> - Nx.Random.uniform_split(key, 0.0, 1.0, shape: flat_shape, type: type) - - :normal -> - Nx.Random.normal_split(key, 0.0, 1.0, shape: flat_shape, type: type) - - dist -> - raise ArgumentError, - "invalid distribution #{inspect(dist)} passed to orthogonal/1" - end - - {flat_shape, out} - end) + {m, n} = get_flat_shape(shape) + + random_seed = + case distribution do + :uniform -> + Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type) + + :normal -> + Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type) + + dist -> + raise ArgumentError, + "invalid distribution #{inspect(dist)} passed to orthogonal/1" + end {q, _r} = Nx.LinAlg.qr(random_seed, mode: :complete) @@ -722,6 +709,17 @@ defmodule Axon.Initializers do rand end + deftransformp get_flat_shape(shape) do + if tuple_size(shape) > 2 do + tuple_list = shape |> Tuple.to_list() |> Enum.reverse() + n = hd(tuple_list) + m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2)) + {m, n} + else + shape + end + end + # Variance scaling branches defnp var_normal(key, variance, opts \\ []) do diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 4162d6000..9d6169d36 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -188,8 +188,8 @@ defmodule Axon.Layers do assert_equal_rank!("Axon.Layers.bilinear", "input1", input1, "input2", input2) assert_rank!("Axon.Layers.bilinear", "kernel", kernel, 3) - inp1_axes = transform(Nx.rank(input1), fn rank -> [rank - 1] end) - inp2_axes = transform(Nx.rank(input2), fn rank -> [rank - 1] end) + inp1_axes = input1 |> last_axis() |> list_wrap() + inp2_axes = input2 |> last_axis() |> list_wrap() input1 |> Nx.dot(inp1_axes, [], kernel, [1], []) @@ -197,6 +197,8 @@ defmodule Axon.Layers do |> Nx.add(bias) end + deftransformp last_axis(input), do: Nx.rank(input) - 1 + ## Convolutional @doc """ @@ -355,29 +357,8 @@ defmodule Axon.Layers do mode: :inference ) - bias_reshape = - transform( - {Nx.shape(bias), Nx.rank(input) - 2, opts[:channels]}, - fn {bias_shape, rank, channels} -> - Axon.Shape.conv_bias_reshape(bias_shape, rank, channels) - end - ) - - {permutations, kernel_permutation} = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - perm = Enum.to_list(0..(rank - 1)) - {perm, perm} - - {rank, :last} -> - spatial = Enum.to_list(1..(rank - 2)//1) - perm = [0, rank - 1 | spatial] - kernel_perm = [rank - 1, rank - 2] ++ Enum.to_list(0..(rank - 3)//1) - {perm, kernel_perm} - - {_rank, invalid} -> - raise ArgumentError, "invalid channel configuration, #{inspect(invalid)}" - end) + bias_reshape = Axon.Shape.conv_bias_reshape(input, bias, opts[:channels]) + {permutations, kernel_permutation} = Axon.Shape.conv_permutations(input, opts[:channels]) input |> Nx.conv(kernel, @@ -491,24 +472,18 @@ defmodule Axon.Layers do mode: :inference ) - strides = - transform( - {Nx.rank(input), opts[:strides]}, - fn - {_, [_ | _] = strides} -> strides - {rank, strides} -> List.duplicate(strides, rank - 2) - end - ) + strides = Axon.Shape.conv_transpose_strides(input, opts[:strides]) padding = - transform( - {Nx.shape(kernel), opts[:kernel_dilation], strides, opts[:padding], opts[:channels]}, - fn {shape, k_dilation, strides, padding, channels} -> - Axon.Shape.conv_transpose_padding(shape, k_dilation, strides, padding, channels) - end + Axon.Shape.conv_transpose_padding( + kernel, + opts[:kernel_dilation], + strides, + opts[:padding], + opts[:channels] ) - ones = transform(Nx.rank(input), &List.duplicate(1, &1 - 2)) + ones = list_duplicate(1, Nx.rank(input) - 2) conv(input, kernel, bias, strides: ones, @@ -597,14 +572,8 @@ defmodule Axon.Layers do mode: :inference ) - num_groups = - transform({Nx.shape(input), opts[:channels]}, fn - {shape, :first} -> - elem(shape, 1) - - {shape, :last} -> - elem(shape, tuple_size(shape) - 1) - end) + channel_index = channel_index_transform(input, opts[:channels]) + num_groups = Nx.axis_size(input, channel_index) conv(input, kernel, bias, strides: opts[:strides], @@ -616,6 +585,9 @@ defmodule Axon.Layers do ) end + deftransformp channel_index_transform(_input, :first), do: 1 + deftransformp channel_index_transform(input, :last), do: Nx.rank(input) - 1 + @doc """ Functional implementation of a 2-dimensional separable depthwise convolution. @@ -816,47 +788,13 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) - - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - padding = - transform( - {opts[:padding], opts[:channels]}, - fn - {:same, _} -> :same - {:valid, _} -> :valid - {padding, :first} -> [{0, 0}, {0, 0} | padding] - {padding, :last} -> [{0, 0} | padding] ++ [{0, 0}] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) input |> Nx.window_max(window_dimensions, @@ -915,51 +853,13 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) - - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) - - padding = - transform( - opts[:padding], - fn - :same -> - :same - - :valid -> - :valid + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - padding -> - [{0, 0}, {0, 0} | padding] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) input |> Nx.window_mean(window_dimensions, @@ -1041,51 +941,13 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) - - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - padding = - transform( - opts[:padding], - fn - :same -> - :same - - :valid -> - :valid - - padding -> - [{0, 0}, {0, 0} | padding] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) norm = opts[:norm] @@ -1129,26 +991,11 @@ defmodule Axon.Layers do opts = keyword!(opts, [:output_size, channels: :last, mode: :inference]) - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) Nx.window_mean(input, window_dimensions, padding: :valid, strides: window_strides) end @@ -1180,26 +1027,11 @@ defmodule Axon.Layers do opts = keyword!(opts, [:output_size, channels: :last, mode: :inference]) - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) Nx.window_max(input, window_dimensions, padding: :valid, strides: window_strides) end @@ -1239,26 +1071,11 @@ defmodule Axon.Layers do norm = opts[:norm] - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) input |> Nx.power(norm) @@ -1302,56 +1119,22 @@ defmodule Axon.Layers do defn batch_norm(input, gamma, beta, ra_mean, ra_var, opts \\ []) do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, momentum: 0.1, mode: :inference) - training? = - transform(opts[:mode], fn - :inference -> false - :train -> true - end) + axes = Axon.Shape.batch_norm_axes(input, opts[:channel_index]) - {axes, channel_index} = - transform({input, opts[:channel_index]}, fn {input, channel} -> - axes = Nx.axes(input) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel, Nx.names(input)) - {Axon.Shape.batch_norm_axes(axes, axis), axis} - end) + num_channels = Nx.axis_size(input, opts[:channel_index]) - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - elem(Nx.shape(inp), channel_idx) - end) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - {gamma, beta, ra_mean, ra_var} = - transform( - {gamma, beta, ra_mean, ra_var, Nx.rank(input), num_channels, channel_index}, - fn {g, b, m, v, rank, num_channels, channel_idx} -> - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(channel_idx, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape), Nx.reshape(m, new_shape), - Nx.reshape(v, new_shape)} - end - ) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) + ra_mean = Nx.reshape(ra_mean, parameter_shape) + ra_var = Nx.reshape(ra_var, parameter_shape) - transform( - {input, gamma, beta, ra_mean, ra_var, axes, opts[:epsilon], opts[:momentum], training?}, - fn - {x, g, b, m, v, axes, eps, alpha, true} -> - {new_mean, new_var} = mean_and_variance(x, axes: axes) - out = normalize(x, new_mean, new_var, g, b, epsilon: eps) - ra_mean = update_ema(new_mean, m, alpha) - ra_var = update_ema(new_var, v, alpha) - - %Axon.StatefulOutput{ - output: out, - state: %{"mean" => ra_mean, "var" => ra_var} - } - - {x, g, b, m, v, _, eps, _, _} -> - normalize(x, m, v, g, b, epsilon: eps) - end + stateful_normalization_mode_transform(input, gamma, beta, ra_mean, ra_var, + axes: axes, + epsilon: opts[:epsilon], + momentum: opts[:momentum], + mode: opts[:mode] ) end @@ -1383,32 +1166,12 @@ defmodule Axon.Layers do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, mode: :inference) axes = opts[:channel_index] - channel_index = opts[:channel_index] - - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) - elem(Nx.shape(inp), axis) - end) - - {gamma, beta} = - transform({gamma, beta, input, Nx.rank(input), num_channels, channel_index}, fn {g, b, - input, - rank, - num_channels, - channel_idx} -> - names = List.duplicate(nil, rank) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_idx, names) + num_channels = Nx.axis_size(input, opts[:channel_index]) - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(axis, num_channels) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape)} - end) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) {mean, var} = mean_and_variance(input, axes: [axes]) normalize(input, mean, var, gamma, beta, epsilon: opts[:epsilon]) @@ -1444,49 +1207,17 @@ defmodule Axon.Layers do defn group_norm(input, gamma, beta, opts \\ []) do opts = keyword!(opts, [:num_groups, epsilon: 1.0e-5, channel_index: -1, mode: :inference]) - channel_axis = - transform({Nx.shape(input), opts[:channel_index]}, fn - {shape, channel_index} -> - names = List.duplicate(nil, Nx.rank(shape)) - Nx.Shape.normalize_axis(shape, channel_index, names) - end) - - group_shape = - transform({Nx.shape(input), opts[:num_groups], channel_axis}, fn - {shape, groups, channel_axis} -> - Axon.Shape.group_norm_shape(shape, groups, channel_axis) - end) - - channel_index = opts[:channel_index] - - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) - elem(Nx.shape(inp), axis) - end) + group_shape = Axon.Shape.group_norm_shape(input, opts[:num_groups], opts[:channel_index]) + num_channels = Nx.axis_size(input, opts[:channel_index]) - {gamma, beta} = - transform({gamma, beta, input, Nx.rank(input), num_channels, channel_index}, fn - {g, b, inp, rank, num_channels, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(axis, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape)} - end) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) x = Nx.reshape(input, group_shape) - axes = - transform({x, channel_axis}, fn {x, channel_axis} -> - Axon.Shape.group_norm_axes(Nx.rank(x), channel_axis) - end) + axes = Axon.Shape.group_norm_axes(x, opts[:channel_index]) {mean, var} = mean_and_variance(x, axes: axes) x = (x - mean) * Nx.rsqrt(var + opts[:epsilon]) @@ -1527,59 +1258,63 @@ defmodule Axon.Layers do defn instance_norm(input, gamma, beta, ra_mean, ra_var, opts \\ []) do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, momentum: 0.1, mode: :inference) - training? = - transform(opts[:mode], fn - :inference -> false - :train -> true - end) + axes = Axon.Shape.instance_norm_axes(input, opts[:channel_index]) + num_channels = Nx.axis_size(input, opts[:channel_index]) - {axes, channel_index} = - transform({input, opts[:channel_index]}, fn {input, channel} -> - axes = Nx.axes(input) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel, Nx.names(input)) - {Axon.Shape.instance_norm_axes(axes, axis), axis} - end) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - elem(Nx.shape(inp), channel_idx) - end) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) + ra_mean = Nx.reshape(ra_mean, parameter_shape) + ra_var = Nx.reshape(ra_var, parameter_shape) - {gamma, beta, ra_mean, ra_var} = - transform( - {gamma, beta, ra_mean, ra_var, Nx.rank(input), num_channels, channel_index}, - fn {g, b, m, v, rank, num_channels, channel_idx} -> - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(channel_idx, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape), Nx.reshape(m, new_shape), - Nx.reshape(v, new_shape)} - end - ) + stateful_normalization_mode_transform(input, gamma, beta, ra_mean, ra_var, + axes: axes, + epsilon: opts[:epsilon], + momentum: opts[:momentum], + mode: opts[:mode] + ) + end - transform( - {input, gamma, beta, ra_mean, ra_var, axes, opts[:epsilon], opts[:momentum], training?}, - fn - {x, g, b, m, v, axes, eps, alpha, true} -> - {new_mean, new_var} = mean_and_variance(x, axes: axes) - out = normalize(x, new_mean, new_var, g, b, epsilon: eps) - ra_mean = update_ema(new_mean, m, alpha) - ra_var = update_ema(new_var, v, alpha) - - %Axon.StatefulOutput{ - output: out, - state: %{"mean" => ra_mean, "var" => ra_var} - } - - {x, g, b, m, v, _, eps, _, _} -> - normalize(x, m, v, g, b, epsilon: eps) - end + deftransformp norm_parameter_reshape(input, num_channels, channel_index) do + 1 + |> List.duplicate(Nx.rank(input)) + |> List.to_tuple() + |> put_elem( + Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)), + num_channels ) end + deftransformp stateful_normalization_mode_transform( + input, + gamma, + beta, + ra_mean, + ra_var, + opts \\ [] + ) do + eps = opts[:epsilon] + alpha = opts[:momentum] + axes = opts[:axes] + + case opts[:mode] do + :train -> + {new_mean, new_var} = mean_and_variance(input, axes: axes) + out = normalize(input, new_mean, new_var, gamma, beta, epsilon: eps) + ra_mean = update_ema(new_mean, ra_mean, alpha) + ra_var = update_ema(new_var, ra_var, alpha) + + %Axon.StatefulOutput{ + output: out, + state: %{"mean" => ra_mean, "var" => ra_var} + } + + :inference -> + normalize(input, ra_mean, ra_var, gamma, beta, epsilon: eps) + end + end + ## Stochastic @doc ~S""" @@ -1660,10 +1395,7 @@ defmodule Axon.Layers do opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) - noise_shape = - transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> - Axon.Shape.spatial_dropout_noise_shape(shape, channels) - end) + noise_shape = Axon.Shape.spatial_dropout_noise_shape(input, opts[:channels]) dropout(input, key, rate: opts[:rate], @@ -1738,10 +1470,7 @@ defmodule Axon.Layers do opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) - noise_shape = - transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> - Axon.Shape.spatial_dropout_noise_shape(shape, channels) - end) + noise_shape = Axon.Shape.spatial_dropout_noise_shape(input, opts[:channels]) keep_prob = 1 - opts[:rate] @@ -1808,14 +1537,7 @@ defmodule Axon.Layers do opts = keyword!(opts, channels: :last, keep_axes: false, mode: :inference) - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) Nx.mean(input, axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes]) end @@ -1872,14 +1594,7 @@ defmodule Axon.Layers do opts = keyword!(opts, keep_axes: false, channels: :last, mode: :inference) - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) Nx.reduce_max(input, axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes]) end @@ -1943,14 +1658,7 @@ defmodule Axon.Layers do norm = opts[:norm] - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) input |> Nx.power(norm) @@ -2031,32 +1739,29 @@ defmodule Axon.Layers do > """ @doc type: :shape - defn flatten(x, _opts \\ []) do - new_shape = transform(Nx.shape(x), &Axon.Shape.flatten/1) + deftransform flatten(input, _opts \\ []) do + shape = Nx.shape(input) + out_units = Nx.size(Tuple.delete_at(shape, 0)) + out_shape = {elem(shape, 0), out_units} - Nx.reshape(x, new_shape) + Nx.reshape(input, out_shape) end @doc false # Internal version of Nx.reshape for constructing reshape layers # without worrying about a batch dimension - defn reshape(x, opts \\ []) do - opts = keyword!(opts, [:shape, mode: :inference]) - - transform({opts[:shape], x}, fn {shape, x} -> - batch_size = Nx.axis_size(x, 0) - - new_shape = - shape - |> Tuple.to_list() - |> Enum.map(fn - :batch -> batch_size - val -> val - end) - |> List.to_tuple() - - Nx.reshape(x, new_shape) + deftransform reshape(x, opts \\ []) do + opts = Keyword.validate!(opts, [:shape, mode: :inference]) + batch_size = Nx.axis_size(x, 0) + + opts[:shape] + |> Tuple.to_list() + |> Enum.map(fn + :batch -> batch_size + val -> val end) + |> List.to_tuple() + |> then(&Nx.reshape(x, &1)) end @doc false @@ -2064,34 +1769,27 @@ defmodule Axon.Layers do # worrying about batch or channel dimensions defn pad(x, opts \\ []) do opts = keyword!(opts, [:padding_config, :value, :channels, mode: :inference]) + config = padding_config_transform(opts[:padding_config], opts[:channels]) - config = - transform({opts[:padding_config], opts[:channels]}, fn - {config, :first} -> - [{0, 0, 0}, {0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] + Nx.pad(x, Nx.as_type(opts[:value], Nx.type(x)), config) + end - {config, :last} -> - [{0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] ++ [{0, 0, 0}] - end) + deftransform padding_config_transform(config, channels) do + case channels do + :first -> + [{0, 0, 0}, {0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] - Nx.pad(x, Nx.as_type(opts[:value], Nx.type(x)), config) + :last -> + [{0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] ++ [{0, 0, 0}] + end end @doc false # Internal version of Nx.transpose for constructing a transpose layer # without worrying about a batch dimension - defn transpose(x, opts \\ []) do - opts = keyword!(opts, [:axes, mode: :inference]) - - axes = - transform({Nx.shape(x), opts[:axes]}, fn - {shape, nil} -> - Nx.axes(shape) |> Enum.reverse() - - {_, axes} -> - axes - end) - + deftransform transpose(x, opts \\ []) do + opts = Keyword.validate!(opts, [:axes, mode: :inference]) + axes = opts[:axes] || Enum.reverse(Nx.axes(x)) Nx.transpose(x, axes: axes) end @@ -2102,20 +1800,7 @@ defmodule Axon.Layers do opts = keyword!(opts, [:cond, mode: :inference]) cond_expr = opts[:cond].(cond_input_expr) - transform(cond_expr, fn cond_expr -> - cond_rank = Nx.rank(cond_expr) - cond_type = Nx.type(cond_expr) - - unless Elixir.Kernel.and( - Elixir.Kernel.==(cond_rank, 0), - Elixir.Kernel.==(cond_type, {:u, 8}) - ) do - raise ArgumentError, - "cond_fn must return a scalar-boolean tensor" <> - " got result with rank #{inspect(cond_rank)} and" <> - " type #{inspect(cond_type)}" - end - end) + validate_conv_predicate!(cond_expr) if cond_expr do on_true_expr @@ -2124,6 +1809,21 @@ defmodule Axon.Layers do end end + deftransformp validate_conv_predicate!(cond_expr) do + cond_rank = Nx.rank(cond_expr) + cond_type = Nx.type(cond_expr) + + unless Elixir.Kernel.and( + Elixir.Kernel.==(cond_rank, 0), + Elixir.Kernel.==(cond_type, {:u, 8}) + ) do + raise ArgumentError, + "cond_fn must return a scalar-boolean tensor" <> + " got result with rank #{inspect(cond_rank)} and" <> + " type #{inspect(cond_type)}" + end + end + @doc false # Internal helper for constructing bias layers without defn bias(input, bias, _opts \\ []) do @@ -2167,67 +1867,70 @@ defmodule Axon.Layers do ** (ArgumentError) expected :method to be either of :nearest, :bilinear, :bicubic, :lanczos3, :lanczos5, got: :foo """ @doc type: :shape - defn resize(input, opts \\ []) do + deftransform resize(input, opts \\ []) do assert_rank!("Axon.Layers.resize", "input", input, 4) opts = - keyword!(opts, [ + Keyword.validate!(opts, [ :size, method: :nearest, channels: :last, mode: :inference ]) - transform({input, opts}, fn {input, opts} -> - {spatial_axes, out_shape} = - input - |> spatial_axes_with_sizes(opts) - |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end) - |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape -> - {axis, put_elem(out_shape, axis, out_size)} - end) + {spatial_axes, out_shape} = + input + |> spatial_axes_with_sizes(opts) + |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end) + |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape -> + {axis, put_elem(out_shape, axis, out_size)} + end) - resized_input = - case opts[:method] do - :nearest -> - resize_nearest(input, out_shape, spatial_axes) + resized_input = + case opts[:method] do + :nearest -> + resize_nearest(input, out_shape, spatial_axes) - :bilinear -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1) + :bilinear -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1) - :bicubic -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1) + :bicubic -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1) - :lanczos3 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1)) + :lanczos3 -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1)) - :lanczos5 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1)) + :lanczos5 -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1)) - method -> - raise ArgumentError, - "expected :method to be either of :nearest, :bilinear, :bicubic, " <> - ":lanczos3, :lanczos5, got: #{inspect(method)}" - end + method -> + raise ArgumentError, + "expected :method to be either of :nearest, :bilinear, :bicubic, " <> + ":lanczos3, :lanczos5, got: #{inspect(method)}" + end - cast_to(resized_input, input) - end) + cast_to(resized_input, input) + end + + deftransformp spatial_axes_with_sizes(input, opts \\ []) do + {height_axis, width_axis} = spatial_axes(input, channels: opts[:channels]) + {height, width} = size(input, channels: opts[:channels]) + {out_height, out_width} = opts[:size] + [{height_axis, height, out_height}, {width_axis, width, out_width}] end - defnp spatial_axes(input, opts \\ []) do + deftransformp spatial_axes(input, opts \\ []) do channels = opts[:channels] - transform({input, channels}, fn {input, channels} -> - axes = - case channels do - :first -> [-2, -1] - :last -> [-3, -2] - end + axes = + case channels do + :first -> [-2, -1] + :last -> [-3, -2] + end - axes - |> Enum.map(&Nx.axis_index(input, &1)) - |> List.to_tuple() - end) + axes + |> Enum.map(&Nx.axis_index(input, &1)) + |> List.to_tuple() end defnp cast_to(left, right) do @@ -2236,56 +1939,58 @@ defmodule Axon.Layers do |> Nx.reshape(left, names: Nx.names(right)) end - defnp resize_nearest(input, out_shape, spatial_axes) do - transform({input, out_shape, spatial_axes}, fn {input, out_shape, spatial_axes} -> - singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple() + deftransformp resize_nearest(input, out_shape, spatial_axes) do + singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple() - for axis <- spatial_axes, reduce: input do - input -> - input_shape = Nx.shape(input) - input_size = elem(input_shape, axis) - output_size = elem(out_shape, axis) - inv_scale = input_size / output_size - offset = (Nx.iota({output_size}) + 0.5) * inv_scale - offset = offset |> Nx.floor() |> Nx.as_type({:s, 32}) + for axis <- spatial_axes, reduce: input do + input -> + input_shape = Nx.shape(input) + input_size = elem(input_shape, axis) + output_size = elem(out_shape, axis) + inv_scale = input_size / output_size + offset = Nx.iota({output_size}) |> Nx.add(0.5) |> Nx.multiply(inv_scale) + offset = offset |> Nx.floor() |> Nx.as_type({:s, 32}) - offset = - offset - |> Nx.reshape(put_elem(singular_shape, axis, output_size)) - |> Nx.broadcast(put_elem(input_shape, axis, output_size)) + offset = + offset + |> Nx.reshape(put_elem(singular_shape, axis, output_size)) + |> Nx.broadcast(put_elem(input_shape, axis, output_size)) - Nx.take_along_axis(input, offset, axis: axis) - end - end) + Nx.take_along_axis(input, offset, axis: axis) + end end @f32_eps :math.pow(2, -23) - defnp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do - transform({input, out_shape, spatial_axes}, fn {input, out_shape, spatial_axes} -> - for axis <- spatial_axes, reduce: input do - input -> - input_shape = Nx.shape(input) - input_size = elem(input_shape, axis) - output_size = elem(out_shape, axis) + deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do + for axis <- spatial_axes, reduce: input do + input -> + input_shape = Nx.shape(input) + input_size = elem(input_shape, axis) + output_size = elem(out_shape, axis) - inv_scale = input_size / output_size - kernel_scale = Nx.max(1, inv_scale) + inv_scale = input_size / output_size + kernel_scale = Nx.max(1, inv_scale) - sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5 - x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale - weights = kernel_fun.(x) + sample_f = + Nx.add(Nx.iota({1, output_size}), 0.5) |> Nx.multiply(Nx.subtract(inv_scale, 0.5)) - weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) + x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale) + weights = kernel_fun.(x) - weights = - Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0) + weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) - input = Nx.dot(input, [axis], weights, [0]) - # The transformed axis is moved to the end, so we transpose back - reorder_axis(input, -1, axis) - end - end) + weights = + Nx.select( + Nx.greater(Nx.abs(weights), 1000 * @f32_eps), + safe_divide(weights, weights_sum), + 0 + ) + + input = Nx.dot(input, [axis], weights, [0]) + # The transformed axis is moved to the end, so we transpose back + reorder_axis(input, -1, axis) + end end defnp fill_linear_kernel(x) do @@ -2311,20 +2016,11 @@ defmodule Axon.Layers do x / Nx.select(y != 0, y, 1) end - defnp reorder_axis(tensor, axis, target_axis) do - transform({tensor, axis, target_axis}, fn {tensor, axis, target_axis} -> - axes = Nx.axes(tensor) - {source_axis, axes} = List.pop_at(axes, axis) - axes = List.insert_at(axes, target_axis, source_axis) - Nx.transpose(tensor, axes: axes) - end) - end - - defnp spatial_axes_with_sizes(input, opts \\ []) do - {height_axis, width_axis} = spatial_axes(input, channels: opts[:channels]) - {height, width} = size(input, channels: opts[:channels]) - {out_height, out_width} = opts[:size] - [{height_axis, height, out_height}, {width_axis, width, out_width}] + deftransformp reorder_axis(tensor, axis, target_axis) do + axes = Nx.axes(tensor) + {source_axis, axes} = List.pop_at(axes, axis) + axes = List.insert_at(axes, target_axis, source_axis) + Nx.transpose(tensor, axes: axes) end defnp size(input, opts \\ []) do @@ -2341,23 +2037,16 @@ defmodule Axon.Layers do for activation <- @activation_layers do @doc false - defn unquote(activation)(input, _opts \\ []) do - transform(input, fn inp -> - Elixir.Kernel.apply(Axon.Activations, unquote(activation), [inp]) - end) + deftransform unquote(activation)(input, _opts \\ []) do + apply(Axon.Activations, unquote(activation), [input]) end end @activation_layers_with_opts [:celu, :elu, :hard_sigmoid, :hard_silu, :leaky_relu] ++ [:log_sumexp, :log_softmax, :selu, :softmax] for activation <- @activation_layers_with_opts do - defn unquote(activation)(input, opts \\ []) do - transform(input, fn inp -> - Elixir.Kernel.apply(Axon.Activations, unquote(activation), [ - inp, - Keyword.delete(opts, :mode) - ]) - end) + deftransform unquote(activation)(input, opts \\ []) do + apply(Axon.Activations, unquote(activation), [input, Keyword.delete(opts, :mode)]) end end @@ -2367,26 +2056,22 @@ defmodule Axon.Layers do @element_wise_layers [:add, :subtract, :multiply] for op <- @element_wise_layers do - defn unquote(op)(inputs, _opts \\ []) do - transform(inputs, fn inputs -> - [first | rest] = Tuple.to_list(inputs) + deftransform unquote(op)(inputs, _opts \\ []) do + [first | rest] = Tuple.to_list(inputs) - Enum.reduce(rest, first, fn next, acc -> - apply(Nx, unquote(op), [acc, next]) - end) + Enum.reduce(rest, first, fn next, acc -> + apply(Nx, unquote(op), [acc, next]) end) end end @doc false - defn concatenate(inputs, opts \\ []) do - opts = keyword!(opts, axis: -1, mode: :inference) + deftransform concatenate(inputs, opts \\ []) do + opts = Keyword.validate!(opts, axis: -1, mode: :inference) - transform(inputs, fn inputs -> - inputs - |> Tuple.to_list() - |> Nx.concatenate(axis: opts[:axis]) - end) + inputs + |> Tuple.to_list() + |> Nx.concatenate(axis: opts[:axis]) end ## Recurrent @@ -2500,46 +2185,40 @@ defmodule Axon.Layers do rank_up({new_h, {new_c, new_h}}) end - defnp split_gates(gates) do - transform(gates, fn gates -> - channels = elem(Nx.shape(gates), 1) - split_every = div(channels, 4) + deftransformp split_gates(gates) do + channels = elem(Nx.shape(gates), 1) + split_every = div(channels, 4) - split_dims = - for i <- 0..3 do - {i * split_every, split_every} - end + split_dims = + for i <- 0..3 do + {i * split_every, split_every} + end - split_dims - |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) - |> List.to_tuple() - end) + split_dims + |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) + |> List.to_tuple() end - defnp rank_down(rnn_data) do - transform(rnn_data, fn {input, {cell, hidden}} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - Nx.squeeze(tensor, axes: [1]) - end + deftransformp rank_down({input, {cell, hidden}}) do + [cell, hidden, input] = + for tensor <- [cell, hidden, input] do + Nx.squeeze(tensor, axes: [1]) + end - {input, {cell, hidden}} - end) + {input, {cell, hidden}} end - defnp rank_up(rnn_data) do - transform(rnn_data, fn {input, {cell, hidden}} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - new_shape = - Nx.shape(tensor) - |> Tuple.insert_at(1, 1) + deftransformp rank_up({input, {cell, hidden}}) do + [cell, hidden, input] = + for tensor <- [cell, hidden, input] do + new_shape = + Nx.shape(tensor) + |> Tuple.insert_at(1, 1) - Nx.reshape(tensor, new_shape) - end + Nx.reshape(tensor, new_shape) + end - {input, {cell, hidden}} - end) + {input, {cell, hidden}} end @doc """ @@ -2554,17 +2233,18 @@ defmodule Axon.Layers do may be more efficient for long sequences. """ defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1)) - - feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2)) + time_steps = Nx.axis_size(input_sequence, 1) + feature_dims = list_duplicate(0, Nx.rank(input_sequence) - 2) initial_shape = - transform({cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, fn - {cell_fn, inp, carry, inp_kernel, hid_kernel, bias} -> - seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) - {seq, _} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) - put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) - end) + unroll_initial_shape_transform( + cell_fn, + input_sequence, + carry, + input_kernel, + recurrent_kernel, + bias + ) init_sequence = Nx.broadcast(0.0, initial_shape) i = Nx.tensor(0) @@ -2573,7 +2253,7 @@ defmodule Axon.Layers do while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias}, Nx.less(i, time_steps) do sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1) - indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end) + indices = compute_indices(i, feature_dims) {output, carry} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias) update_sequence = Nx.put_slice(init_sequence, indices, output) {i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias} @@ -2582,6 +2262,16 @@ defmodule Axon.Layers do {output, carry} end + deftransformp compute_indices(i, feature_dims) do + [0, i] ++ feature_dims + end + + deftransformp unroll_initial_shape_transform(cell_fn, inp, carry, inp_kernel, hid_kernel, bias) do + seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) + {seq, _} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) + put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) + end + @doc """ Statically unrolls an RNN. @@ -2693,17 +2383,7 @@ defmodule Axon.Layers do assert_min_rank!("Axon.Layers.split", "input", input, 2) opts = keyword!(opts, [:index, :splits, axis: -1, mode: :train]) - shape = Nx.shape(input) - - {offset, size} = - transform( - {shape, opts[:index], opts[:splits], opts[:axis]}, - fn {shape, idx, splits, axis} -> - slice_size = Axon.Shape.split(shape, splits, axis) - offset = idx * slice_size - {offset, slice_size} - end - ) + {offset, size} = Axon.Shape.split(input, opts[:index], opts[:splits], opts[:axis]) Nx.slice_along_axis(input, offset, size, axis: opts[:axis]) end diff --git a/lib/axon/loss_scale.ex b/lib/axon/loss_scale.ex index 116d12920..48f32d49f 100644 --- a/lib/axon/loss_scale.ex +++ b/lib/axon/loss_scale.ex @@ -44,19 +44,12 @@ defmodule Axon.LossScale do end defnp scale_static(value, %{loss_scale: loss_scale}) do - transform({value, loss_scale}, fn {value, loss_scale} -> - deep_new(value, fn x -> x * loss_scale end) - end) + deep_new(value, fn x -> x * loss_scale end) end defnp unscale_static(value, %{loss_scale: loss_scale} = state) do inv_loss_scale = 1 / loss_scale - - unscaled = - transform({value, inv_loss_scale}, fn {value, inv_loss_scale} -> - deep_new(value, fn x -> x * inv_loss_scale end) - end) - + unscaled = deep_new(value, fn x -> x * inv_loss_scale end) {unscaled, state} end @@ -81,19 +74,12 @@ defmodule Axon.LossScale do end defnp scale_dynamic(value, %{loss_scale: loss_scale}) do - transform({value, loss_scale}, fn {value, loss_scale} -> - deep_new(value, fn x -> x * loss_scale end) - end) + deep_new(value, fn x -> x * loss_scale end) end defnp unscale_dynamic(value, %{loss_scale: loss_scale} = state, opts \\ []) do inv_loss_scale = 1 / loss_scale - - unscaled = - transform({value, inv_loss_scale}, fn {value, inv_loss_scale} -> - deep_new(value, fn x -> x * inv_loss_scale end) - end) - + unscaled = deep_new(value, fn x -> x * inv_loss_scale end) {unscaled, adjust_dynamic(value, state, opts)} end @@ -101,12 +87,10 @@ defmodule Axon.LossScale do opts = keyword!(opts, period: 2_000, factor: 2, min_loss_scale: 1) grads_are_finite = - transform(grads, fn grads -> - deep_reduce(grads, Nx.tensor(1), fn x, acc -> - x - |> is_finite() - |> Nx.logical_and(acc) - end) + deep_reduce(grads, Nx.tensor(1), fn x, acc -> + x + |> is_finite() + |> Nx.logical_and(acc) end) new_loss_scale = diff --git a/lib/axon/losses.ex b/lib/axon/losses.ex index 34cdcaf12..4c4b7181b 100644 --- a/lib/axon/losses.ex +++ b/lib/axon/losses.ex @@ -134,20 +134,7 @@ defmodule Axon.Losses do # altogether if necessary. If either of them is set, then we need to set # both and perform this whole thing. If neither is set, we set this to # nil and then avoid the weighted avg later on. - weights = - transform({y_true, opts[:positive_weight], opts[:negative_weight]}, fn - {_, nil, nil} -> - nil - - {y_true, pos, nil} -> - Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true) - - {y_true, nil, neg} -> - Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true) - - {y_true, pos, neg} -> - Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true) - end) + weights = get_weights(y_true, opts[:positive_weight], opts[:negative_weight]) # Merge types before computing loss to prevent under/overflow. This # can especially happen when targets are encoded as u8 tensors. We @@ -207,6 +194,22 @@ defmodule Axon.Losses do reduction(possibly_weighted_avg_loss, opts[:reduction]) end + deftransformp get_weights(y_true, pos, neg) do + case {y_true, pos, neg} do + {_, nil, nil} -> + nil + + {y_true, pos, nil} -> + Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true) + + {y_true, nil, neg} -> + Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true) + + {y_true, pos, neg} -> + Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true) + end + end + defnp sigmoid_cross_entropy_from_logits(y_true, y_pred) do log_p = Axon.Activations.log_sigmoid(y_pred) log_not_p = Axon.Activations.log_sigmoid(-y_pred) @@ -895,14 +898,7 @@ defmodule Axon.Losses do n12 = Nx.max(w1 * w2, eps) loss = w12 / n12 - transform( - {opts[:reduction], loss}, - fn - {:mean, loss} -> Nx.mean(loss) - {:sum, loss} -> Nx.sum(loss) - {:none, loss} -> loss - end - ) + reduction(loss, opts[:reduction]) end @doc ~S""" @@ -1019,14 +1015,11 @@ defmodule Axon.Losses do {Nx.put_slice(loss, [b], Nx.reshape(loss_b, {1})), b + 1, y_true, s_true, y_pred} end - transform( - {opts[:reduction], loss}, - fn - {:mean, loss} -> Nx.divide(loss, l_true) |> Nx.mean() - {:sum, loss} -> Nx.sum(loss) - {:none, loss} -> loss - end - ) + case opts[:reduction] do + :mean -> Nx.divide(loss, l_true) |> Nx.mean() + :sum -> Nx.sum(loss) + :none -> loss + end end defnp get_limits(y_true, s_max, t_max) do diff --git a/lib/axon/metrics.ex b/lib/axon/metrics.ex index 3df313556..cf1753e15 100644 --- a/lib/axon/metrics.ex +++ b/lib/axon/metrics.ex @@ -416,14 +416,7 @@ defmodule Axon.Metrics do defn top_k_categorical_accuracy(y_true, y_pred, opts \\ []) do opts = keyword!(opts, k: 5, sparse: false) - y_true = - transform(y_true, fn y_true -> - if opts[:sparse] do - y_true - else - top_k_index_transform(y_true) - end - end) + y_true = if opts[:sparse], do: y_true, else: top_k_index_transform(y_true) cond do Nx.rank(y_pred) == 2 -> @@ -449,7 +442,7 @@ defmodule Axon.Metrics do end end - defnp(top_k_index_transform(y_true), do: Nx.argmax(y_true, axis: -1, keep_axis: true)) + defnp top_k_index_transform(y_true), do: Nx.argmax(y_true, axis: -1, keep_axis: true) # Combinators diff --git a/lib/axon/recurrent.ex b/lib/axon/recurrent.ex deleted file mode 100644 index 7a0444668..000000000 --- a/lib/axon/recurrent.ex +++ /dev/null @@ -1,233 +0,0 @@ -defmodule Axon.Recurrent do - @moduledoc false - - import Nx.Defn - import Axon.Layers - - @doc """ - GRU Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - GRU-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](https://arxiv.org/pdf/1412.3555v1.pdf) - """ - @deprecated "Use Axon.Layers.gru_cell/7 instead" - defn gru_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {hidden} = carry - {wir, wiz, win} = input_kernel - {whr, whz, whn} = hidden_kernel - {br, bz, bin, bhn} = bias - - r = gate_fn.(dense(input, wir, br) + dense(hidden, whr, 0)) - z = gate_fn.(dense(input, wiz, bz) + dense(hidden, whz, 0)) - n = activation_fn.(dense(input, win, bin) + r * dense(hidden, whn, bhn)) - - new_h = (1.0 - z) * n + z * hidden - - {{new_h}, new_h} - end - - @doc """ - LSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - LSTM-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Long Short-Term Memory](http://www.bioinf.jku.at/publications/older/2604.pdf) - """ - @deprecated "Use Axon.Layers.lstm_cell/7 instead" - defn lstm_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {cell, hidden} = carry - {wii, wif, wig, wio} = input_kernel - {whi, whf, whg, who} = hidden_kernel - - {bi, bf, bg, bo} = bias - - i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0)) - f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0)) - g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0)) - o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0)) - - new_c = f * cell + i * g - new_h = o * activation_fn.(new_c) - - {{new_c, new_h}, new_h} - end - - @doc """ - ConvLSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - ConvLSTM-based RNN. More memory efficient than traditional LSTM. - - ## Options - - * `:strides` - convolution strides. Defaults to `1`. - - * `:padding` - convolution padding. Defaults to `:same`. - - ## References - - * [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://arxiv.org/abs/1506.04214) - """ - @deprecated "Use Axon.Layers.conv_lstm_cell/6 instead" - defn conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ []) do - opts = keyword!(opts, strides: 1, padding: :same) - - {ih} = input_kernel - {hh} = hidden_kernel - {bi} = bias - - {{cell, hidden}, input} = rank_down({carry, input}) - - gates = - Nx.add( - conv(input, ih, bi, strides: opts[:strides], padding: opts[:padding]), - conv(hidden, hh, 0, strides: opts[:strides], padding: opts[:padding]) - ) - - {i, g, f, o} = split_gates(gates) - - f = sigmoid(f + 1) - new_c = f * cell + sigmoid(i) * tanh(g) - new_h = sigmoid(o) * tanh(new_c) - - rank_up({{new_c, new_h}, new_h}) - end - - defnp split_gates(gates) do - transform(gates, fn gates -> - channels = elem(Nx.shape(gates), 1) - split_every = div(channels, 4) - - split_dims = - for i <- 0..3 do - {i * split_every, split_every} - end - - split_dims - |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) - |> List.to_tuple() - end) - end - - defnp rank_down(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - Nx.squeeze(tensor, axes: [1]) - end - - {{cell, hidden}, input} - end) - end - - defnp rank_up(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - new_shape = - Nx.shape(tensor) - |> Tuple.insert_at(1, 1) - - Nx.reshape(tensor, new_shape) - end - - {{cell, hidden}, input} - end) - end - - @doc """ - Dynamically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function will make use of an `defn` while-loop such and thus - may be more efficient for long sequences. - """ - @deprecated "Use Axon.Layers.dynamic_unroll/6 instead" - defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1)) - - feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2)) - - initial_shape = - transform({cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, fn - {cell_fn, inp, carry, inp_kernel, hid_kernel, bias} -> - seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) - {_, seq} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) - put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) - end) - - init_sequence = Nx.broadcast(0.0, initial_shape) - i = Nx.tensor(0) - - {_, carry, output, _, _, _, _} = - while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias}, - Nx.less(i, time_steps) do - sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1) - indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end) - {carry, output} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias) - update_sequence = Nx.put_slice(init_sequence, indices, output) - {i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias} - end - - {carry, output} - end - - @doc """ - Statically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function inlines the unrolling of the sequence such that - the entire operation appears as a part of the compilation graph. - This makes it suitable for shorter sequences. - """ - @deprecated "Use Axon.Layers.static_unroll/6 instead" - defn static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - transform( - {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, - fn {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias} -> - time_steps = elem(Nx.shape(input_sequence), 1) - - {carry, outputs} = - for t <- 0..(time_steps - 1), reduce: {carry, []} do - {carry, outputs} -> - input = Nx.slice_along_axis(input_sequence, t, 1, axis: 1) - {carry, output} = cell_fn.(input, carry, input_kernel, recurrent_kernel, bias) - {carry, [output | outputs]} - end - - {carry, Nx.concatenate(Enum.reverse(outputs), axis: 1)} - end - ) - end -end diff --git a/lib/axon/shape.ex b/lib/axon/shape.ex index 5e7314b93..7c5860de1 100644 --- a/lib/axon/shape.ex +++ b/lib/axon/shape.ex @@ -1,6 +1,8 @@ defmodule Axon.Shape do @moduledoc false + import Nx.Defn + # Collection of shape calculations for calculating the # output and trainable parameter shapes for high-level # layers. @@ -319,8 +321,11 @@ defmodule Axon.Shape do the input bias shape is a vector, otherwise we'll just attempt to let it broadcast itself. """ - def conv_bias_reshape(input_shape, spatial_rank, channels) do - case input_shape do + deftransform conv_bias_reshape(input, bias, channels) do + bias_shape = Nx.shape(bias) + spatial_rank = Nx.rank(input) - 2 + + case bias_shape do {} -> {} @@ -338,11 +343,51 @@ defmodule Axon.Shape do end end + @doc """ + Calculates the permutation options to pass to convolution + based on channels configuration. + + It returns both the input/output permutation and the kernel + permutation. + """ + deftransform conv_permutations(input, channels) do + rank = Nx.rank(input) + + case channels do + :first -> + perm = Enum.to_list(0..(rank - 1)) + {perm, perm} + + :last -> + spatial = Enum.to_list(1..(rank - 2)//1) + perm = [0, rank - 1 | spatial] + kernel_perm = [rank - 1, rank - 2] ++ Enum.to_list(0..(rank - 3)//1) + {perm, kernel_perm} + + invalid -> + raise ArgumentError, "invalid channel configuration, #{inspect(invalid)}" + end + end + + @doc """ + Calculates strides for transposed convolution. + """ + deftransform conv_transpose_strides(input, strides) do + rank = Nx.rank(input) - 2 + + case strides do + [_ | _] = strides -> strides + strides -> List.duplicate(strides, rank) + end + end + @doc """ Calculates the padding needed for a transposed convolution. """ - def conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding, channels) - when padding in [:valid, :same] do + deftransform conv_transpose_padding(kernel, kernel_dilation, strides, padding, channels) + when padding in [:valid, :same] do + kernel_shape = Nx.shape(kernel) + kernel_spatial_dims = case channels do :first -> @@ -395,7 +440,7 @@ defmodule Axon.Shape do end end - def conv_transpose_padding(_, _, _, padding, _), do: padding + deftransform conv_transpose_padding(_, _, _, padding, _), do: padding @doc """ Calculates the shape of a depthwise convolution kernel given the @@ -632,7 +677,9 @@ defmodule Axon.Shape do across batch or channel dimensions, so we just specify a size of `1` for each of those. """ - def pool_window_size(window, spatial_rank, channels) do + deftransform pool_window_size(input, window, channels) do + spatial_rank = Nx.rank(input) - 2 + spatial_dims = case window do x when is_integer(x) -> @@ -655,20 +702,70 @@ defmodule Axon.Shape do end @doc """ - Computes the window size from the given parent shape. + Calculates the window strides of a pooling operation. """ - def adaptive_pool_window_size(parent_shape, nil, channels) do + deftransform pool_window_strides(input, strides, window_dimensions, channels) do + rank = Nx.rank(input) + + case {strides, channels} do + {nil, _} -> Tuple.to_list(window_dimensions) + {[_ | _] = strides, :first} -> [1, 1 | strides] + {[_ | _] = strides, :last} -> [1 | strides] ++ [1] + {strides, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] + {strides, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates window dilations of a pooling operation. + """ + deftransform pool_window_dilations(input, window_dilations, channels) do + rank = Nx.rank(input) + + case {window_dilations, channels} do + {nil, _} -> List.duplicate(1, rank) + {[_ | _] = dilations, :first} -> [1, 1 | dilations] + {[_ | _] = dilations, :last} -> [1 | dilations] ++ [1] + {dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] + {dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates padding of a pooling operation based on input padding + and channels configuration. + """ + deftransform pool_window_padding(padding, channels) do + case {padding, channels} do + {:same, _} -> :same + {:valid, _} -> :valid + {padding, :first} -> [{0, 0}, {0, 0} | padding] + {padding, :last} -> [{0, 0} | padding] ++ [{0, 0}] + end + end + + @doc """ + Computes the adaptive pooling output size from the given parent + shape, output shape and channels configuration. + """ + deftransform adaptive_pool_output_size(input, nil, channels) do + parent_shape = Nx.shape(input) + case channels do :first -> - parent_shape |> Tuple.delete_at(0) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(0) + |> Tuple.delete_at(0) :last -> - parent_shape |> Tuple.delete_at(tuple_size(parent_shape) - 1) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(tuple_size(parent_shape) - 1) + |> Tuple.delete_at(0) end end - def adaptive_pool_window_size(parent_shape, output_size, _channels) do - inner_rank = Nx.rank(parent_shape) - 2 + deftransform adaptive_pool_output_size(input, output_size, _channels) do + inner_rank = Nx.rank(input) - 2 tuple_or_duplicate(:output_size, output_size, inner_rank) end @@ -684,7 +781,10 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_strides(input_shape, output_spatial, spatial_rank, channels) do + deftransform adaptive_pool_window_strides(input, output_spatial, channels) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + idx = if channels == :first do 1 @@ -733,13 +833,15 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_size( - input_shape, - stride, - output_spatial, - spatial_rank, - channels - ) do + deftransform adaptive_pool_window_size( + input, + stride, + output_spatial, + channels + ) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + strides = case channels do :first -> @@ -813,16 +915,22 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for batch normalization. """ - def batch_norm_axes(axes, channel_index) do - axes - |> Enum.filter(&(&1 != channel_index)) + deftransform batch_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + + input + |> Nx.axes() + |> Enum.filter(&(&1 != axis)) end @doc """ Calculates the reduction axes for instance normalization. """ - def instance_norm_axes(axes, channel_index) do - reduction_axes = axes -- [0, channel_index] + deftransform instance_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + axes = Nx.axes(input) + + reduction_axes = axes -- [0, axis] if reduction_axes == [] do raise ArgumentError, "rank of input shape must be at least 3" @@ -834,14 +942,17 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for group normalization. """ - def group_norm_axes(rank, channel_index) do - Enum.to_list(1..(rank - 1)) -- [channel_index] + deftransform group_norm_axes(input, channel_index) do + Enum.to_list(1..(Nx.rank(input) - 1)) -- [channel_index] end @doc """ Calculates the reshape for group normalization. """ - def group_norm_shape(shape, num_groups, channel_index) do + deftransform group_norm_shape(input, num_groups, channel_index) do + shape = Nx.shape(input) + channel_index = Nx.Shape.normalize_axis(shape, channel_index, Nx.names(input)) + channels = elem(shape, channel_index) group_size = div(channels, num_groups) @@ -850,42 +961,25 @@ defmodule Axon.Shape do |> Tuple.insert_at(channel_index + 1, group_size) end - @doc """ - Calculates the shape after a flatten layer, which - flattens the non-minibatch dimensions into a single - dimension. - - ## Examples - - iex> Axon.Shape.flatten({nil, 1, 28, 28}) - {nil, 784} - - iex> Axon.Shape.flatten({32, 128}) - {32, 128} - - iex> Axon.Shape.flatten({nil, 10, 10}) - {nil, 100} - """ - def flatten(shape) do - out_units = Nx.size(Tuple.delete_at(shape, 0)) - - {elem(shape, 0), out_units} - end - @doc """ Computes split sizes for the given splits. """ - def split(shape, n, axis) do + deftransform split(input, index, splits, axis) do + shape = Nx.shape(input) + nil_names = List.duplicate(nil, Nx.rank(shape)) axis = Nx.Shape.normalize_axis(shape, axis, nil_names) - unless rem(elem(shape, axis), n) == 0 do + unless rem(elem(shape, axis), splits) == 0 do raise ArgumentError, - "unable to create #{n} even splits along axis #{axis}" <> + "unable to create #{splits} even splits along axis #{axis}" <> " of size #{elem(shape, axis)}" end - div(elem(shape, axis), n) + slice_size = div(elem(shape, axis), splits) + + offset = index * slice_size + {offset, slice_size} end @doc """ @@ -898,13 +992,15 @@ defmodule Axon.Shape do ## Examples - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 3, 28, 28}, :first) - {nil, 1, 28, 28} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 3, 28, 28}, :first) + {1, 1, 28, 28} - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 28, 28, 3}, :last) - {nil, 28, 28, 1} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 28, 28, 3}, :last) + {1, 28, 28, 1} """ - def spatial_dropout_noise_shape(input_shape, channels) do + deftransform spatial_dropout_noise_shape(input, channels) do + input_shape = Nx.shape(input) + if channels == :first do :erlang.setelement(2, input_shape, 1) else @@ -972,6 +1068,22 @@ defmodule Axon.Shape do {elem(shape, 0), 1, units} end + @doc """ + Returns the reduction axes for a global pooling operation + based on the input rank and channels configuration. + """ + deftransform global_pool_axes(input, channels) do + rank = Nx.rank(input) + + case channels do + :last -> + Enum.to_list(1..(rank - 2)) + + :first -> + Enum.to_list(2..(rank - 1)) + end + end + defp tuple_or_duplicate(key, tuple_or_integer, rank) do cond do is_tuple(tuple_or_integer) -> diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 910a45bda..fa7108f7e 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -11,111 +11,87 @@ defmodule Axon.Shared do @doc """ Asserts `lhs` has same shape as `rhs`. """ - defn assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {lhs, rhs} -> - lhs = Nx.shape(lhs) - rhs = Nx.shape(rhs) - - unless Elixir.Kernel.==(lhs, rhs) do - raise ArgumentError, - "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> - " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" - end - end - ) + deftransform assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do + lhs = Nx.shape(lhs) + rhs = Nx.shape(rhs) + + unless lhs == rhs do + raise ArgumentError, + "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> + " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" + end end @doc """ Asserts all shapes are equal. """ - defn assert_shape!(caller, shape_names, shapes) do - transform(shapes, fn [shape | shapes] -> - equal? = - Enum.all?(shapes, fn cur_shape -> - Elixir.Kernel.==(Nx.shape(cur_shape), Nx.shape(shape)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input shapes #{inspect(shape_names)}" <> - " to be equal, got #{inspect(shapes)}" - end - end) + deftransform assert_shape!(caller, shape_names, [shape | shapes]) do + equal? = + Enum.all?(shapes, fn cur_shape -> + Nx.shape(cur_shape) == Nx.shape(shape) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input shapes #{inspect(shape_names)}" <> + " to be equal, got #{inspect(shapes)}" + end end @doc """ Asserts `inp` has explicit rank `rank`. """ - defn assert_rank!(caller, inp_name, inp, rank) do - transform( - {inp, rank}, - fn {x, y} -> - x = Nx.rank(x) - - unless Elixir.Kernel.==(x, y) do - raise ArgumentError, - "#{caller}: expected #{inp_name} to have rank equal to #{y}," <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_rank!(caller, inp_name, inp, rank) do + x = Nx.rank(inp) + + unless x == rank do + raise ArgumentError, + "#{caller}: expected #{inp_name} to have rank equal to #{rank}," <> + " got #{x} != #{rank}" + end end @doc """ Asserts `lhs` has same rank as `rhs`. """ - defn assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) + + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> + " got #{x} != #{y}" + end end @doc """ Asserts all ranks are equal. """ - defn assert_equal_rank!(caller, rank_names, ranks) do - transform(ranks, fn [rank | ranks] -> - equal? = - Enum.all?(ranks, fn cur_rank -> - Elixir.Kernel.==(Nx.rank(cur_rank), Nx.rank(rank)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input ranks #{inspect(rank_names)}" <> - " to be equal, got #{inspect(ranks)}" - end - end) + deftransform assert_equal_rank!(caller, rank_names, [rank | ranks]) do + equal? = + Enum.all?(ranks, fn cur_rank -> + Nx.rank(cur_rank) == Nx.rank(rank) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input ranks #{inspect(rank_names)}" <> + " to be equal, got #{inspect(ranks)}" + end end @doc """ Asserts `lhs` has at least rank `rhs`. """ - defn assert_min_rank!(caller, name, lhs, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" - end - end - ) + deftransform assert_min_rank!(caller, name, lhs, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) + + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" + end end @doc """ @@ -252,18 +228,17 @@ defmodule Axon.Shared do end end - ## Numerical Helpers + ## List transforms in defn - # TODO: These should be contained somewhere else, like another library + deftransform list_duplicate(value, size) do + List.duplicate(value, size) + end - defn logsumexp(x, opts \\ []) do - opts = keyword!(opts, axes: [], keep_axes: false) + deftransform list_wrap(value), do: List.wrap(value) - x - |> Nx.exp() - |> Nx.sum(opts) - |> Nx.log() - end + ## Numerical Helpers + + # TODO: These should be contained somewhere else, like another library defn xlogy(x, y) do x_ok = Nx.not_equal(x, 0.0) diff --git a/lib/axon/updates.ex b/lib/axon/updates.ex index ac5bffd0c..5ca661a2e 100644 --- a/lib/axon/updates.ex +++ b/lib/axon/updates.ex @@ -43,12 +43,7 @@ defmodule Axon.Updates do end defnp apply_scale(x, _params, step) do - transform( - {x, step}, - fn {updates, step} -> - deep_new(updates, fn x -> Nx.multiply(x, step) end) - end - ) + deep_new(updates, fn x -> Nx.multiply(x, step) end) end Notice how the function given to `stateless/2` is defined within `defn`. @@ -68,9 +63,7 @@ defmodule Axon.Updates do defnp apply_my_update(updates, state) do new_state = deep_new(state, fn v -> Nx.add(v, 0.01) end) - updates = transform({updates, new_state}, fn {updates, state} -> - deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end) - end) + updates = deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end) {updates, %{state: new_state}} end @@ -525,10 +518,8 @@ defmodule Axon.Updates do defnp radam_update(ro, ro_inf, mu, nu, eps_root, eps) do r = Nx.sqrt((ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) - transform({r, mu, nu, eps_root, eps}, fn {r, mu, nu, eps_root, eps} -> - deep_merge(mu, nu, fn m, v -> - r * m / (Nx.sqrt(v + eps_root) + eps) - end) + deep_merge(mu, nu, fn m, v -> + r * m / (Nx.sqrt(v + eps_root) + eps) end) end @@ -678,16 +669,16 @@ defmodule Axon.Updates do end defnp apply_centralize(x, _params, _opts \\ []) do - transform(x, fn x -> - deep_new(x, fn z -> - if Elixir.Kernel.>(Nx.rank(z), 1) do - axes = tl(Nx.axes(z)) - z - Nx.mean(z, axes: axes, keep_axes: true) - else - z - end - end) - end) + deep_new(x, ¢ralize_for_rank/1) + end + + deftransformp centralize_for_rank(input) do + if Nx.rank(input) > 1 do + input + |> Nx.subtract(Nx.mean(input, axes: tl(Nx.axes(input)), keep_axes: true)) + else + input + end end @doc """ diff --git a/test/axon/activations_test.exs b/test/axon/activations_test.exs index 47e0cc7c8..2f6634cb3 100644 --- a/test/axon/activations_test.exs +++ b/test/axon/activations_test.exs @@ -700,7 +700,7 @@ defmodule Axon.ActivationsTest do describe "log_softmax" do test "raises on bad axis" do - assert_raise ArgumentError, ~r/log_softmax axis must be within rank of tensor/, fn -> + assert_raise ArgumentError, "given axis (2) invalid for shape with rank 2", fn -> Axon.Activations.log_softmax(Nx.iota({1, 3}), axis: 2) end end @@ -1143,6 +1143,22 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.sigmoid(&1))) end), [a]) assert_all_close(expected, actual) end + + defn cache_test_sigmoid(x) do + x + |> Axon.Activations.sigmoid() + |> get_cached() + end + + deftransformp get_cached(res) do + %{data: %{args: [_, %{logits: inp}]}} = res + inp + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_sigmoid(a), a) + end end describe "silu" do @@ -1348,6 +1364,17 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.softmax(&1))) end), [a]) assert_all_close(expected, actual, atol: 1.0e-7) end + + defn cache_test_softmax(x) do + x + |> Axon.Activations.softmax() + |> get_cached() + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_softmax(a), a) + end end describe "softplus" do diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index cc1fb9fc6..e2d5e1366 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -198,6 +198,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with strides" do + input = Nx.random_uniform({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = Nx.random_uniform({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) kernel = Nx.iota({2, 1, 1}) @@ -237,6 +250,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels first same as channels last with strides" do + input = Nx.random_uniform({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = Nx.random_uniform({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv_transpose(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv_transpose(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "correct valid padding, no strides" do inp = Nx.iota({1, 1, 4}, type: {:f, 32}) kernel = Nx.iota({3, 1, 2}, type: {:f, 32}) @@ -579,6 +605,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = Nx.random_uniform({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -622,6 +669,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = Nx.random_uniform({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -665,6 +733,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = Nx.random_uniform({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1})