/
axon.ex
4068 lines (3144 loc) · 116 KB
/
axon.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
defmodule Axon do
@moduledoc """
A high-level interface for creating neural network models.
Axon is built entirely on top of Nx numerical definitions,
so every neural network can be JIT or AOT compiled using
any Nx compiler, or even transformed into high-level neural
network formats like TensorFlow Lite and
[ONNX](https://github.com/elixir-nx/axon_onnx).
For a more in-depth overview of Axon, refer to the [Guides](guides.html).
## Model Creation
All Axon models start with an input layer, optionally specifying
the expected shape of the input data:
input = Axon.input("input", shape: {nil, 784})
Notice you can specify some dimensions as `nil`, indicating
that the dimension size will be filled in at model runtime.
You can then compose inputs with other layers:
model =
input
|> Axon.dense(128, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.8)
|> Axon.dense(64)
|> Axon.tanh()
|> Axon.dense(10)
|> Axon.activation(:softmax)
You can inspect the model for a nice summary:
IO.inspect(model)
#Axon<
inputs: %{"input" => {nil, 784}}
outputs: "softmax_0"
nodes: 9
>
Or use the `Axon.Display` module to see more in-depth summaries:
Axon.Display.as_table(model, Nx.template({1, 784}, :f32)) |> IO.puts
+----------------------------------------------------------------------------------------------------------------+
| Model |
+=======================================+=============+==============+===================+=======================+
| Layer | Input Shape | Output Shape | Options | Parameters |
+=======================================+=============+==============+===================+=======================+
| input ( input ) | [] | {1, 784} | shape: {nil, 784} | |
| | | | optional: false | |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["input"] ) | [{1, 784}] | {1, 128} | | kernel: f32[784][128] |
| | | | | bias: f32[128] |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| relu_0 ( relu["dense_0"] ) | [{1, 128}] | {1, 128} | | |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| batch_norm_0 ( batch_norm["relu_0"] ) | [{1, 128}] | {1, 128} | epsilon: 1.0e-5 | gamma: f32[128] |
| | | | channel_index: 1 | beta: f32[128] |
| | | | momentum: 0.1 | mean: f32[128] |
| | | | | var: f32[128] |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dropout_0 ( dropout["batch_norm_0"] ) | [{1, 128}] | {1, 128} | rate: 0.8 | |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["dropout_0"] ) | [{1, 128}] | {1, 64} | | kernel: f32[128][64] |
| | | | | bias: f32[64] |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| tanh_0 ( tanh["dense_1"] ) | [{1, 64}] | {1, 64} | | |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_2 ( dense["tanh_0"] ) | [{1, 64}] | {1, 10} | | kernel: f32[64][10] |
| | | | | bias: f32[10] |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_2"] ) | [{1, 10}] | {1, 10} | | |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
### Multiple Inputs
Creating a model with multiple inputs is as easy as declaring an
additional input in your Axon graph. Every input layer present in
the final Axon graph will be required to be passed as input at the
time of model execution.
inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
# Both inputs will be used
model1 = Axon.add(inp1, inp2)
# Only inp2 will be used
model2 = Axon.add(inp2, inp2)
Axon graphs are immutable, which means composing and manipulating
an Axon graph creates an entirely new graph. Additionally, layer
names are lazily generated at model execution time. To avoid
non-deterministic input orderings and names, Axon requires each
input to have a unique binary identifier. You can then reference
inputs by name when passing to models at execution time:
inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model1 = Axon.add(inp1, inp2)
{init_fn, predict_fn} = Axon.build(model1)
params1 = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
# Inputs are referenced by name
predict_fn.(params1, %{"input_0" => x, "input_1" => y})
### Multiple Outputs
Nx offers robust [container](https://hexdocs.pm/nx/Nx.Container.html) support
which is extended to Axon. Axon allows you to wrap any valid Nx container
in a layer. Containers are most commonly used to structure outputs:
inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container(%{foo: inp1, bar: inp2})
Containers can be arbitrarily nested:
inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container({%{foo: {inp1, %{bar: inp2}}}})
You can even use custom structs which implement the container protocol:
inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container(%MyStruct{foo: inp1, bar: inp2})
### Custom Layers
If you find that Axon's built-in layers are insufficient for your needs,
you can create your own using the custom layer API. All of Axon's built-in
layers (aside from special ones such as `input`, `constant`, and `container`)
make use of this same API.
Axon layers are really just placeholders for Nx computations with trainable
parameters and possibly state. To define a custom layer, you just need to
define a `defn` implementation:
defn my_layer(x, weight, _opts \\\\ []) do
Nx.atan2(x, weight)
end
Notice the only stipulation is that your custom layer implementation must
accept at least 1 input and a list of options. At execution time, every
layer will be passed a `:mode` option which can be used to control behavior
at training and inference time.
Inputs to your custom layer can be either Axon graph inputs or trainable
parameters. You can pass Axon graph inputs as-is to a custom layer. To
declare trainable parameters, use `Axon.param/3`:
weight = Axon.param("weight", param_shape)
To create a custom layer, you "wrap" your implementation and inputs into
a layer using `Axon.layer`. You'll notice the API mirrors Elixir's `apply`:
def atan2_layer(%Axon{} = input) do
weight = Axon.param("weight", param_shape)
Axon.layer(&my_layer/3, [input, weight])
end
## Model Execution
Under the hood, Axon models are represented as Elixir structs. You
can initialize and apply models by building or compiling them with
`Axon.build/2` or `Axon.compile/4` and then calling the produced
initialization and predict functions:
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
predict_fn.(params, inputs)
You may either set the default JIT compiler or backend globally, or
pass a specific compiler to `Axon.build/2`:
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train)
params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
predict_fn.(params, inputs)
`predict_fn` by default runs in inference mode, which performs certain
optimizations and removes layers such as dropout layers. If constructing
a training step using `Axon.predict/4` or `Axon.build/2`, be sure to specify
`mode: :train`.
## Model Training
Combining the Axon model creation API with the optimization and training
APIs, you can create and train neural networks with ease:
model =
Axon.input("input_0", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.layer_norm()
|> Axon.dropout()
|> Axon.dense(10, activation: :softmax)
IO.inspect model
model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.005))
|> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)
See `Polaris.Updates` and `Axon.Loop` for a more in-depth treatment of
model optimization and model training.
## Using with `Nx.Serving`
When deploying an `Axon` model to production, you usually want to batch
multiple prediction requests and run the inference for all of them at
once. Conveniently, `Nx` already has an abstraction for this task in the
form of `Nx.Serving`. Here's how you could define a serving for an `Axon`
model:
def build_serving() do
# Configuration
batch_size = 4
defn_options = [compiler: EXLA]
Nx.Serving.new(
# This function runs on the serving startup
fn ->
# Build the Axon model and load params (usually from file)
model = build_model()
params = load_params()
# Build the prediction defn function
{_init_fun, predict_fun} = Axon.build(model)
inputs_template = %{"pixel_values" => Nx.template({batch_size, 224, 224, 3}, :f32)}
template_args = [Nx.to_template(params), inputs_template]
# Compile the prediction function upfront for the configured batch_size
predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)
# The returned function is called for every accumulated batch
fn inputs ->
inputs = Nx.Batch.pad(inputs, batch_size - inputs.size)
predict_fun.(params, inputs)
end
end,
batch_size: batch_size
)
end
Then you would start the serving server as part of your application's
supervision tree:
children = [
...,
{Nx.Serving, serving: build_serving(), name: MyApp.Serving, batch_timeout: 100}
]
With that in place, you can now ask serving for predictions all across
your application (controllers, live views, async jobs, etc.). Having a
tensor input you would do:
inputs = %{"pixel_values" => ...}
batch = Nx.Batch.concatenate([inputs])
result = Nx.Serving.batched_run(MyApp.Serving, batch)
Usually you also want to do pre/post-processing of the model input/output.
You could make those preparations directly before/after `Nx.Serving.batched_run/2`,
however you can also make use of `Nx.Serving.client_preprocessing/2` and
`Nx.Serving.client_postprocessing/2` to encapsulate that logic as part of
the serving.
"""
alias __MODULE__, as: Axon
alias Axon.Parameter
require Logger
# Axon serialization version
@file_version 1
@type t :: %__MODULE__{}
defstruct [
:nodes,
:output
]
@doc """
Custom Axon layer with given inputs.
Inputs may be other Axon layers or trainable parameters created
with `Axon.param`. At inference time, `op` will be applied with
inputs in specified order and an additional `opts` parameter which
specifies inference options. All options passed to layer are forwarded
to inference function except:
* `:name` - layer name.
* `:op_name` - layer operation for inspection and building parameter map.
* `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both`
* `:global_options` - a list of global option names that this layer
supports. Global options passed to `build/2` will be forwarded to
the layer, as long as they are declared
Note this means your layer should not use these as input options,
as they will always be dropped during inference compilation.
Axon's compiler will additionally forward the following options to
every layer at inference time:
* `:mode` - `:inference` or `:train`. To control layer behavior
based on inference or train time.
`op` is a function of the form:
fun = fn input, weight, bias, _opts ->
input * weight + bias
end
"""
@doc type: :special
def layer(op, inputs, opts \\ []) when (is_atom(op) or is_function(op)) and is_list(inputs) do
{inputs, params, args, updated_nodes} = split_inputs(op, inputs)
inputs = Enum.reverse(inputs)
params = Enum.reverse(params)
args = Enum.reverse(args)
{mode, opts} = Keyword.pop(opts, :mode, :both)
{name, opts} = Keyword.pop(opts, :name)
{op_name, opts} = Keyword.pop(opts, :op_name, :custom)
{global_options, opts} = Keyword.pop(opts, :global_options, [])
name = name(op_name, name)
id = System.unique_integer([:positive, :monotonic])
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts, global_options)
%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
end
defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts, global_options) do
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
Process.info(self(), :current_stacktrace)
%Axon.Node{
id: id,
mode: mode,
name: name,
parent: inputs,
parameters: params,
args: args,
op: op,
policy: Axon.MixedPrecision.create_policy(),
hooks: [],
opts: layer_opts,
global_options: global_options,
op_name: op_name,
stacktrace: stacktrace
}
end
defp split_inputs(:container, [inputs]) do
{inputs, cache} =
deep_map_reduce(inputs, %{}, fn %Axon{output: id, nodes: nodes}, cache ->
{id, Map.merge(nodes, cache)}
end)
{[inputs], [], [:layer], cache}
end
defp split_inputs(_op, inputs) do
Enum.reduce(inputs, {[], [], [], %{}}, fn
%Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} ->
{[layer_input | layers], params, [:layer | args], Map.merge(nodes, cache)}
%Parameter{} = param, {layers, params, args, cache} ->
{layers, [param | params], [:parameter | args], cache}
invalid, _ ->
raise ArgumentError, "invalid input given to layer: #{inspect(invalid)}"
end)
end
@doc """
Trainable Axon parameter used to create custom layers.
Parameters are specified in usages of `Axon.layer` and will
be automatically initialized and used in subsequent applications
of Axon models.
You may specify the parameter shape as either a static shape or
as function of the inputs to the given layer. If you specify the
parameter shape as a function, it will be given the
## Options
* `:initializer` - parameter initializer. Defaults to `:glorot_uniform`.
"""
@doc type: :special
def param(name, shape, opts \\ [])
def param(name, {:map, [_ | _] = inner_params}, opts) do
maybe_warn_on_param_opts(opts)
%Axon.Parameter{
name: name,
type: :map,
children: inner_params
}
end
def param(name, shape, opts) when is_tuple(shape) or is_function(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32})
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}
%Axon.Parameter{
name: name,
shape: shape,
type: type,
initializer: initializer
}
end
defp maybe_warn_on_param_opts(opts) do
if :initializer in opts or :type in opts do
Logger.warning(
"Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
)
end
end
@doc """
Adds an input layer to the network.
Input layers specify a model's inputs. Input layers are
always the root layers of the neural network.
You must specify the input layers name, which will be used
to uniquely identify it in the case of multiple inputs.
## Options
* `:shape` - the expected input shape, use `nil` for dimensions
of a dynamic size.
* `:optional` - if `true`, the input may be omitted when using
the model. This needs to be handled in one of the subsequent
layers. See `optional/2` for more details.
"""
@doc type: :special
def input(name, opts \\ [])
def input(name, opts) when is_binary(name) and is_list(opts) do
opts = Keyword.validate!(opts, [:shape, optional: false])
optional = opts[:optional]
input_shape = opts[:shape]
output_shape = input_shape && Axon.Shape.input(input_shape)
layer(:input, [], name: name, shape: output_shape, op_name: :input, optional: optional)
end
@doc """
Wraps an Axon model in an optional node.
By default, when an optional input is missing, all subsequent layers
are nullified. For example, consider this model:
values = Axon.input("values")
mask = Axon.input("mask", optional: true)
model =
values
|> Axon.dense(10)
|> Axon.multiply(mask)
|> Axon.dense(1)
|> Axon.sigmoid()
In case the mask is not provided, the input node will resolve to
`%Axon.None{}` and so will all the layers that depend on it. By
using `optional/2` a layer may opt-in to receive `%Axon.None{}`.
To fix our example, we could define a custom layer to apply the
mask only when present
def apply_optional_mask(%Axon{} = x, %Axon{} = mask) do
Axon.layer(
fn x, mask, _opts ->
case mask do
%Axon.None{} -> x
mask -> Nx.multiply(x, mask)
end
end,
[x, Axon.optional(mask)]
)
end
# ...
model =
values
|> Axon.dense(10)
|> apply_optional_mask(mask)
|> Axon.dense(1)
|> Axon.sigmoid()
## Options
* `:name` - layer name.
"""
@doc type: :special
def optional(%Axon{} = x, opts \\ []) do
opts = Keyword.validate!(opts, [:name])
layer(:optional, [x], name: opts[:name], op_name: :optional)
end
@doc """
Adds a constant layer to the network.
Constant layers encapsulate Nx tensors in an Axon layer for ease
of use with other Axon layers. They can be used interchangeably
with other Axon layers:
inp = Axon.input("input", shape: {nil, 32})
my_constant = Axon.constant(Nx.iota({1, 32}))
model = Axon.add(inp, my_constant)
Constant layers will be cast according to the mixed precision policy.
If it's important for your constant to retain it's type during
the computation, you will need to set the mixed precision policy to
ignore constant layers.
## Options
* `:name` - layer name.
"""
def constant(tensor, opts \\ [])
@doc type: :special
def constant(%Nx.Tensor{} = tensor, opts) do
opts = Keyword.validate!(opts, [:name])
layer(:constant, [], name: opts[:name], value: tensor, op_name: :constant)
end
def constant(number, opts) when is_number(number) do
opts = Keyword.validate!(opts, [:name])
layer(:constant, [], name: opts[:name], value: Nx.tensor(number), op_name: :constant)
end
def constant(value, _) do
raise ArgumentError,
"value passed to constant must be an Nx tensor" <>
" but got #{inspect(value)}, if you are passing" <>
" a number, wrap it with a call to Nx.tensor/2"
end
@doc """
Adds a container layer to the network.
In certain cases you may want your model to have multiple
outputs. In order to make this work, you must "join" the
outputs into an Axon layer using this function for use in
initialization and inference later on.
The given container can be any valid Axon Nx container.
## Options
* `:name` - layer name.
## Examples
iex> inp1 = Axon.input("input_0", shape: {nil, 1})
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
iex> model = Axon.container(%{a: inp1, b: inp2})
iex> %{a: a, b: b} = Axon.predict(model, %{}, %{
...> "input_0" => Nx.tensor([[1.0]]),
...> "input_1" => Nx.tensor([[1.0, 2.0]])
...> })
iex> a
#Nx.Tensor<
f32[1][1]
[
[1.0]
]
>
iex> b
#Nx.Tensor<
f32[1][2]
[
[1.0, 2.0]
]
>
"""
@doc type: :special
def container(container, opts \\ []) do
opts = Keyword.validate!(opts, [:name])
layer(:container, [container], name: opts[:name], op_name: :container)
end
# TODO: This should not be duplicated
defp deep_new(%Nx.Tensor{} = x, fun), do: fun.(x)
defp deep_new(x, fun) when is_number(x), do: fun.(x)
defp deep_new(map, fun) do
{cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun))
cont
end
defp recur_traverse(item, :ok, fun) do
case item do
%Axon{} = t ->
{fun.(t), :ok}
%{axon: :axon} = t ->
{fun.(t), :ok}
container ->
{deep_new(container, fun), :ok}
end
end
defp deep_merge(left, right, fun) do
case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
{merged, []} ->
merged
{_merged, _leftover} ->
raise ArgumentError,
"unable to merge arguments with incompatible" <>
" structure"
end
end
defp leaves(container) do
container
|> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
|> Enum.reverse()
end
defp recur_merge(left, [right | right_leaves], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_leaves}
{%Axon{} = left, %Axon{} = right} ->
{fun.(left, right), right_leaves}
{left, right} ->
{deep_merge(left, right, fun), right_leaves}
end
end
@doc """
Wraps an Axon model into a namespace.
A namespace is a part of an Axon model which is meant to
be a self-contained collection of Axon layers. Namespaces
are guaranteed to always generate with the same internal
layer names and can be re-used universally across models.
Namespaces are most useful for containing large collections
of layers and offering a straightforward means for accessing
the parameters of individual model components. A common application
of namespaces is to use them in with a pre-trained model for
fine-tuning:
{base, resnet_params} = resnet()
base = base |> Axon.namespace("resnet")
model = base |> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model)
init_fn.(Nx.template({1, 3, 224, 224}, {:f, 32}), %{"resnset" => resnet_params})
Notice you can use `init_fn` in conjunction with namespaces
to specify which portion of a model you'd like to initialize
from a fixed starting point.
Namespaces have fixed names, which means it's easy to run into namespace
collisions. Re-using namespaces, re-using inner parts of a namespace,
and attempting to share layers between namespaces are still sharp
edges in namespace usage.
"""
@doc type: :special
def namespace(%Axon{} = axon, name) when is_binary(name) do
layer(:namespace, [axon], name: name)
end
@doc """
Returns a function which represents a self-contained re-usable block
of operations in a neural network. All parameters in the block are
shared between every usage of the block.
This returns an arity-1 function which accepts a list of inputs which
are forwarded to `fun`. This is most often used in situations where
you wish to re-use parameters in a block:
reused_dense = Axon.block(&Axon.dense(&1, 32))
Everytime `reused_dense` is invoked, it re-uses the same parameters:
input = Axon.input("features")
# unique parameters
x1 = Axon.dense(input, 32)
# unique parameters
x2 = reused_dense.(x1)
# parameters shared
x3 = reused_dense.(x2)
Subgraphs in blocks can be arbitrarily complex:
reused_block = Axon.block(fn x ->
x
|> Axon.dense(32)
|> Axon.dense(64)
|> Axon.dense(32)
end)
Blocks can also have multiple inputs, you can invoke a block with multiple
inputs by passing a list of arguments:
reused_block = Axon.block(fn x, y, z ->
x = Axon.dense(x, 32)
y = Axon.dense(y, 32)
z = Axon.dense(z, 32)
Axon.add([x, y, z])
end)
# invoke with a list
reused_block.([x, y, z])
Blocks prefix subgraph parameters with their name and a dot. As with other
Axon layers, if a name is not explicitly provided, one will be dynamically
generated.
"""
@doc type: :special
def block(fun, opts \\ []) when is_function(fun) do
opts = Keyword.validate!(opts, [:name])
block_id = System.unique_integer([:positive, :monotonic])
fn inputs ->
layer(:block, List.wrap(inputs),
op_name: :block,
name: opts[:name],
block_fun: fun,
block_id: block_id
)
end
end
@doc """
Adds a dense layer to the network.
The dense layer implements:
output = activation(dot(input, kernel) + bias)
where `activation` is given by the `:activation` option and both
`kernel` and `bias` are layer parameters. `units` specifies the
number of output units.
Compiles to `Axon.Layers.dense/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :linear
def dense(%Axon{} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true
])
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :dense}
else
{[x, kernel], :dense}
end
node = layer(op, inputs, name: opts[:name], op_name: :dense)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a bilinear layer to the network.
The bilinear layer implements:
output = activation(dot(dot(input1, kernel), input2) + bias)
where `activation` is given by the `:activation` option and both
`kernel` and `bias` are layer parameters. `units` specifies the
number of output units.
All dimensions but the last of `input1` and `input2` must match. The
batch sizes of both inputs must also match or at least one must be `nil`.
Inferred output batch size coerces to the strictest input batch size.
Compiles to `Axon.Layers.bilinear/5`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :linear
def bilinear(
%Axon{} = input1,
%Axon{} = input2,
units,
opts \\ []
)
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true
])
kernel_shape = &Axon.Shape.bilinear_kernel(&1, &2, units)
bias_shape = &Axon.Shape.bilinear_bias(&1, &2, units)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[input1, input2, kernel, bias], :bilinear}
else
{[input1, input2, kernel], :bilinear}
end
node = layer(op, inputs, name: opts[:name], op_name: :bilinear)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a convolution layer to the network.
The convolution layer implements a general dimensional
convolutional layer - which convolves a kernel over the input
to produce an output.
Compiles to `Axon.Layers.conv/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:feature_group_size` - feature group size for convolution. Defaults
to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:last`.
"""
@doc type: :convolution
def conv(%Axon{} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
input_dilation: 1,
kernel_dilation: 1,
channels: :last,
feature_group_size: 1
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
input_dilation = opts[:input_dilation]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
feature_group_size = opts[:feature_group_size]
kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, feature_group_size)
bias_shape = &Axon.Shape.conv_bias(&1, units, kernel_size, channels, feature_group_size)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :conv}
else
{[x, kernel], :conv}
end
node =
layer(op, inputs,
name: opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
kernel_dilation: kernel_dilation,
feature_group_size: feature_group_size,
channels: channels,
op_name: :conv
)
if activation = opts[:activation] do
activation(node, activation)