Skip to content

Commit

Permalink
Use model state struct instead of parameters (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored May 10, 2024
1 parent 19803a0 commit 1ccbeba
Show file tree
Hide file tree
Showing 12 changed files with 1,731 additions and 1,935 deletions.
7 changes: 2 additions & 5 deletions examples/vision/mnist.exs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
Mix.install([
{:axon, "~> 0.5"},
{:axon, path: "~/projects/axon"},
{:polaris, "~> 0.1"},
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:exla, ">= 0.0.0"},
{:scidata, "~> 0.1"}
])

defmodule Mnist do
require Axon

@batch_size 32
@image_side_pixels 28
@channel_value_max 255
Expand Down
229 changes: 26 additions & 203 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,6 @@ defmodule Axon do

require Logger

# Axon serialization version
@file_version 1

@type t :: %__MODULE__{}

defstruct [
Expand Down Expand Up @@ -417,16 +414,18 @@ defmodule Axon do
}
end

def param(name, shape, opts) when is_tuple(shape) or is_function(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32})
def param(name, shape, opts) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}
kind = opts[:kind] || :parameter

%Axon.Parameter{
name: name,
shape: shape,
type: type,
initializer: initializer
initializer: initializer,
kind: kind
}
end

Expand Down Expand Up @@ -586,7 +585,7 @@ defmodule Axon do
iex> inp1 = Axon.input("input_0", shape: {nil, 1})
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
iex> model = Axon.container(%{a: inp1, b: inp2})
iex> %{a: a, b: b} = Axon.predict(model, %{}, %{
iex> %{a: a, b: b} = Axon.predict(model, Axon.ModelState.empty(), %{
...> "input_0" => Nx.tensor([[1.0]]),
...> "input_1" => Nx.tensor([[1.0, 2.0]])
...> })
Expand Down Expand Up @@ -667,42 +666,6 @@ defmodule Axon do
end
end

@doc """
Wraps an Axon model into a namespace.
A namespace is a part of an Axon model which is meant to
be a self-contained collection of Axon layers. Namespaces
are guaranteed to always generate with the same internal
layer names and can be re-used universally across models.
Namespaces are most useful for containing large collections
of layers and offering a straightforward means for accessing
the parameters of individual model components. A common application
of namespaces is to use them in with a pre-trained model for
fine-tuning:
{base, resnet_params} = resnet()
base = base |> Axon.namespace("resnet")
model = base |> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model)
init_fn.(Nx.template({1, 3, 224, 224}, {:f, 32}), %{"resnset" => resnet_params})
Notice you can use `init_fn` in conjunction with namespaces
to specify which portion of a model you'd like to initialize
from a fixed starting point.
Namespaces have fixed names, which means it's easy to run into namespace
collisions. Re-using namespaces, re-using inner parts of a namespace,
and attempting to share layers between namespaces are still sharp
edges in namespace usage.
"""
@doc type: :special
def namespace(%Axon{} = axon, name) when is_binary(name) do
layer(:namespace, [axon], name: name)
end

@doc """
Returns a function which represents a self-contained re-usable block
of operations in a neural network. All parameters in the block are
Expand Down Expand Up @@ -1561,7 +1524,8 @@ defmodule Axon do
key_state =
param("key", fn _ -> {2} end,
type: {:u, 32},
initializer: fn _, _ -> Nx.Random.key(seed) end
initializer: fn _, _ -> Nx.Random.key(seed) end,
kind: :state
)

layer(dropout, [x, key_state],
Expand Down Expand Up @@ -1867,8 +1831,8 @@ defmodule Axon do
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])

mean = param("mean", mean_shape, initializer: :zeros)
var = param("var", var_shape, initializer: :ones)
mean = param("mean", mean_shape, initializer: :zeros, kind: :state)
var = param("var", var_shape, initializer: :ones, kind: :state)

layer(
norm,
Expand Down Expand Up @@ -3006,14 +2970,16 @@ defmodule Axon do
key_state =
param("key", fn _ -> {2} end,
type: {:u, 32},
initializer: fn _, _ -> Nx.Random.key(seed) end
initializer: fn _, _ -> Nx.Random.key(seed) end,
kind: :state
)

name =
case parent_name do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[rnn_type]}_#{state_name}_hidden_state"
count = op_counts[rnn_type] || 0
"#{Atom.to_string(rnn_type)}_#{count}_#{state_name}_hidden_state"
end

parent_name when is_binary(parent_name) ->
Expand Down Expand Up @@ -3042,9 +3008,16 @@ defmodule Axon do

arity == 3 ->
fun =
fn inputs, key, _opts ->
fn inputs, key, opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
initializer.(shape, {:f, 32}, key)
keys = Nx.Random.split(key)
out = initializer.(shape, {:f, 32}, keys[1])

if opts[:mode] == :train do
%Axon.StatefulOutput{output: out, state: %{"key" => keys[0]}}
else
out
end
end

{fun, [x, key_state]}
Expand Down Expand Up @@ -3168,6 +3141,7 @@ defmodule Axon do
the update process.
"""
@doc type: :model
@deprecated "Use Axon.ModelState.freeze/2 instead"
def freeze(model, fun_or_predicate \\ :all) do
freeze(model, fun_or_predicate, true)
end
Expand Down Expand Up @@ -3240,6 +3214,7 @@ defmodule Axon do
the update process.
"""
@doc type: :model
@deprecated "Use Axon.ModelState.freeze/2 instead"
def unfreeze(model, fun_or_predicate \\ :all) do
freeze(model, fun_or_predicate, false)
end
Expand Down Expand Up @@ -3410,7 +3385,7 @@ defmodule Axon do
out =
Nx.Defn.jit(
fn inputs ->
forward_fn.(init_fn.(inputs, %{}), inputs)
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
end,
compiler: Axon.Defn
).(inputs)
Expand Down Expand Up @@ -3864,158 +3839,6 @@ defmodule Axon do
end
end

# Serialization

@doc """
Serializes a model and its parameters for persisting
models to disk or elsewhere.
Model and parameters are serialized as a tuple, where the
model is converted to a recursive map to ensure compatibility
with future Axon versions and the parameters are serialized
using `Nx.serialize/2`. There is some additional metadata included
such as current serialization version for compatibility.
Serialization `opts` are forwarded to `Nx.serialize/2` and
`:erlang.term_to_binary/2` for controlling compression options.
## Examples
iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> {init_fn, _} = Axon.build(model)
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> {_, predict_fn} = Axon.build(saved_model)
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
f32[1][1]
[
[0.0]
]
>
"""
@doc type: :model
def serialize(%Axon{output: id, nodes: nodes}, params, opts \\ []) do
Logger.warning(
"Attempting to serialize an Axon model. Serialization is discouraged" <>
" and will be deprecated, then removed in future releases. You should" <>
" keep your model definitions as code and serialize your parameters using" <>
" `Nx.serialize/2`."
)

nodes =
Map.new(nodes, fn {k, %{op: op, op_name: op_name} = v} ->
validate_serialized_op!(op_name, op)
node_meta = Map.from_struct(v)
{k, Map.put(node_meta, :node, :node)}
end)

model_meta = %{output: id, nodes: nodes, axon: :axon}
params = Nx.serialize(params, opts)
:erlang.term_to_binary({@file_version, model_meta, params}, opts)
end

# TODO: Raise on next release
defp validate_serialized_op!(op_name, op) when is_function(op) do
fun_info = Function.info(op)

case fun_info[:type] do
:local ->
Logger.warning(
"Attempting to serialize anonymous function in #{inspect(op_name)} layer," <>
" this will result in errors during deserialization between" <>
" different processes, and will be unsupported in a future" <>
" release. You should instead use a fully-qualified MFA function" <>
" such as &Axon.Layers.dense/3"
)

{:type, :external} ->
:ok
end
end

defp validate_serialized_op!(_name, op) when is_atom(op), do: :ok

@doc """
Deserializes serialized model and parameters into a `{model, params}`
tuple.
It is the opposite of `Axon.serialize/3`.
## Examples
iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> {init_fn, _} = Axon.build(model)
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> {_, predict_fn} = Axon.build(saved_model)
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
f32[1][1]
[
[0.0]
]
>
"""
@doc type: :model
def deserialize(serialized, opts \\ []) do
Logger.warning(
"Attempting to deserialize a serialized Axon model. Deserialization" <>
" is discouraged and will be deprecated, then removed in future" <>
" releases. You should keep your model definitions as code and" <>
" serialize your parameters using `Nx.serialize/2`."
)

{1, model_meta, serialized_params} = :erlang.binary_to_term(serialized, opts)
%{nodes: nodes, output: id} = model_meta

nodes =
Map.new(nodes, fn {k, %{op_name: op_name, op: op} = v} ->
validate_deserialized_op!(op_name, op)

node_struct =
v
|> Map.delete(:node)
|> then(&struct(Axon.Node, &1))

{k, node_struct}
end)

model = %Axon{output: id, nodes: nodes}
params = Nx.deserialize(serialized_params, opts)
{model, params}
end

# TODO: Raise on next release
defp validate_deserialized_op!(op_name, op) when is_function(op) do
fun_info = Function.info(op)

case fun_info[:type] do
:local ->
Logger.warning(
"Attempting to deserialize anonymous function in #{inspect(op_name)} layer," <>
" this will result in errors during deserialization between" <>
" different processes, and will be unsupported in a future" <>
" release"
)

:external ->
unless function_exported?(fun_info[:module], fun_info[:name], fun_info[:arity]) do
Logger.warning(
"Attempting to deserialize model which depends on function" <>
" #{inspect(op)} in layer #{inspect(op_name)} which does not exist in" <>
" the current environment, check your dependencies"
)
end
end
end

defp validate_deserialized_op!(op, _op_name) when is_atom(op), do: :ok

## Helpers

@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++
Expand Down
Loading

0 comments on commit 1ccbeba

Please sign in to comment.