diff --git a/lib/axon.ex b/lib/axon.ex index 50cad497..c1f15e7b 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -1582,6 +1582,39 @@ defmodule Axon do layer(pool, [x], opts) end + @doc """ + Adds a blur pooling layer to the network. + + See `Axon.Layers.blur_pool/2` for more details. + + ## Options + + * `:name` - layer name. + + * `:strides` - stride during convolution. Defaults to `1`. + + * `:channels` - channels location. One of `:first` or `:last`. + Defaults to `:last`. + """ + def blur_pool(%Axon{} = x, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + channels: :last + ]) + + channels = opts[:channels] + name = opts[:name] + + opts = [ + name: name, + channels: channels, + op_name: :blur_pool + ] + + layer(:blur_pool, [x], opts) + end + ## Adaptive Pooling @adaptive_pooling_layers [ diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index b2a95051..b4497f41 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -961,6 +961,77 @@ defmodule Axon.Layers do |> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) end + @doc """ + Functional implementation of a 2-dimensional blur pooling layer. + + Blur pooling applies a spatial low-pass filter to the input. It is + often applied before pooling and convolutional layers as a way to + increase model accuracy without much additional computation cost. + + The blur pooling implementation follows from [MosaicML](https://github.com/mosaicml/composer/blob/dev/composer/algorithms/blurpool/blurpool_layers.py). + """ + @doc type: :pooling + defn blur_pool(input, opts \\ []) do + assert_rank!("blur_pool", "input", input, 4) + opts = keyword!(opts, channels: :last, mode: :train) + + filter = + Nx.tensor([ + [ + [ + [1, 2, 1], + [2, 4, 2], + [1, 2, 1] + ] + ] + ]) * 1 / 16.0 + + output_channels = + case opts[:channels] do + :last -> + Nx.axis_size(input, 3) + + :first -> + Nx.axis_size(input, 1) + end + + filter = compute_filter(filter, opts[:channels], output_channels) + + conv(input, filter, + padding: padding_for_filter(filter), + feature_group_size: output_channels, + channels: opts[:channels] + ) + end + + deftransformp compute_filter(filter, :first, out_channels) do + filter_shape = put_elem(Nx.shape(filter), 0, out_channels) + Nx.broadcast(filter, filter_shape) + end + + deftransformp compute_filter(filter, :last, out_channels) do + filter_shape = put_elem(Nx.shape(filter), 0, out_channels) + filter_permutation = [3, 2, 0, 1] + filter |> Nx.broadcast(filter_shape) |> Nx.transpose(axes: filter_permutation) + end + + deftransformp padding_for_filter(filter) do + {_, _, h, w} = Nx.shape(filter) + + cond do + rem(h, 2) == 0 -> + raise ArgumentError, "filter height must be odd" + + rem(w, 2) == 0 -> + raise ArgumentError, "filter width must be odd" + + true -> + :ok + end + + [{div(h, 2), div(h, 2)}, {div(w, 2), div(w, 2)}] + end + @doc """ Functional implementation of general dimensional adaptive average pooling. diff --git a/lib/axon/losses.ex b/lib/axon/losses.ex index 8d692abe..2c433a43 100644 --- a/lib/axon/losses.ex +++ b/lib/axon/losses.ex @@ -1177,6 +1177,53 @@ defmodule Axon.Losses do t0_prob end + ## Modifiers + + @doc """ + Modifies the given loss function to smooth labels prior + to calculating loss. + + See `apply_label_smoothing/2` for details. + + ## Options + + * `:smoothing` - smoothing factor. Defaults to 0.1 + """ + def label_smoothing(loss_fun, opts \\ []) when is_function(loss_fun, 2) do + opts = Keyword.validate!(opts, smoothing: 0.1) + + fn y_true, y_pred -> + smoothed = apply_label_smoothing(y_true, y_pred, smoothing: opts[:smoothing]) + loss_fun.(smoothed, y_pred) + end + end + + @doc """ + Applies label smoothing to the given labels. + + Label smoothing is a regularization technique which shrink targets + towards a uniform distribution. Label smoothing can improve model + generalization. + + ## Options + + * `:smoothing` - smoothing factor. Defaults to 0.1 + + ## References + + * [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) + """ + defn apply_label_smoothing(y_true, y_pred, opts \\ []) do + assert_min_rank!("apply_label_smoothing", "y_true", y_true, 2) + assert_min_rank!("apply_label_smoothing", "y_pred", y_pred, 2) + + opts = keyword!(opts, smoothing: 0.1) + n_classes = Nx.axis_size(y_pred, 1) + y_true * (1 - opts[:smoothing]) + opts[:smoothing] / n_classes + end + + ## Helpers + defnp reduction(loss, reduction \\ :none) do case reduction do :mean -> Nx.mean(loss) diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 4dc6d22e..e49574a5 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -813,6 +813,59 @@ defmodule CompilerTest do # end end + describe "blur_pool" do + test "initializes with no params" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + + input = random({1, 32, 32, 1}) + + assert {init_fn, _predict_fn} = Axon.build(model) + assert %{} = init_fn.(input, %{}) + end + + test "computes forward pass with default options" do + model2 = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) + + assert {_, predict_fn} = Axon.build(model2) + + assert_equal( + predict_fn.(%{}, input2), + apply(Axon.Layers, :blur_pool, [input2]) + ) + end + + test "computes forward pass with output policy" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + policy = AMP.create_policy(output: {:bf, 16}) + mp_model = AMP.apply_policy(model, policy) + + input = random({1, 32, 32, 1}) + + assert {init_fn, predict_fn} = Axon.build(mp_model) + + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 32, 1}))) == + {:bf, 16} + end + + test "computes forward pass with channels last" do + model = + apply(Axon, :blur_pool, [ + Axon.input("input", shape: {nil, 32, 32, 1}), + [channels: :last] + ]) + + inp = random({1, 32, 32, 1}) + + assert {_, predict_fn} = Axon.build(model) + + assert_equal( + predict_fn.(%{}, inp), + apply(Axon.Layers, :blur_pool, [inp, [channels: :last]]) + ) + end + end + @adaptive_pooling_layers [:adaptive_avg_pool, :adaptive_max_pool, :adaptive_lp_pool] describe "adaptive pooling" do diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index feeb1e73..a9c874d4 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -376,6 +376,14 @@ defmodule Axon.LoopTest do } end ) + |> Loop.handle_event( + :completed, + fn %State{step_state: %{counter: counter}} = state -> + assert 4 = counter + + {:continue, state} + end + ) |> Loop.run( [{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}], %{}, @@ -408,6 +416,14 @@ defmodule Axon.LoopTest do } end ) + |> Loop.handle_event( + :completed, + fn %State{step_state: %{counter: counter}} = state -> + assert {{4}, 4} = counter + + {:continue, state} + end + ) |> Loop.run( [{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}], %{}, diff --git a/test/axon/losses_test.exs b/test/axon/losses_test.exs index 922986c3..6f396cca 100644 --- a/test/axon/losses_test.exs +++ b/test/axon/losses_test.exs @@ -284,4 +284,25 @@ defmodule Axon.LossesTest do ) end end + + describe "apply_label_smoothing" do + test "correctly smooths labels" do + y_true = Nx.tensor([[0, 1, 0, 0, 0, 0]]) + y_pred = Nx.tensor([[0.5, 0.1, 0.1, 0.0, 0.2, 0.1]]) + + assert_all_close( + Axon.Losses.apply_label_smoothing(y_true, y_pred, smoothing: 0.1), + Nx.tensor([[0.0167, 0.9167, 0.0167, 0.0167, 0.0167, 0.0167]]), + atol: 1.0e-3 + ) + end + end + + describe "label_smoothing" do + test "returns an arity-2 function from loss function" do + loss = &Axon.Losses.categorical_cross_entropy/2 + smooth_loss = Axon.Losses.label_smoothing(loss, smoothing: 0.1) + assert is_function(smooth_loss, 2) + end + end end