Skip to content

Commit

Permalink
Add some features from Mosaic ML (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jun 2, 2023
1 parent 006650c commit c19ff70
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 0 deletions.
33 changes: 33 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,39 @@ defmodule Axon do
layer(pool, [x], opts)
end

@doc """
Adds a blur pooling layer to the network.
See `Axon.Layers.blur_pool/2` for more details.
## Options
* `:name` - layer name.
* `:strides` - stride during convolution. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:last`.
"""
def blur_pool(%Axon{} = x, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
channels: :last
])

channels = opts[:channels]
name = opts[:name]

opts = [
name: name,
channels: channels,
op_name: :blur_pool
]

layer(:blur_pool, [x], opts)
end

## Adaptive Pooling

@adaptive_pooling_layers [
Expand Down
71 changes: 71 additions & 0 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,77 @@ defmodule Axon.Layers do
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

@doc """
Functional implementation of a 2-dimensional blur pooling layer.
Blur pooling applies a spatial low-pass filter to the input. It is
often applied before pooling and convolutional layers as a way to
increase model accuracy without much additional computation cost.
The blur pooling implementation follows from [MosaicML](https://github.com/mosaicml/composer/blob/dev/composer/algorithms/blurpool/blurpool_layers.py).
"""
@doc type: :pooling
defn blur_pool(input, opts \\ []) do
assert_rank!("blur_pool", "input", input, 4)
opts = keyword!(opts, channels: :last, mode: :train)

filter =
Nx.tensor([
[
[
[1, 2, 1],
[2, 4, 2],
[1, 2, 1]
]
]
]) * 1 / 16.0

output_channels =
case opts[:channels] do
:last ->
Nx.axis_size(input, 3)

:first ->
Nx.axis_size(input, 1)
end

filter = compute_filter(filter, opts[:channels], output_channels)

conv(input, filter,
padding: padding_for_filter(filter),
feature_group_size: output_channels,
channels: opts[:channels]
)
end

deftransformp compute_filter(filter, :first, out_channels) do
filter_shape = put_elem(Nx.shape(filter), 0, out_channels)
Nx.broadcast(filter, filter_shape)
end

deftransformp compute_filter(filter, :last, out_channels) do
filter_shape = put_elem(Nx.shape(filter), 0, out_channels)
filter_permutation = [3, 2, 0, 1]
filter |> Nx.broadcast(filter_shape) |> Nx.transpose(axes: filter_permutation)
end

deftransformp padding_for_filter(filter) do
{_, _, h, w} = Nx.shape(filter)

cond do
rem(h, 2) == 0 ->
raise ArgumentError, "filter height must be odd"

rem(w, 2) == 0 ->
raise ArgumentError, "filter width must be odd"

true ->
:ok
end

[{div(h, 2), div(h, 2)}, {div(w, 2), div(w, 2)}]
end

@doc """
Functional implementation of general dimensional adaptive average
pooling.
Expand Down
47 changes: 47 additions & 0 deletions lib/axon/losses.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,53 @@ defmodule Axon.Losses do
t0_prob
end

## Modifiers

@doc """
Modifies the given loss function to smooth labels prior
to calculating loss.
See `apply_label_smoothing/2` for details.
## Options
* `:smoothing` - smoothing factor. Defaults to 0.1
"""
def label_smoothing(loss_fun, opts \\ []) when is_function(loss_fun, 2) do
opts = Keyword.validate!(opts, smoothing: 0.1)

fn y_true, y_pred ->
smoothed = apply_label_smoothing(y_true, y_pred, smoothing: opts[:smoothing])
loss_fun.(smoothed, y_pred)
end
end

@doc """
Applies label smoothing to the given labels.
Label smoothing is a regularization technique which shrink targets
towards a uniform distribution. Label smoothing can improve model
generalization.
## Options
* `:smoothing` - smoothing factor. Defaults to 0.1
## References
* [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567)
"""
defn apply_label_smoothing(y_true, y_pred, opts \\ []) do
assert_min_rank!("apply_label_smoothing", "y_true", y_true, 2)
assert_min_rank!("apply_label_smoothing", "y_pred", y_pred, 2)

opts = keyword!(opts, smoothing: 0.1)
n_classes = Nx.axis_size(y_pred, 1)
y_true * (1 - opts[:smoothing]) + opts[:smoothing] / n_classes
end

## Helpers

defnp reduction(loss, reduction \\ :none) do
case reduction do
:mean -> Nx.mean(loss)
Expand Down
53 changes: 53 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,59 @@ defmodule CompilerTest do
# end
end

describe "blur_pool" do
test "initializes with no params" do
model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})])

input = random({1, 32, 32, 1})

assert {init_fn, _predict_fn} = Axon.build(model)
assert %{} = init_fn.(input, %{})
end

test "computes forward pass with default options" do
model2 = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 8, 4, 1})])
input2 = random({1, 8, 4, 1}, type: {:f, 32})

assert {_, predict_fn} = Axon.build(model2)

assert_equal(
predict_fn.(%{}, input2),
apply(Axon.Layers, :blur_pool, [input2])
)
end

test "computes forward pass with output policy" do
model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})])
policy = AMP.create_policy(output: {:bf, 16})
mp_model = AMP.apply_policy(model, policy)

input = random({1, 32, 32, 1})

assert {init_fn, predict_fn} = Axon.build(mp_model)

assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 32, 1}))) ==
{:bf, 16}
end

test "computes forward pass with channels last" do
model =
apply(Axon, :blur_pool, [
Axon.input("input", shape: {nil, 32, 32, 1}),
[channels: :last]
])

inp = random({1, 32, 32, 1})

assert {_, predict_fn} = Axon.build(model)

assert_equal(
predict_fn.(%{}, inp),
apply(Axon.Layers, :blur_pool, [inp, [channels: :last]])
)
end
end

@adaptive_pooling_layers [:adaptive_avg_pool, :adaptive_max_pool, :adaptive_lp_pool]

describe "adaptive pooling" do
Expand Down
16 changes: 16 additions & 0 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle_event(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert 4 = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand Down Expand Up @@ -408,6 +416,14 @@ defmodule Axon.LoopTest do
}
end
)
|> Loop.handle_event(
:completed,
fn %State{step_state: %{counter: counter}} = state ->
assert {{4}, 4} = counter

{:continue, state}
end
)
|> Loop.run(
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
%{},
Expand Down
21 changes: 21 additions & 0 deletions test/axon/losses_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,25 @@ defmodule Axon.LossesTest do
)
end
end

describe "apply_label_smoothing" do
test "correctly smooths labels" do
y_true = Nx.tensor([[0, 1, 0, 0, 0, 0]])
y_pred = Nx.tensor([[0.5, 0.1, 0.1, 0.0, 0.2, 0.1]])

assert_all_close(
Axon.Losses.apply_label_smoothing(y_true, y_pred, smoothing: 0.1),
Nx.tensor([[0.0167, 0.9167, 0.0167, 0.0167, 0.0167, 0.0167]]),
atol: 1.0e-3
)
end
end

describe "label_smoothing" do
test "returns an arity-2 function from loss function" do
loss = &Axon.Losses.categorical_cross_entropy/2
smooth_loss = Axon.Losses.label_smoothing(loss, smoothing: 0.1)
assert is_function(smooth_loss, 2)
end
end
end

0 comments on commit c19ff70

Please sign in to comment.