From 2068e13fd7a7f7a88b41148433cd884ad78c7f9e Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Mon, 28 Jun 2021 20:35:37 -0400 Subject: [PATCH] First commit --- .formatter.exs | 4 + .gitignore | 26 + README.md | 21 + lib/axon_onnx.ex | 6 + lib/axon_onnx/serialize.ex | 1060 ++++++++++++++++++++++++++++++++++++ mix.exs | 29 + mix.lock | 6 + priv/onnx.proto | 743 +++++++++++++++++++++++++ test/axon_onnx_test.exs | 8 + test/test_helper.exs | 1 + 10 files changed, 1904 insertions(+) create mode 100644 .formatter.exs create mode 100644 .gitignore create mode 100644 README.md create mode 100644 lib/axon_onnx.ex create mode 100644 lib/axon_onnx/serialize.ex create mode 100644 mix.exs create mode 100644 mix.lock create mode 100644 priv/onnx.proto create mode 100644 test/axon_onnx_test.exs create mode 100644 test/test_helper.exs diff --git a/.formatter.exs b/.formatter.exs new file mode 100644 index 0000000..d2cda26 --- /dev/null +++ b/.formatter.exs @@ -0,0 +1,4 @@ +# Used by "mix format" +[ + inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] +] diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6adf2b3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# The directory Mix will write compiled artifacts to. +/_build/ + +# If you run "mix test --cover", coverage assets end up here. +/cover/ + +# The directory Mix downloads your dependencies sources to. +/deps/ + +# Where third-party dependencies like ExDoc output generated docs. +/doc/ + +# Ignore .fetch files in case you like to edit your project deps locally. +/.fetch + +# If the VM crashes, it generates a dump, let's ignore it too. +erl_crash.dump + +# Also ignore archive artifacts (built via "mix archive.build"). +*.ez + +# Ignore package tarball (built via "mix hex.build"). +axon_onnx-*.tar + +# Temporary files, for example, from tests. +/tmp/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..fc37f9d --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# AxonOnnx + +**TODO: Add description** + +## Installation + +If [available in Hex](https://hex.pm/docs/publish), the package can be installed +by adding `axon_onnx` to your list of dependencies in `mix.exs`: + +```elixir +def deps do + [ + {:axon_onnx, "~> 0.1.0"} + ] +end +``` + +Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) +and published on [HexDocs](https://hexdocs.pm). Once published, the docs can +be found at [https://hexdocs.pm/axon_onnx](https://hexdocs.pm/axon_onnx). + diff --git a/lib/axon_onnx.ex b/lib/axon_onnx.ex new file mode 100644 index 0000000..57f3d58 --- /dev/null +++ b/lib/axon_onnx.ex @@ -0,0 +1,6 @@ +defmodule AxonOnnx do + use Protox, + files: [ + :filename.join([:code.priv_dir(:axon_onnx), "onnx.proto"]) + ] +end diff --git a/lib/axon_onnx/serialize.ex b/lib/axon_onnx/serialize.ex new file mode 100644 index 0000000..d60d5a8 --- /dev/null +++ b/lib/axon_onnx/serialize.ex @@ -0,0 +1,1060 @@ +defmodule AxonOnnx.Serialize do + alias Onnx.ModelProto, as: Model + alias Onnx.GraphProto, as: Graph + alias Onnx.ValueInfoProto, as: ValueInfo + alias Onnx.AttributeProto, as: Attribute + alias Onnx.NodeProto, as: Node + alias Onnx.TypeProto, as: Type + alias Onnx.TensorProto, as: T + alias Onnx.TypeProto.Tensor, as: Placeholder + alias Onnx.TensorShapeProto, as: Shape + alias Onnx.TensorShapeProto.Dimension, as: Dimension + + # TODO(seanmor5): Currently we do a lot of potentially expensive operations + # eagerly (especially when manipulating parameters), we can potentially make + # them part of the model or alternatively return an initialization function + # which can be JIT-compiled. + + # TODO(seanmor5): The current approach builds a lot of intermediate graphs, + # instead we should only keep graphs which are specified as outputs and override + # all other graphs so they are GC'ed + + # TODO(seanmor5): Some operations occur strictly on parameters (e.g. reshape, unsqueeze, + # etc.), so we need to change all of these cases to handle instances where the only + # input is a parameter which is an Nx expression rather than a model + + # TODO(seanmor5): Because some operations act on parameter inputs which don't have a + # parameterized equivalenet operation in Axon (e.g. add, multiply, etc.), we need + # a way to implement them that still builds an Axon model but preserves the parameters + + def __import__(file, opts \\ []) do + file + |> File.read!() + |> Model.decode!() + |> to_axon(opts) + end + + defp to_axon(%Model{graph: %Graph{node: nodes} = graph}, opts) do + dimensions = opts[:dimensions] || %{} + + params = get_params(graph) + inputs = get_inputs(graph, params, dimensions) + outputs = get_outputs(graph) + {nodes, params} = get_nodes(nodes, inputs, params, %{}) + {hd(Enum.map(outputs, fn name -> nodes[name] end)), params} + end + + defp get_inputs(%Graph{input: inputs}, params, dimensions) do + Enum.reduce(inputs, %{}, fn %ValueInfo{name: name, type: %Type{value: value}}, acc -> + if Map.has_key?(params, name) do + acc + else + case value do + {:tensor_type, %Placeholder{} = tensor} -> + input_shape = shape!(tensor, dimensions) + + input_shape = + if tuple_size(input_shape) == 1, + do: Tuple.insert_at(input_shape, 0, nil), + else: input_shape + + Map.put(acc, name, Axon.input(input_shape)) + + _ -> + raise ArgumentError, "unsupported input type" + end + end + end) + end + + defp get_params(%Graph{initializer: initializer}) do + Enum.reduce(initializer, %{}, fn %T{name: name} = tensor, params -> + Map.put(params, name, tensor!(tensor)) + end) + end + + defp get_outputs(%Graph{output: outputs}) do + Enum.map(outputs, fn %ValueInfo{name: name} -> name end) + end + + defp get_nodes(pruned_nodes, inp, params, used_params) do + Enum.reduce(pruned_nodes, {inp, used_params}, fn %Node{op_type: op_type} = op_node, + {axon, used_params} -> + case op_type do + "Abs" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.abs/1) + + "Acos" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.acos/1) + + "Acosh" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.acosh/1) + + "Add" -> + to_axon_binary_op(op_node, axon, params, used_params, :add) + + "ArgMax" -> + to_axon_reduction(op_node, axon, params, used_params, &Nx.argmax/2) + + "ArgMin" -> + to_axon_reduction(op_node, axon, params, used_params, &Nx.argmin/2) + + "Asin" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.asin/1) + + "Asinh" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.asinh/1) + + "Atan" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.atan/1) + + "Atanh" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.atanh/1) + + "AveragePool" -> + # TODO(seanmor5): unify window pooling + raise "unsupported op AveragePool" + + "BatchNormalization" -> + to_axon_batch_norm(op_node, axon, params, used_params) + + "BitShift" -> + raise "unsupported op BitShift" + + "Cast" -> + raise "unsupported op Cast" + + "Ceil" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.ceil/1) + + "Celu" -> + # TODO(seanmor5): alpha attr + to_axon_activation(op_node, axon, params, used_params, :celu) + + "Clip" -> + to_axon_clip(op_node, axon, params, used_params) + + "Compress" -> + raise "unsupported op Compress" + + "Concat" -> + raise "unsupported op Concat" + + "ConcatFromSequence" -> + raise "unsupported op ConcatFromSequence" + + "Constant" -> + to_axon_constant(op_node, axon, params, used_params) + + "ConstantOfShape" -> + # TODO(seanmor5): unify with to_axon_constant + raise "unsupported op ConstantOfShape" + + "Conv" -> + to_axon_conv(op_node, axon, params, used_params) + + "ConvInteger" -> + # TODO(seanmor5): unify conv + raise "unsupported op ConvInteger" + + "ConvTranspose" -> + # TODO(seanmor5): conv transpose + raise "unsupported op ConvTranspose" + + "Cos" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.cos/1) + + "Cosh" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.cosh/1) + + "CumSum" -> + raise "unsupported op CumSum" + + "DepthToSpace" -> + raise "unsupported op DepthToSpace" + + "DequantizeLinear" -> + raise "unsupported op DequantizeLinear" + + "Det" -> + raise "unsupported op Det" + + "Div" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> Nx.divide(x, y) end) + + "Dropout" -> + raise "unsupported op Dropout" + + "Einsum" -> + raise "unsupported op Einsum" + + "Elu" -> + # TODO(seanmor5): alpha attr + to_axon_activation(op_node, axon, params, used_params, :elu) + + "Equal" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> Nx.equal(x, y) end) + + "Erf" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.erf/1) + + "Exp" -> + to_axon_activation(op_node, axon, params, used_params, :exp) + + "Expand" -> + raise "unsupported op Expand" + + "EyeLike" -> + raise "unsuporrted op EyeLike" + + "Flatten" -> + to_axon_flatten(op_node, axon, params, used_params) + + "Floor" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.floor/1) + + "GRU" -> + raise "unsupported op GRU" + + "Gather" -> + raise "unsupported op Gather" + + "GatherElements" -> + raise "unsupported op GatherElements" + + "GatherND" -> + raise "unsupported op GatherND" + + "Gemm" -> + to_axon_dense(op_node, axon, params, used_params) + + "GlobalAveragePool" -> + to_axon_global_pool(op_node, axon, params, used_params) + + "GlobalLpPool" -> + to_axon_global_pool(op_node, axon, params, used_params) + + "GlobalMaxPool" -> + to_axon_global_pool(op_node, axon, params, used_params) + + "Greater" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> Nx.greater(x, y) end) + + "GreaterOrEqual" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> + Nx.greater_equal(x, y) + end) + + "HardSigmoid" -> + # TODO(seanmor5): alpha, beta attrs + to_axon_activation(op_node, axon, params, used_params, :hard_sigmoid) + + "Hardmax" -> + raise "unsupported op Hardmax" + + "Identity" -> + to_axon_nx(op_node, axon, params, used_params, & &1) + + "Less" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> Nx.less(x, y) end) + + "LessOrEqual" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> + Nx.less_equal(x, y) + end) + + "Log" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.log/1) + + "MatMul" -> + to_axon_dense(op_node, axon, params, used_params) + + "Mod" -> + # TODO(seanmor5): Support fmod option + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> + Nx.remainder(x, y) + end) + + "Mul" -> + to_axon_binary_op(op_node, axon, params, used_params, :multiply) + + "Neg" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.negate/1) + + "Not" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.logical_not/1) + + "Or" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> + Nx.logical_or(x, y) + end) + + "Pow" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> Nx.power(x, y) end) + + "ReduceMax" -> + to_axon_reduction(op_node, axon, params, used_params, &Nx.reduce_max/2) + + "ReduceMin" -> + to_axon_reduction(op_node, axon, params, used_params, &Nx.reduce_min/2) + + "ReduceProd" -> + to_axon_reduction(op_node, axon, params, used_params, &Nx.product/2) + + "Relu" -> + to_axon_activation(op_node, axon, params, used_params, :relu) + + "Reshape" -> + to_axon_reshape(op_node, axon, params, used_params) + + "Round" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.round/1) + + "Selu" -> + # TODO(seanmor5): alpha, gamma attrs + to_axon_activation(op_node, axon, params, used_params, :selu) + + "Shape" -> + to_axon_nx(op_node, axon, params, used_params, fn x -> + x + |> Nx.shape() + |> Tuple.to_list() + |> Nx.tensor(backend: Nx.Defn.Expr) + end) + + "Sigmoid" -> + to_axon_activation(op_node, axon, params, used_params, :sigmoid) + + "Sign" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.sign/1) + + "Sin" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.sin/1) + + "Sinh" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.sinh/1) + + "Size" -> + to_axon_nx(op_node, axon, params, used_params, fn x -> + x + |> Nx.size() + |> Nx.tensor(backend: Nx.Defn.Expr) + end) + + "Softmax" -> + # TODO(seanmor5): axis attr + to_axon_activation(op_node, axon, params, used_params, :softmax) + + "Softplus" -> + to_axon_activation(op_node, axon, params, used_params, :softplus) + + "Softsign" -> + to_axon_activation(op_node, axon, params, used_params, :softsign) + + "Sqrt" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.sqrt/1) + + "Sub" -> + to_axon_binary_op(op_node, axon, params, used_params, :subtract) + + "Tan" -> + to_axon_nx(op_node, axon, params, used_params, &Nx.tan/1) + + "Tanh" -> + to_axon_activation(op_node, axon, params, used_params, :tanh) + + "Transpose" -> + to_axon_transpose(op_node, axon, params, used_params) + + "Unsqueeze" -> + to_axon_unsqueeze(op_node, axon, params, used_params) + + "Xor" -> + to_axon_binary_op(op_node, axon, params, used_params, fn {x, y} -> + Nx.logical_xor(x, y) + end) + + "MaxPool" -> + to_axon_max_pool(op_node, axon, params, used_params) + + "Pad" -> + to_axon_pad(op_node, axon, params, used_params) + + op -> + raise "unsupported #{op} op in graph" + end + end) + end + + # Builds a generic Nx layer by applying the given operation + # to the input. Most of these functions are generic element-wise + # operations such as Abs, Acos, etc. + # + # TODO(seanmor5): Replace with Axon.layer when we have better shape + # inference + defp to_axon_nx(%Node{input: [input], output: [output_name]}, axon, params, used_params, fun) do + axon_input = input_or_param!(input, params, axon, used_params) + updated_axon = Map.put(axon, output_name, Axon.nx(axon_input, fun, name: output_name)) + {updated_axon, used_params} + end + + # Builds a generic Nx layer by applying the given reduction operation + # to the input. + # + # TODO(seanmor5): Replace with Axon.layer when we have better shape + # inference + defp to_axon_reduction( + %Node{input: [input], attribute: attrs, output: [output_name]}, + axon, + params, + used_params, + reduce_fun + ) do + reduce_options = options!(attrs) + + axes = reduce_options["axes"] + keepdims = reduce_options["keepdims"] + keep_axes = if keepdims == 1, do: true, else: false + + axon_input = input_or_param!(input, params, axon, used_params) + + updated_axon = + Map.put( + axon, + output_name, + Axon.nx(axon_input, reduce_fun, + name: output_name, + opts: [axes: axes, keep_axes: keep_axes] + ) + ) + + {updated_axon, used_params} + end + + # Builds an Axon dense layer from an ONNX MatMul or GEMM Node. MatMul + # nodes do not account for bias (they're treated as a separate operation + # in the graph). GEMM Nodes are a bit more in-depth. + # + # TODO(seanmor5): Handle alpha, beta attrs + defp to_axon_dense( + %Node{op_type: op_type, input: inputs, output: [output_name], attribute: attrs}, + axon, + params, + used_params + ) do + [input, weight | maybe_bias] = inputs + + input = input_or_param!(input, params, axon, used_params) + weight = input_or_param!(weight, params, axon, used_params) + + case op_type do + "MatMul" -> + {_, units} = Nx.shape(weight) + + updated_axon = + Map.put( + axon, + output_name, + Axon.dense(input, units, use_bias: false, name: output_name) + ) + + updated_params = Map.put(used_params, output_name <> "_kernel", weight) + {updated_axon, updated_params} + + "Gemm" -> + dense_options = options!(attrs) + + # TODO(seanmor5): Handle alpha, beta + _alpha = dense_options["alpha"] + _beta = dense_options["beta"] + + trans_a = dense_options["transA"] + trans_b = dense_options["transB"] + + input = + if trans_a == 1 do + Nx.transpose(input) + else + input + end + + weight = + if trans_b == 1 do + Nx.transpose(weight) + else + weight + end + + {_, units} = Nx.shape(weight) + + updated_axon = + Map.put( + axon, + output_name, + Axon.dense(input, units, use_bias: maybe_bias != [], name: output_name) + ) + + updated_params = + if maybe_bias == [] do + Map.put(used_params, output_name <> "_kernel", weight) + else + [bias] = maybe_bias + bias = input_or_param!(bias, params, axon, used_params) + + used_params + |> Map.put(output_name <> "_kernel", weight) + |> Map.put(output_name <> "_bias", bias) + end + + {updated_axon, updated_params} + end + end + + # Builds an Axon layer from an element-wise binary operation. Binary + # op is either an atom representing one of Axon's legitimate Binary op + # layers, or a function to be used in a custom layer. + # + # TODO(seanmor5): Verify broadcasting semantics + defp to_axon_binary_op( + %Node{input: [x, y], output: [output_name]}, + axon, + params, + used_params, + binary_op + ) do + inp1 = input_or_param!(x, params, axon, used_params) + inp2 = input_or_param!(y, params, axon, used_params) + + updated_axon = + case binary_op do + op when is_atom(op) -> + Map.put(axon, output_name, apply(Axon, op, [inp1, inp2, [name: output_name]])) + + fun when is_function(fun, 2) -> + # TODO(seanmor5): Use Axon.layer when shape inference improves + Map.put(axon, output_name, Axon.nx({inp1, inp2}, fun, name: output_name)) + end + + {updated_axon, used_params} + end + + defp to_axon_max_pool( + %Node{op_type: "MaxPool", input: [inp], attribute: attrs, output: [output_name]}, + axon, + params, + used_params + ) do + max_pool_options = options!(attrs) + + kernel_shape = max_pool_options["kernel_shape"] + strides = max_pool_options["strides"] + pads = max_pool_options["pads"] + auto_pad = max_pool_options["auto_pad"] + + kernel_size = List.to_tuple(kernel_shape) + padding_config = padding!(auto_pad, pads) + + inp = input_or_param!(inp, params, axon, used_params) + + updated_axon = + Map.put( + axon, + output_name, + Axon.max_pool(inp, + kernel_size: kernel_size, + strides: strides, + padding: padding_config, + name: output_name + ) + ) + + {updated_axon, used_params} + end + + defp to_axon_conv(%Node{op_type: "Conv"} = conv_node, axon, params, used_params) do + %{attribute: attrs, input: input, output: [output_name]} = conv_node + + conv_options = options!(attrs) + + auto_pad = conv_options["auto_pad"] + dilations = conv_options["dilations"] + group = conv_options["group"] + kernel_shape = conv_options["kernel_shape"] + pads = conv_options["pads"] + strides = conv_options["strides"] + + padding_config = padding!(auto_pad, pads) + kernel_size = List.to_tuple(kernel_shape) + + [inp, kernel | maybe_bias] = input + + axon_inp = input_or_param!(inp, params, axon, used_params) + + # Parameters can either be embedded in the graph as constants or + # passed as parameters + {axon_kernel, units} = + cond do + Map.has_key?(params, kernel) -> + kernel = params[kernel] + {kernel, elem(Nx.shape(kernel), 0)} + + Map.has_key?(axon, kernel) -> + %{output_shape: shape} = kernel = axon[kernel] + {kernel, elem(shape, 1)} + + true -> + raise "unable to find kernel for conv" + end + + updated_params = Map.put(used_params, output_name <> "_kernel", axon_kernel) + + updated_params = + if maybe_bias == [] do + updated_params + else + [bias] = maybe_bias + axon_bias = params[bias] + Map.put(updated_params, output_name <> "_bias", axon_bias) + end + + updated_axon = + Map.put( + axon, + output_name, + Axon.conv(axon_inp, units, + kernel_size: kernel_size, + feature_group_size: group, + padding: padding_config, + strides: strides, + use_bias: maybe_bias != [], + name: output_name + ) + ) + + {updated_axon, updated_params} + end + + defp to_axon_pad(%Node{op_type: "Pad", input: inputs, output: [output_name], attribute: attrs}, axon, params, used_params) do + pad_options = options!(attrs) + + case pad_options["mode"] do + "constant" -> + :ok + + nil -> + :ok + + mode -> + raise "unsupported padding mode #{inspect(mode)}" + end + + [data, pads | maybe_constant] = inputs + + inp = input_or_param!(data, params, axon, used_params) + # TODO(seanmor5): Pads should probably be scrubbed from the graph + # and parameters + pads = input_or_param!(pads, params, axon, used_params) + + padding_config = + pads.ints + |> Enum.chunk_every(2) + |> Enum.zip() + + constant_value = + case maybe_constant do + [] -> + 0 + + [value] -> + tensor!(value) + end + + updated_axon = Map.put(axon, output_name, Axon.pad(inp, padding_config, constant_value, name: output_name)) + + {updated_axon, used_params} + end + + # TODO(seanmor5): Mean and variance + defp to_axon_batch_norm( + %Node{ + op_type: "BatchNormalization", + input: [inp, gamma, beta, _mean, _var], + output: [output_name] + }, + axon, + params, + used_params + ) do + inp = input_or_param!(inp, params, axon, used_params) + + gamma = input_or_param!(gamma, params, axon, used_params) + beta = input_or_param!(beta, params, axon, used_params) + + updated_axon = Map.put(axon, output_name, Axon.batch_norm(inp, name: output_name)) + + updated_params = + used_params + |> Map.put(output_name <> "_gamma", gamma) + |> Map.put(output_name <> "_beta", beta) + + {updated_axon, updated_params} + end + + # Builds an axon activation layer with the given activation function. + # `activation` must be a legitimate Axon activation. `activation` functions + # are all element-wise with 1 input. Optionally has activation options. + # + # TODO(seanmor5): Handle activation options + defp to_axon_activation( + %Node{input: [inp], output: [output_name]}, + axon, + params, + used_params, + activation + ) do + inp = input_or_param!(inp, params, axon, used_params) + {Map.put(axon, output_name, Axon.activation(inp, activation, name: output_name)), used_params} + end + + defp to_axon_global_pool( + %Node{op_type: op_type, attribute: attrs, input: [inp], output: [output_name]}, + axon, + params, + used_params + ) do + inp = input_or_param!(inp, params, axon, used_params) + + updated_axon = + case op_type do + "GlobalAveragePool" -> + Map.put(axon, output_name, Axon.global_average_pool(inp, name: output_name)) + + "GlobalMaxPool" -> + Map.put(axon, output_name, Axon.global_max_pool(inp, name: output_name)) + + "GlobalLpPool" -> + lp_pool_options = options!(attrs) + + Map.put( + axon, + output_name, + Axon.global_lp_pool(inp, norm: lp_pool_options["p"], name: output_name) + ) + end + + {updated_axon, used_params} + end + + # Builds an Axon layer which returns a constant with the given + # value. Constants are embedded in custom layers which just yield + # the value of the constant here. They are not treated as parameters + defp to_axon_constant( + %Node{op_type: "Constant", attribute: attrs, output: [output_name]}, + axon, + _, + used_params + ) do + constant_options = options!(attrs) + + fun = + fn _ -> + cond do + constant_options["sparse_value"] -> + raise ArgumentError, "sparse tensors are not supported" + + constant_options["value"] -> + tensor!(constant_options["value"]) + + constant_options["value_float"] -> + Nx.tensor(constant_options["value_float"], type: {:f, 32}, backend: Nx.Defn.Expr) + + constant_options["value_floats"] -> + Nx.tensor(constant_options["value_floats"], type: {:f, 32}, backend: Nx.Defn.Expr) + + constant_options["value_int"] -> + Nx.tensor(constant_options["value_int"], type: {:s, 64}, backend: Nx.Defn.Expr) + + constant_options["value_ints"] -> + Nx.tensor(constant_options["value_ints"], type: {:s, 64}, backend: Nx.Defn.Expr) + + constant_options["value_string"] or constant_options["value_strings"] -> + raise ArgumentError, "string tensors are not supported" + + true -> + raise ArgumentError, "invalid constant tensor type" + end + end + + # TODO(seanmor5): Use Axon.layer when shape inference is supported + # TODO(seanmor5): Should layer support blank inputs for constants? + updated_axon = Map.put(axon, output_name, Axon.nx(Axon.input({nil, 1}), fun, name: output_name)) + + {updated_axon, used_params} + end + + defp to_axon_reshape( + %Node{op_type: "Reshape", input: [inp], attribute: attrs, output: [output_name]}, + axon, + params, + used_params + ) do + reshape_options = options!(attrs) + + inp = input_or_param!(inp, params, axon, used_params) + + new_shape = + reshape_options["shape"] + |> List.to_tuple() + + {Map.put(axon, output_name, Axon.reshape(inp, new_shape, name: output_name)), used_params} + end + + defp to_axon_flatten( + %Node{op_type: "Flatten", input: [inp], output: [output_name]}, + axon, + params, + used_params + ) do + inp = input_or_param!(inp, params, axon, used_params) + + {Map.put(axon, output_name, Axon.flatten(inp, name: output_name)), used_params} + end + + # Builds an Axon transpose layer. Transpose is given by + # the perm option in Node attribute. + defp to_axon_transpose( + %Node{op_type: "Transpose", input: [input], attribute: attrs, output: [output_name]}, + axon, + params, + used_params + ) do + inp = input_or_param!(input, params, axon, used_params) + + transpose_options = options!(attrs) + + permutation = transpose_options["perm"] + + updated_axon = + Map.put(axon, output_name, Axon.transpose(inp, permutation: permutation, name: output_name)) + + {updated_axon, used_params} + end + + # Builds an unsqueeze layer using a custom Nx layer with the given + # input and axes. + # + # TODO(seanmor5): Use Axon.layer + defp to_axon_unsqueeze(%Node{op_type: "Unsqueeze", input: [input], attribute: attrs, output: [output_name]}, axon, params, used_params) do + unsqueeze_options = options!(attrs) + + inp = input_or_param!(input, params, axon, used_params) + + axes = unsqueeze_options["axes"] + + fun = fn input -> + Enum.reduce(axes, input, fn axis, x -> Nx.new_axis(x, axis) end) + end + + case inp do + %Nx.Tensor{} = tensor -> + updated_params = Map.put(used_params, output_name, fun.(tensor)) + {axon, updated_params} + + %Axon{} = model -> + updated_axon = Map.put(axon, output_name, Axon.nx(model, fun, name: output_name)) + {updated_axon, used_params} + end + end + + defp to_axon_clip(%Node{op_type: "Clip", input: [input], attribute: attrs, output: [output_name]}, axon, params, used_params) do + clip_options = options!(attrs) + + inp = input_or_param!(input, params, axon, used_params) + + min = clip_options["min"] + max = clip_options["max"] + + fun = fn input -> + Nx.clip(input, min, max) + end + + updated_axon = Map.put(axon, output_name, Axon.nx(inp, fun, name: output_name)) + + {updated_axon, used_params} + end + + # TODO(seanmor5): Handle segments + defp tensor!(%T{data_type: dtype, dims: dims} = tensor) do + shape = List.to_tuple(dims) + + case dtype do + 1 -> + to_nx_tensor(tensor.float_data, tensor.raw_data, {:f, 32}, shape) + + 2 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:u, 8}, shape) + + 3 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:s, 8}, shape) + + 4 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:u, 16}, shape) + + 5 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:s, 16}, shape) + + 6 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:s, 32}, shape) + + 7 -> + to_nx_tensor(tensor.int64_data, tensor.raw_data, {:s, 64}, shape) + + 8 -> + raise "unsupported Nx tensor type: string" + + 9 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:u, 8}, shape) + + 10 -> + to_nx_tensor(tensor.int32_data, tensor.raw_data, {:f, 16}, shape) + + 11 -> + to_nx_tensor(tensor.double_data, tensor.raw_data, {:f, 64}, shape) + + 12 -> + to_nx_tensor(tensor.uint64_data, tensor.raw_data, {:u, 32}, shape) + + 13 -> + to_nx_tensor(tensor.uint64_data, tensor.raw_data, {:u, 64}, shape) + + 14 -> + # TODO(seanmor5): When complex is supported, tensor.float_data + raise "unsupported Nx tensor type: C64" + + 15 -> + # TODO(seanmor5): When complex is supported, tensor.double_data + raise "unsupported Nx tensor type: C128" + + 16 -> + to_nx_tensor([], tensor.raw_data, {:bf, 16}, shape) + end + end + + defp to_nx_tensor([], <<>>, _, _) do + raise "unsupported empty Nx tensor" + end + + defp to_nx_tensor([], raw, type, shape) do + raw + |> Nx.from_binary(type, backend: Nx.Defn.Expr) + |> Nx.reshape(shape) + end + + defp to_nx_tensor(data, _, type, shape) do + data + |> Nx.tensor(type: type, backend: Nx.Defn.Expr) + |> Nx.reshape(shape) + end + + defp input_or_param!(name, params, axon, used_params) do + cond do + Map.has_key?(params, name) -> + params[name] + + Map.has_key?(axon, name) -> + axon[name] + + Map.has_key?(used_params, name) -> + used_params[name] + + true -> + raise "unable to find value with name #{inspect(name)} in" <> + " parameters or model" + end + end + + defp padding!(auto_pad, pads) do + case auto_pad do + val when val == "NOTSET" or val == nil -> + pads + |> Enum.chunk_every(2) + |> Enum.zip() + + val when val == "SAME_UPPER" or val == "SAME_LOWER" -> + :same + + "VALID" -> + :valid + end + end + + defp options!(attrs) when is_list(attrs) do + Enum.reduce(attrs, %{}, fn %Attribute{type: type, name: name} = attr, options -> + case type do + :FLOAT -> + Map.put(options, name, attr.f) + + :INT -> + Map.put(options, name, attr.i) + + :STRING -> + Map.put(options, name, attr.s) + + :TENSOR -> + Map.put(options, name, attr.t) + + :GRAPH -> + Map.put(options, name, attr.g) + + :SPARSE_TENSOR -> + Map.put(options, name, attr.sparse_tensor) + + :TYPE_PROTO -> + Map.put(options, name, attr.tp) + + :FLOATS -> + Map.put(options, name, attr.floats) + + :INTS -> + Map.put(options, name, attr.ints) + + :STRINGS -> + Map.put(options, name, attr.strings) + + :TENSORS -> + Map.put(options, name, attr.tensors) + + :GRAPHS -> + Map.put(options, name, attr.graphs) + + :SPARSE_TENSORS -> + Map.put(options, name, attr.sparse_tensors) + + :TYPE_PROTOS -> + Map.put(options, name, attr.type_protos) + end + end) + end + + defp shape!(%Placeholder{shape: %Shape{dim: dims}}, dim_params) do + dims + |> Enum.map(fn %Dimension{value: value} -> + case value do + {:dim_value, val} -> + val + + {:dim_param, key} -> + unless Map.has_key?(dim_params, key) do + raise "dimension #{inspect(key)} not found in provided dimensions," <> + " you must specify unknown dimension shapes at import time" + end + + dim_params[key] + + _ -> + raise ArgumentError, "unsupported dimension type" + end + end) + |> List.to_tuple() + end +end diff --git a/mix.exs b/mix.exs new file mode 100644 index 0000000..756e3bc --- /dev/null +++ b/mix.exs @@ -0,0 +1,29 @@ +defmodule AxonOnnx.MixProject do + use Mix.Project + + def project do + [ + app: :axon_onnx, + version: "0.1.0", + elixir: "~> 1.12", + start_permanent: Mix.env() == :prod, + deps: deps() + ] + end + + # Run "mix help compile.app" to learn about applications. + def application do + [ + extra_applications: [:logger] + ] + end + + # Run "mix help deps" to learn about dependencies. + defp deps do + [ + {:protox, "~> 1.4.0"}, + {:nx, "~> 0.1.0-dev", path: "../exla/nx", override: true}, + {:axon, path: "../axon"} + ] + end +end diff --git a/mix.lock b/mix.lock new file mode 100644 index 0000000..fa40219 --- /dev/null +++ b/mix.lock @@ -0,0 +1,6 @@ +%{ + "axon": {:git, "https://github.com/elixir-nx/axon.git", "d5d1e019352c1aa4435f393bda1f0cd6f89197e8", []}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "939167274adf6f9776ca583518e4b1a7966f65ef", [sparse: "nx"]}, + "protox": {:hex, :protox, "1.4.0", "2fc940fd8b10c07b935cff4a447b0b57b45d0cb747a9c66df3757e0850d53a82", [:mix], [], "hexpm", "871f3038939c947159fc00e3b79127ea57c896204448bb5517fd666a4cca5313"}, + "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, +} diff --git a/priv/onnx.proto b/priv/onnx.proto new file mode 100644 index 0000000..fd36a08 --- /dev/null +++ b/priv/onnx.proto @@ -0,0 +1,743 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// SPDX-License-Identifier: Apache-2.0 + + +syntax = "proto3"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + IR_VERSION = 0x0000000000000008; + +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field heuristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accommodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + TypeProto tp = 14; // type proto + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto's. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} + + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; diff --git a/test/axon_onnx_test.exs b/test/axon_onnx_test.exs new file mode 100644 index 0000000..1be84ec --- /dev/null +++ b/test/axon_onnx_test.exs @@ -0,0 +1,8 @@ +defmodule AxonOnnxTest do + use ExUnit.Case + doctest AxonOnnx + + test "greets the world" do + assert AxonOnnx.hello() == :world + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs new file mode 100644 index 0000000..869559e --- /dev/null +++ b/test/test_helper.exs @@ -0,0 +1 @@ +ExUnit.start()