From 1857e4911fcd5fa7e8ecb455d988b851cce1a2ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 6 Mar 2024 20:14:31 +0100 Subject: [PATCH] Fix warnings (#560) --- lib/axon/defn.ex | 3 + lib/axon/optimizers.ex | 197 +-- lib/axon/schedules.ex | 193 +-- lib/axon/updates.ex | 955 +------------- mix.exs | 4 +- mix.lock | 10 +- test/axon/compiler_test.exs | 6 +- test/axon/optimizers_test.exs | 318 ----- test/axon/schedules_test.exs | 230 ---- test/axon/updates_test.exs | 2243 --------------------------------- 10 files changed, 82 insertions(+), 4077 deletions(-) delete mode 100644 test/axon/optimizers_test.exs delete mode 100644 test/axon/schedules_test.exs delete mode 100644 test/axon/updates_test.exs diff --git a/lib/axon/defn.ex b/lib/axon/defn.ex index e7313ee4..970ae86b 100644 --- a/lib/axon/defn.ex +++ b/lib/axon/defn.ex @@ -22,4 +22,7 @@ defmodule Axon.Defn do @impl true def __partitions_options__(_), do: raise("not implemented") + + @impl true + def __to_backend__(_), do: raise("not implemented") end diff --git a/lib/axon/optimizers.ex b/lib/axon/optimizers.ex index 8cadf3d6..c980ccd0 100644 --- a/lib/axon/optimizers.ex +++ b/lib/axon/optimizers.ex @@ -1,230 +1,53 @@ defmodule Axon.Optimizers do @moduledoc false - alias Polaris.Updates - @doc """ - Adabelief optimizer. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `0.0` - * `:eps_root` - numerical stability term. Defaults to `1.0e-16` - - ## References - - * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) - """ @deprecated "Use Polaris.Optimizers.adabelief/1 instead" def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do - Updates.scale_by_belief(opts) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.adabelief([learning_rate: learning_rate] ++ opts) end - @doc """ - Adagrad optimizer. - - ## Options - - * `:eps` - numerical stability term. Defaults to `1.0e-7` - - ## References - - * [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - """ @deprecated "Use Polaris.Optimizers.adagrad/1 instead" def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do - Updates.scale_by_rss(opts) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.adagrad([learning_rate: learning_rate] ++ opts) end - @doc """ - Adam optimizer. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - * `:eps_root` - numerical stability term. Defaults to `1.0e-15` - - ## References - - * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) - """ @deprecated "Use Polaris.Optimizers.adam/1 instead" def adam(learning_rate \\ 1.0e-3, opts \\ []) do - Updates.scale_by_adam(opts) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.adam([learning_rate: learning_rate] ++ opts) end - @doc """ - Adam with weight decay optimizer. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - * `:eps_root` - numerical stability term. Defaults to `0.0` - * `:decay` - weight decay. Defaults to `0.0` - """ @deprecated "Use Polaris.Optimizers.adamw/1 instead" def adamw(learning_rate \\ 1.0e-3, opts \\ []) do - {decay, opts} = Keyword.pop(opts, :decay, 0.0) - - Updates.scale_by_adam(opts) - |> Updates.add_decayed_weights(decay: decay) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.adamw([learning_rate: learning_rate] ++ opts) end - @doc """ - Lamb optimizer. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - * `:eps_root` - numerical stability term. Defaults to `0.0` - * `:decay` - weight decay. Defaults to `0.0` - * `:min_norm` - minimum norm value. Defaults to `0.0` - - ## References - - * [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) - """ @deprecated "Use Polaris.Optimizers.lamb/1 instead" def lamb(learning_rate \\ 1.0e-2, opts \\ []) do - {decay, opts} = Keyword.pop(opts, :decay, 0.0) - {min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0) - - Updates.scale_by_adam(opts) - |> Updates.add_decayed_weights(decay: decay) - |> Updates.scale_by_trust_ratio(min_norm: min_norm) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.lamb([learning_rate: learning_rate] ++ opts) end - @doc """ - Noisy SGD optimizer. - - ## Options - - * `:eta` - used to compute variance of noise distribution. Defaults to `0.1` - * `:gamma` - used to compute variance of noise distribution. Defaults to `0.55` - """ @deprecated "Use Polaris.Optimizers.noisy_sgd/1 instead" def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do - scale_by_learning_rate(learning_rate) - |> Updates.add_noise(opts) + Polaris.Optimizers.noisy_sgd([learning_rate: learning_rate] ++ opts) end - @doc """ - Rectified Adam optimizer. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - * `:eps_root` - numerical stability term. Defaults to `0.0` - * `:threshold` - threshold term. Defaults to `5.0` - - ## References - - * [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf) - """ @deprecated "Use Polaris.Optimizers.radam/1 instead" def radam(learning_rate \\ 1.0e-3, opts \\ []) do - Updates.scale_by_radam(opts) - |> scale_by_learning_rate(learning_rate) + Polaris.Optimizers.radam([learning_rate: learning_rate] ++ opts) end - @doc """ - RMSProp optimizer. - - ## Options - - * `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false` - * `:momentum` - momentum term. If set, uses SGD with momentum and decay set - to value of this term. - * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` - * `:initial_scale` - initial value of EMA. Defaults to `0.0` - * `:decay` - EMA decay rate. Defaults to `0.9` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - """ @deprecated "Use Polaris.Optimizers.rmsprop/1 instead" def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do - {centered, opts} = Keyword.pop(opts, :centered, false) - {nesterov?, opts} = Keyword.pop(opts, :nesterov, false) - {momentum, opts} = Keyword.pop(opts, :momentum, nil) - - combinator = - if centered do - Updates.scale_by_stddev(opts) - else - Updates.scale_by_rms(opts) - end - |> scale_by_learning_rate(learning_rate) - - if momentum, - do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?), - else: combinator + Polaris.Optimizers.rmsprop([learning_rate: learning_rate] ++ opts) end - @doc """ - SGD optimizer. - - ## Options - - * `:momentum` - momentum term. If set, uses SGD with momentum and decay set - to value of this term. - * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` - """ @deprecated "Use Polaris.Optimizers.sgd/1 instead" def sgd(learning_rate \\ 1.0e-2, opts \\ []) do - momentum = opts[:momentum] - nesterov? = opts[:nesterov] || false - - if momentum do - Updates.trace(decay: momentum, nesterov: nesterov?) - |> scale_by_learning_rate(learning_rate) - else - scale_by_learning_rate(learning_rate) - end + Polaris.Optimizers.sgd([learning_rate: learning_rate] ++ opts) end - @doc """ - Yogi optimizer. - - ## Options - - * `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0` - * `:b1` - first moment decay. Defaults to `0.9` - * `:b2` - second moment decay. Defaults to `0.999` - * `:eps` - numerical stability term. Defaults to `1.0e-8` - * `:eps_root` - numerical stability term. Defaults to `0.0` - - ## References - - * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) - """ @deprecated "Use Polaris.Optimizers.yogi/1 instead" def yogi(learning_rate \\ 1.0e-2, opts \\ []) do - Updates.scale_by_yogi(opts) - |> scale_by_learning_rate(learning_rate) - end - - ## Helpers - - defp scale_by_learning_rate(combinator \\ Updates.identity(), lr) - - defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do - Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end) - end - - defp scale_by_learning_rate(combinator, lr) do - Updates.scale_by_state(combinator, -lr) + Polaris.Optimizers.yogi([learning_rate: learning_rate] ++ opts) end end diff --git a/lib/axon/schedules.ex b/lib/axon/schedules.ex index 62f07e07..6a712f93 100644 --- a/lib/axon/schedules.ex +++ b/lib/axon/schedules.ex @@ -1,201 +1,18 @@ defmodule Axon.Schedules do @moduledoc false - import Nx.Defn - @doc """ - Linear decay schedule. - - ## Options - - * `:warmup` - scheduler warmup steps. Defaults to `0` - - * `:steps` - total number of decay steps. Defaults to `1000` - """ @deprecated "Use Polaris.Schedules.linear_decay/2 instead" - def linear_decay(init_value, opts \\ []) do - &apply_linear_decay(&1, [{:init_value, init_value} | opts]) - end - - defnp apply_linear_decay(step, opts \\ []) do - opts = - keyword!(opts, - init_value: 1.0e-2, - warmup: 0, - steps: 1000 - ) - - if step < opts[:warmup] do - step / Nx.max(1, opts[:warmup]) - else - Nx.max(0.0, (opts[:steps] - step) / Nx.max(1, opts[:steps] - opts[:warmup])) - end - end - - @doc ~S""" - Exponential decay schedule. - - $$\gamma(t) = \gamma_0 * r^{\frac{t}{k}}$$ - - ## Options - - * `:decay_rate` - rate of decay. $r$ in above formulation. - Defaults to `0.95` - - * `:transition_steps` - steps per transition. $k$ in above - formulation. Defaults to `10` + defdelegate linear_decay(init_value, opts \\ []), to: Polaris.Schedules - * `:transition_begin` - step to begin transition. Defaults to `0` - - * `:staircase` - discretize outputs. Defaults to `false` - - """ @deprecated "Use Polaris.Schedules.exponential_decay/2 instead" - def exponential_decay(init_value, opts \\ []) do - &apply_exponential_decay(&1, [{:init_value, init_value} | opts]) - end - - defnp apply_exponential_decay(step, opts \\ []) do - opts = - keyword!(opts, - init_value: 1.0e-2, - decay_rate: 0.95, - transition_steps: 10, - transition_begin: 0, - staircase: false - ) - - init_value = opts[:init_value] - rate = opts[:decay_rate] - staircase? = opts[:staircase] - k = opts[:transition_steps] - start = opts[:transition_begin] - - t = Nx.subtract(step, start) - - p = - if staircase? do - t - |> Nx.divide(k) - |> Nx.floor() - else - t - |> Nx.divide(k) - end - - decayed_value = - rate - |> Nx.pow(p) - |> Nx.multiply(init_value) - - Nx.select( - Nx.less_equal(t, 0), - init_value, - decayed_value - ) - end - - @doc ~S""" - Cosine decay schedule. - - $$\gamma(t) = \gamma_0 * \left(\frac{1}{2}(1 - \alpha)(1 + \cos\pi \frac{t}{k}) + \alpha\right)$$ + defdelegate exponential_decay(init_value, opts \\ []), to: Polaris.Schedules - ## Options - - * `:decay_steps` - number of steps to apply decay for. - $k$ in above formulation. Defaults to `10` - - * `:alpha` - minimum value of multiplier adjusting learning rate. - $\alpha$ in above formulation. Defaults to `0.0` - - ## References - - * [SGDR: Stochastic Gradient Descent with Warm Restarts](https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) - - """ @deprecated "Use Polaris.Schedules.cosine_decay/2 instead" - def cosine_decay(init_value, opts \\ []) do - &apply_cosine_decay(&1, [{:init_value, init_value} | opts]) - end + defdelegate cosine_decay(init_value, opts \\ []), to: Polaris.Schedules - defnp apply_cosine_decay(step, opts \\ []) do - opts = keyword!(opts, init_value: 1.0e-2, decay_steps: 10, alpha: 0.0) - init_value = opts[:init_value] - decay_steps = opts[:decay_steps] - alpha = opts[:alpha] - - step - |> Nx.min(decay_steps) - |> Nx.divide(decay_steps) - |> Nx.multiply(3.1415926535897932384626433832795028841971) - |> Nx.cos() - |> Nx.add(1) - |> Nx.divide(2) - |> Nx.multiply(1 - alpha) - |> Nx.add(alpha) - |> Nx.multiply(init_value) - end - - @doc ~S""" - Constant schedule. - - $$\gamma(t) = \gamma_0$$ - - """ @deprecated "Use Polaris.Schedules.constant/2 instead" - def constant(init_value, opts \\ []) do - &apply_constant(&1, [{:init_value, init_value} | opts]) - end - - defnp apply_constant(_step, opts \\ []) do - opts = keyword!(opts, init_value: 0.01) - opts[:init_value] - end + defdelegate constant(init_value, opts \\ []), to: Polaris.Schedules - @doc ~S""" - Polynomial schedule. - - $$\gamma(t) = (\gamma_0 - \gamma_n) * (1 - \frac{t}{k})^p$$ - - ## Options - - * `:end_value` - end value of annealed scalar. $\gamma_n$ in above formulation. - Defaults to `1.0e-3` - - * `:power` - power of polynomial. $p$ in above formulation. Defaults to `2` - - * `:transition_steps` - number of steps over which annealing takes place. - $k$ in above formulation. Defaults to `10` - - """ @deprecated "Use Polaris.Schedules.polynomial_decay/2 instead" - def polynomial_decay(init_value, opts \\ []) do - &apply_polynomial_decay(&1, [{:init_value, init_value} | opts]) - end - - defnp apply_polynomial_decay(step, opts \\ []) do - opts = - keyword!(opts, - init_value: 1.0e-2, - end_value: 1.0e-3, - power: 2, - transition_steps: 10, - transition_begin: 0 - ) - - init_value = opts[:init_value] - end_value = opts[:end_value] - start = opts[:transition_begin] - k = opts[:transition_steps] - p = opts[:power] - - step - |> Nx.subtract(start) - |> Nx.clip(0, k) - |> Nx.divide(k) - |> Nx.negate() - |> Nx.add(1) - |> Nx.pow(p) - |> Nx.multiply(Nx.subtract(init_value, end_value)) - |> Nx.add(end_value) - end + defdelegate polynomial_decay(init_value, opts \\ []), to: Polaris.Schedules end diff --git a/lib/axon/updates.ex b/lib/axon/updates.ex index 06e0b144..d970e759 100644 --- a/lib/axon/updates.ex +++ b/lib/axon/updates.ex @@ -2,965 +2,120 @@ defmodule Axon.Updates do @moduledoc false import Nx.Defn - import Axon.Shared - @doc ~S""" - Scales input by a fixed step size. - - $$f(x_i) = \alpha x_i$$ - """ @deprecated "Use Polaris.Updates.scale/2 instead" - def scale(combinator \\ identity(), step_size) do - stateless(combinator, &apply_scale(&1, &2, step_size)) - end - - defnp apply_scale(updates, _params, step) do - deep_new(updates, fn v -> Nx.multiply(v, step) end) - end - - @doc ~S""" - Scales input by a tunable learning rate which can be - manipulated by external APIs such as Axon's Loop API. + defdelegate scale(combinator \\ identity(), step_size), to: Polaris.Updates - $$f(x_i) = \alpha x_i$$ - """ @deprecated "Use Polaris.Updates.scale_by_state/1 instead" - def scale_by_state(combinator_or_step) - - def scale_by_state(step) when is_number(step) do - scale_by_state(identity(), step) - end - - def scale_by_state({init_fn, apply_fn} = combinator, step) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_number(step) do - stateful(combinator, &init_scale_by_state(&1, init_scale: step), &apply_scale_by_state/3) - end - - defnp init_scale_by_state(_params, opts \\ []) do - opts = keyword!(opts, [:init_scale]) - %{scale: opts[:init_scale]} - end - - defnp apply_scale_by_state(x, %{scale: scale} = state, params) do - {apply_scale(x, params, scale), state} - end - - @doc """ - Scales input according to Adam algorithm. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - - * `:b2` - second moment decay. Defaults to `0.999` - - * `:eps` - numerical stability term. Defaults to `1.0e-8` + defdelegate scale_by_state(combinator_or_step), to: Polaris.Updates - * `:eps_root` - numerical stability term. Defaults to `1.0e-15` + @deprecated "Use Polaris.Updates.scale_by_state/2 instead" + defdelegate scale_by_state(combinator, step), to: Polaris.Updates - ## References - - * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) - - """ @deprecated "Use Polaris.Updates.scale_by_adam/1 instead" - def scale_by_adam(combinator_or_opts \\ []) - - def scale_by_adam(opts) when is_list(opts) do - scale_by_adam(identity(), opts) - end - - def scale_by_adam({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_adam(combinator, []) - end - - def scale_by_adam({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful( - combinator, - &init_scale_by_adam/1, - &apply_scale_by_adam(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_adam(params) do - mus = zeros_like(params, type: :f32) - nus = zeros_like(params, type: :f32) - count = Nx.tensor(0) - %{mu: mus, nu: nus, count: count} - end - - defnp apply_scale_by_adam(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do - opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 1.0e-15) - b1 = opts[:b1] - b2 = opts[:b2] - eps = opts[:eps] - eps_root = opts[:eps_root] - - mu = update_moment(x, mu, b1, 1) - nu = update_moment(x, nu, b2, 2) + defdelegate scale_by_adam(combinator_or_opts \\ []), to: Polaris.Updates - mu_hat = bias_correction(mu, b1, count + 1) - nu_hat = bias_correction(nu, b2, count + 1) - - x = deep_merge(mu_hat, nu_hat, fn z, t -> z / (Nx.sqrt(t + eps_root) + eps) end) - {x, %{mu: mu, nu: nu, count: count + 1}} - end + @deprecated "Use Polaris.Updates.scale_by_adam/2 instead" + defdelegate scale_by_adam(combinator, opts), to: Polaris.Updates - @doc """ - Scales input by the root of all prior squared inputs. - - ## Options - - * `:eps` - numerical stability term. Defaults to `1.0e-7` - - """ @deprecated "Use Polaris.Updates.scale_by_rss/1 instead" - def scale_by_rss(combinator_or_opts \\ []) + defdelegate scale_by_rss(combinator_or_opts \\ []), to: Polaris.Updates - def scale_by_rss(opts) when is_list(opts) do - scale_by_rss(identity(), opts) - end - - def scale_by_rss({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_rss(combinator, []) - end - - def scale_by_rss({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - {initial, opts} = Keyword.pop(opts, :initial_accumulator_value, 0.1) - - stateful( - combinator, - &init_scale_by_rss(&1, initial), - &apply_scale_by_rss(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_rss(params, value) do - sum_of_squares = fulls_like(params, value, type: :f32) - %{sum_of_squares: sum_of_squares} - end - - defnp apply_scale_by_rss(x, %{sum_of_squares: sum_of_squares}, _params, opts \\ []) do - opts = keyword!(opts, eps: 1.0e-7) - eps = opts[:eps] - - sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end) - - inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end) - - inv_sqrt_x_square = - deep_merge(sum_of_squares, inv_sqrt_squares, fn z, t -> - Nx.select(Nx.greater(z, 0), t, 0.0) - end) - - x = deep_merge(x, inv_sqrt_x_square, fn g, t -> g * t end) - - {x, %{sum_of_squares: sum_of_squares}} - end - - @doc """ - Scales input by the root of the EMA of squared inputs. - - ## Options - - * `:decay` - EMA decay rate. Defaults to `0.9`. - - * `:eps` - numerical stability term. Defaults to `1.0e-8`. - - ## References - - * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + @deprecated "Use Polaris.Updates.scale_by_rss/1 instead" + defdelegate scale_by_rss(combinator, opts), to: Polaris.Updates - """ @deprecated "Use Polaris.Updates.scale_by_rms/1 instead" - def scale_by_rms(combinator_or_opts \\ []) - - def scale_by_rms(opts) when is_list(opts) do - scale_by_rms(identity(), opts) - end - - def scale_by_rms({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_rms(combinator, []) - end - - def scale_by_rms({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - {initial, opts} = Keyword.pop(opts, :initial_scale, 0.0) - - stateful( - combinator, - &init_scale_by_rms(&1, initial), - &apply_scale_by_rms(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_rms(params, scale) do - nu = fulls_like(params, scale, type: :f32) - %{nu: nu} - end - - defnp apply_scale_by_rms(x, %{nu: nu}, _params, opts \\ []) do - opts = keyword!(opts, decay: 0.9, eps: 1.0e-8) - decay = opts[:decay] - eps = opts[:eps] - - nu = update_moment(x, nu, decay, 2) - - x = deep_merge(x, nu, fn g, t -> Nx.rsqrt(t + eps) * g end) - - {x, %{nu: nu}} - end + defdelegate scale_by_rms(combinator_or_opts \\ []), to: Polaris.Updates - @doc """ - Scales input according to the AdaBelief algorithm. + @deprecated "Use Polaris.Updates.scale_by_rms/2 instead" + defdelegate scale_by_rms(combinator, opts), to: Polaris.Updates - ## Options - - * `:b1` - first moment decay. Defaults to `0.9`. - - * `:b2` - second moment decay. Defaults to `0.999`. - - * `:eps` - numerical stability term. Defaults to `0.0`. - - * `:eps_root` - numerical stability term. Defaults to `1.0e-16`. - - ## References - - * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) - - """ @deprecated "Use Polaris.Updates.scale_by_belief/1 instead" - def scale_by_belief(combinator_or_opts \\ []) - - def scale_by_belief(opts) when is_list(opts) do - scale_by_belief(identity(), opts) - end - - def scale_by_belief({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_belief(combinator, []) - end - - def scale_by_belief({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful( - combinator, - &init_scale_by_belief/1, - &apply_scale_by_belief(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_belief(params) do - mus = zeros_like(params, type: :f32) - nus = zeros_like(params, type: :f32) - count = Nx.tensor(0) - %{mu: mus, nu: nus, count: count} - end - - defnp apply_scale_by_belief(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do - opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 0.0, eps_root: 1.0e-16) - b1 = opts[:b1] - b2 = opts[:b2] - eps = opts[:eps] - eps_root = opts[:eps_root] - - mu = update_moment(x, mu, b1, 1) - nu = update_moment(x, nu, b2, 2) - - mu_hat = bias_correction(mu, b1, count + 1) - nu_hat = bias_correction(nu, b2, count + 1) + defdelegate scale_by_belief(combinator_or_opts \\ []), to: Polaris.Updates - x = deep_merge(mu_hat, nu_hat, fn z, t -> 1 / (Nx.sqrt(t + eps_root) + eps) * z end) + @deprecated "Use Polaris.Updates.scale_by_belief/2 instead" + defdelegate scale_by_belief(combinator, opts), to: Polaris.Updates - {x, %{mu: mu, nu: nu, count: count + 1}} - end - - @doc """ - Scales input by the root of the centered EMA of squared inputs. - - ## Options - - * `:decay` - EMA decay rate. Defaults to `0.9`. - - * `:eps` - numerical stability term. Defaults to `1.0e-8`. - - ## References - - * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - """ @deprecated "Use Polaris.Updates.scale_by_stddev/1 instead" - def scale_by_stddev(combinator_or_opts \\ []) - - def scale_by_stddev(opts) when is_list(opts) do - scale_by_stddev(identity(), opts) - end - - def scale_by_stddev({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_stddev(combinator, []) - end - - def scale_by_stddev({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - {initial, opts} = Keyword.pop(opts, :initial_scale, 0.0) - - stateful( - combinator, - &init_scale_by_stddev(&1, initial), - &apply_scale_by_stddev(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_stddev(params, value) do - mu = zeros_like(params, type: :f32) - nu = fulls_like(params, value, type: :f32) - %{mu: mu, nu: nu} - end + defdelegate scale_by_stddev(combinator_or_opts \\ []), to: Polaris.Updates - defnp apply_scale_by_stddev(x, %{mu: mu, nu: nu}, _params, opts \\ []) do - opts = keyword!(opts, decay: 0.9, eps: 1.0e-8) - decay = opts[:decay] - eps = opts[:eps] + @deprecated "Use Polaris.Updates.scale_by_stddev/2 instead" + defdelegate scale_by_stddev(combinator, opts), to: Polaris.Updates - mu = update_moment(x, mu, decay, 1) - nu = update_moment(x, nu, decay, 2) - - mu_nu = - deep_merge(mu, nu, fn m, n -> - Nx.rsqrt(-Nx.pow(m, 2) + n + eps) - end) - - x = deep_merge(x, mu_nu, fn g, mn -> g * mn end) - - {x, %{mu: mu, nu: nu}} - end - - @doc """ - Scales input using the given schedule function. - - This can be useful for implementing learning rate schedules. - The number of update iterations is tracked by an internal - counter. You might need to update the schedule to operate - on per-batch schedule rather than per-epoch. - """ @deprecated "Use Polaris.Updates.scale_by_schedule/2 instead" - def scale_by_schedule(combinator \\ identity(), schedule_fn) when is_function(schedule_fn, 1) do - stateful( - combinator, - &init_scale_by_schedule/1, - &apply_scale_by_schedule(&1, &2, &3, schedule_fn) - ) - end - - defnp init_scale_by_schedule(_) do - %{count: Nx.tensor(0)} - end - - defnp apply_scale_by_schedule(x, %{count: count}, _params, schedule_fn) do - step_size = schedule_fn.(count) - - updates = deep_new(x, fn x -> x * step_size end) - - {updates, %{count: count + 1}} - end - - @doc """ - Scale input according to the Rectified Adam algorithm. - - ## Options - - * `:b1` - first moment decay. Defaults to `0.9` - - * `:b2` - second moment decay. Defaults to `0.999` - - * `:eps` - numerical stability term. Defaults to `1.0e-8` + defdelegate scale_by_schedule(combinator \\ identity(), schedule_fn), to: Polaris.Updates - * `:eps_root` - numerical stability term. Defaults to `0.0` - - * `:threshold` - threshold for variance. Defaults to `5.0` - - ## References - - * [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) - - """ @deprecated "Use Polaris.Updates.scale_by_radam/1 instead" - def scale_by_radam(combinator_or_opts \\ []) - - def scale_by_radam(opts) when is_list(opts) do - scale_by_radam(identity(), opts) - end - - def scale_by_radam({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_radam(combinator, []) - end - - def scale_by_radam({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful( - combinator, - &init_scale_by_radam/1, - &apply_scale_by_radam(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_radam(params) do - mu = zeros_like(params, type: :f32) - nu = zeros_like(params, type: :f32) - count = Nx.tensor(0) - %{mu: mu, nu: nu, count: count} - end - - defnp apply_scale_by_radam(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do - opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 0.0, threshold: 5.0) - b1 = opts[:b1] - b2 = opts[:b2] - eps = opts[:eps] - eps_root = opts[:eps_root] - threshold = opts[:threshold] - - ro_inf = 2.0 / (1 - b1) - 1 - - mu = update_moment(x, mu, b1, 1) - nu = update_moment(x, nu, b2, 2) - count_inc = count + 1 - - b2t = Nx.pow(b2, count_inc) - ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) - - mu_hat = bias_correction(mu, b1, count + 1) - nu_hat = bias_correction(nu, b2, count + 1) + defdelegate scale_by_radam(combinator_or_opts \\ []), to: Polaris.Updates - x = - if Nx.all(Nx.greater_equal(ro, threshold)) do - radam_update(ro, ro_inf, mu_hat, nu_hat, eps_root, eps) - else - mu_hat - end + @deprecated "Use Polaris.Updates.scale_by_radam/2 instead" + defdelegate scale_by_radam(combinator, opts), to: Polaris.Updates - {x, %{mu: mu, nu: nu, count: count + 1}} - end - - defnp radam_update(ro, ro_inf, mu, nu, eps_root, eps) do - r = Nx.sqrt((ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) - - deep_merge(mu, nu, fn m, v -> - r * m / (Nx.sqrt(v + eps_root) + eps) - end) - end - - @doc """ - Trace inputs with past inputs. - - ## Options - - * `:decay` - decay rate for tracing past updates. Defaults - to `0.9` - * `:nesterov` - whether to use Nesterov momentum. Defaults - to `false` - - """ @deprecated "Use Polaris.Updates.trace/1 instead" - def trace(combinator_or_opts \\ []) - - def trace(opts) when is_list(opts) do - trace(identity(), opts) - end - - def trace({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - trace(combinator, []) - end - - def trace({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful( - combinator, - &init_trace/1, - &apply_trace(&1, &2, &3, opts) - ) - end - - defnp init_trace(params) do - trace = zeros_like(params, type: :f32) - %{trace: trace} - end - - defnp apply_trace(x, %{trace: trace}, _params, opts \\ []) do - opts = keyword!(opts, decay: 0.9, nesterov: false) - decay = opts[:decay] - - update_trace = deep_merge(x, trace, fn g, t -> t * decay + g end) + defdelegate trace(combinator_or_opts \\ []), to: Polaris.Updates - x = - if opts[:nesterov] do - deep_merge(x, update_trace, fn g, t -> t * decay + g end) - else - update_trace - end + @deprecated "Use Polaris.Updates.trace/2 instead" + defdelegate trace(combinator, opts), to: Polaris.Updates - {x, %{trace: update_trace}} - end - - @doc """ - Clips input between -delta and delta. - - ## Options - - * `:delta` - maximum absolute value of the input. Defaults - to `2.0` - """ @deprecated "Use Polaris.Updates.clip/1 instead" - def clip(combinator_or_opts \\ []) - - def clip(opts) when is_list(opts) do - clip(identity(), opts) - end - - def clip({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - clip(combinator, []) - end - - def clip({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateless(combinator, &apply_clip(&1, &2, opts)) - end - - defnp apply_clip(x, _params, opts \\ []) do - opts = keyword!(opts, delta: 2.0) - delta = opts[:delta] - - deep_new(x, fn g -> Nx.clip(g, -delta, delta) end) - end - - @doc """ - Clips input using input global norm. + defdelegate clip(combinator_or_opts \\ []), to: Polaris.Updates - ## Options + @deprecated "Use Polaris.Updates.clip/2 instead" + defdelegate clip(combinator, opts), to: Polaris.Updates - * `:max_norm` - maximum norm value of input. Defaults to - `1.0` - """ @deprecated "Use Polaris.Updates.clip_by_global_norm/1 instead" - def clip_by_global_norm(combinator_or_opts \\ []) + defdelegate clip_by_global_norm(combinator_or_opts \\ []), to: Polaris.Updates - def clip_by_global_norm(opts) when is_list(opts) do - clip_by_global_norm(identity(), opts) - end - - def clip_by_global_norm({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - clip_by_global_norm(combinator, []) - end + @deprecated "Use Polaris.Updates.clip_by_global_norm/2 instead" + defdelegate clip_by_global_norm(combinator, opts), to: Polaris.Updates - def clip_by_global_norm({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateless(combinator, &apply_clip_by_global_norm(&1, &2, opts)) - end - - defnp apply_clip_by_global_norm(x, _params, opts \\ []) do - opts = keyword!(opts, max_norm: 1.0) - max_norm = opts[:max_norm] - - sum_gs = - deep_reduce(x, Nx.tensor(0.0), fn leaf, acc -> - leaf - |> Nx.pow(2) - |> Nx.sum() - |> Nx.add(acc) - end) - - g_norm = Nx.sqrt(sum_gs) - - deep_new(x, fn z -> - Nx.select(Nx.less(g_norm, max_norm), z, z / g_norm * max_norm) - end) - end - - @doc """ - Centralizes input by shifting updates by their mean. - """ @deprecated "Use Polaris.Updates.centralize/1 instead" - def centralize(combinator_or_opts \\ []) - - def centralize(opts) when is_list(opts) do - centralize(identity(), opts) - end - - def centralize({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - centralize(combinator, []) - end - - def centralize({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateless(combinator, &apply_centralize(&1, &2, opts)) - end - - defnp apply_centralize(x, _params, _opts \\ []) do - deep_new(x, ¢ralize_for_rank/1) - end - - deftransformp centralize_for_rank(input) do - if Nx.rank(input) > 1 do - input - |> Nx.subtract(Nx.mean(input, axes: tl(Nx.axes(input)), keep_axes: true)) - else - input - end - end - - @doc """ - Adds decayed weights to updates. - - Commonly used as a regularization strategy. + defdelegate centralize(combinator_or_opts \\ []), to: Polaris.Updates - ## Options + @deprecated "Use Polaris.Updates.centralize/2 instead" + defdelegate centralize(combinator, opts), to: Polaris.Updates - * `:decay` - Rate of decay. Defaults to `0.0`. - """ @deprecated "Use Polaris.Updates.add_decayed_weights/1 instead" - def add_decayed_weights(combinator_or_opts \\ []) + defdelegate add_decayed_weights(combinator_or_opts \\ []), to: Polaris.Updates - def add_decayed_weights(opts) when is_list(opts) do - add_decayed_weights(identity(), opts) - end - - def add_decayed_weights({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - add_decayed_weights(combinator, []) - end - - def add_decayed_weights({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateless(combinator, fn updates, params -> - opts = Nx.Defn.Kernel.keyword!(opts, decay: 0.0) - # Decay can be a tensor, that's why we preprocess it before-hand - # and pass it as argument to defn instead of as an option. - apply_weight_decay(updates, params, opts[:decay]) - end) - end - - defnp apply_weight_decay(updates, params, decay) do - deep_merge(updates, params, fn g, p -> g + decay * p end) - end - - @doc """ - Scale by trust ratio. - - ## Options + @deprecated "Use Polaris.Updates.add_decayed_weights/2 instead" + defdelegate add_decayed_weights(combinator, opts), to: Polaris.Updates - * `:min_norm` - Min norm to clip. Defaults to - `0.0`. - - * `:trust_coefficient` - Trust coefficient. Defaults - to `1.0`. - - * `:eps` - Numerical stability term. Defaults to `0.0`. - """ @deprecated "Use Polaris.Updates.scale_by_trust_ratio/1 instead" - def scale_by_trust_ratio(combinator_or_opts \\ []) - - def scale_by_trust_ratio(opts) when is_list(opts) do - scale_by_trust_ratio(identity(), opts) - end - - def scale_by_trust_ratio({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_trust_ratio(combinator, []) - end - - def scale_by_trust_ratio({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateless(combinator, fn update, params -> - opts = Nx.Defn.Kernel.keyword!(opts, min_norm: 0.0, trust_coefficient: 1.0, eps: 0.0) - - apply_scale_by_trust_ratio( - update, - params, - opts[:min_norm], - opts[:trust_coefficient], - opts[:eps] - ) - end) - end + defdelegate scale_by_trust_ratio(combinator_or_opts \\ []), to: Polaris.Updates - defnp apply_scale_by_trust_ratio(updates, params, min_norm, trust_coefficient, eps) do - deep_merge(updates, params, fn g, p -> - param_norm = safe_norm(p, min_norm) - update_norm = safe_norm(g, min_norm) + @deprecated "Use Polaris.Updates.scale_by_trust_ratio/2 instead" + defdelegate scale_by_trust_ratio(combinator, opts), to: Polaris.Updates - trust_ratio = trust_coefficient * param_norm / (update_norm + eps) - - zero_norm = param_norm == 0.0 or update_norm == 0.0 - safe_trust_ratio = Nx.select(zero_norm, 1, trust_ratio) - g * safe_trust_ratio - end) - end - - @doc """ - Adds random Gaussian noise to the input. - - ## Options - - * `:seed` - Random seed to use. Defaults to the - current system time. - - * `:eta` - Controls amount of noise to add. - Defaults to `0.01`. - - * `:gamma` - Controls amount of noise to add. - Defaults to `0.55`. - """ @deprecated "Use Polaris.Updates.add_noise/1 instead" - def add_noise(combinator_or_opts \\ []) - - def add_noise(opts) when is_list(opts) do - add_noise(identity(), opts) - end + defdelegate add_noise(combinator_or_opts \\ []), to: Polaris.Updates - def add_noise({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - add_noise(combinator, []) - end - - def add_noise({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) - stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts)) - end - - defnp init_add_noise(_params, opts \\ []) do - %{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])} - end + @deprecated "Use Polaris.Updates.add_noise/2 instead" + defdelegate add_noise(combinator, opts), to: Polaris.Updates - defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do - opts = keyword!(opts, eta: 0.01, gamma: 0.55) - var = opts[:eta] / Nx.pow(count + 1, opts[:gamma]) - - {noise, key} = - deep_map_reduce(x, key, fn z, key -> - Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z)) - end) - - updates = deep_merge(x, noise, fn g, n -> g + var * n end) - - {updates, %{count: count + 1, key: key}} - end - - @doc """ - Scale input according to the Yogi algorithm. - - ## Options - - * `:initial_accumulator_value` - Initial state accumulator value. - - * `:b1` - first moment decay. Defaults to `0.9` - - * `:b2` - second moment decay. Defaults to `0.999` - - * `:eps` - numerical stability term. Defaults to `1.0e-8` - - * `:eps_root` - numerical stability term. Defaults to `0.0` - - ## References - - * [Adaptive Methods for Nonconvex Optimization](https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) - """ @deprecated "Use Polaris.Updates.scale_by_yogi/1 instead" - def scale_by_yogi(combinator_or_opts \\ []) - - def scale_by_yogi(opts) when is_list(opts) do - scale_by_yogi(identity(), opts) - end - - def scale_by_yogi({init_fn, apply_fn} = combinator) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - scale_by_yogi(combinator, []) - end - - def scale_by_yogi({init_fn, apply_fn} = combinator, opts) - when is_function(init_fn, 1) and is_function(apply_fn, 3) do - {initial, opts} = Keyword.pop(opts, :initial_accumulator_value, 1.0e-6) - - stateful( - combinator, - &init_scale_by_yogi(&1, initial), - &apply_scale_by_yogi(&1, &2, &3, opts) - ) - end - - defnp init_scale_by_yogi(params, value) do - value = fulls_like(params, value, type: :f32) - mu = value - nu = value - count = Nx.tensor(0) - %{mu: mu, nu: nu, count: count} - end - - defnp apply_scale_by_yogi(x, %{mu: mu, nu: nu, count: count}, _params, opts \\ []) do - opts = keyword!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-3, eps_root: 0.0) - b1 = opts[:b1] - b2 = opts[:b2] - eps = opts[:eps] - eps_root = opts[:eps_root] - - mu = update_moment(x, mu, b1, 1) - - nu = - deep_merge(x, nu, fn g, v -> - v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2) - end) - - mu_hat = bias_correction(mu, b1, count + 1) - nu_hat = bias_correction(nu, b2, count + 1) + defdelegate scale_by_yogi(combinator_or_opts \\ []), to: Polaris.Updates - updates = deep_merge(mu_hat, nu_hat, fn m, v -> m / (Nx.sqrt(v + eps_root) + eps) end) + @deprecated "Use Polaris.Updates.scale_by_yogi/2 instead" + defdelegate scale_by_yogi(combinator, opts), to: Polaris.Updates - {updates, %{mu: mu, nu: nu, count: count + 1}} - end - - @doc """ - Represents a stateless update. - - Stateless updates do not depend on an update state and thus - only require an implementation of an update function. - """ @deprecated "Use Polaris.Updates.stateless/2 instead" - def stateless({parent_init_fn, parent_apply_fn} \\ identity(), apply_fn) do - apply_fn = fn updates, state, params -> - {updates, state} = parent_apply_fn.(updates, state, params) - {apply_fn.(updates, params), state} - end - - {parent_init_fn, apply_fn} - end + defdelegate stateless(parent_combinator \\ identity(), apply_fn), to: Polaris.Updates - @doc """ - Returns the identity update. + @deprecated "Use Polaris.Updates.identity/0 instead" + defdelegate identity(), to: Polaris.Updates - This is often as the initial update in many functions in this module. - """ @deprecated "Use Polaris.Updates.identity/1 instead" - def identity() do - {fn _params -> {} end, fn updates, state, _params -> {updates, state} end} - end - - def identity(combinator) do - combinator - end - - @doc """ - Composes two updates. This is useful for extending optimizers - without having to reimplement them. For example, you can implement - gradient centralization: + defdelegate identity(combinator), to: Polaris.Updates - import Axon.Updates - - Axon.Updates.compose(Axon.Updates.centralize(), Axon.Optimizers.rmsprop()) - - This is equivalent to: - - Axon.Updates.centralize() - |> Axon.Updates.scale_by_rms() - """ @deprecated "Use Polaris.Updates.compose/2 instead" - def compose({init_fn1, apply_fn1}, {init_fn2, apply_fn2}) do - init_fn = fn params -> - state = init_fn1.(params) - Tuple.insert_at(state, 0, init_fn2.(params)) - end - - apply_fn = fn updates, state, params -> - this_state = elem(state, 0) - other_state = Tuple.delete_at(state, 0) - {updates, new_other_state} = apply_fn1.(updates, other_state, params) - {updates, new_this_state} = apply_fn2.(updates, this_state, params) - {updates, Tuple.insert_at(new_other_state, 0, new_this_state)} - end - - {init_fn, apply_fn} - end + defdelegate compose(combinator1, combinator2), to: Polaris.Updates - @doc """ - Represents a stateful update. - - Stateful updates require some update state, such as - momentum or RMS of previous updates. Therefore you must - implement some initialization function as well as an update - function. - """ @deprecated "Use Polaris.Updates.stateful/3 instead" - def stateful({parent_init_fn, parent_apply_fn} \\ identity(), init_fn, apply_fn) do - init_fn = fn params -> - state = parent_init_fn.(params) - Tuple.insert_at(state, 0, init_fn.(params)) - end - - apply_fn = fn updates, state, params -> - this_state = elem(state, 0) - other_state = Tuple.delete_at(state, 0) - {updates, new_other_state} = parent_apply_fn.(updates, other_state, params) - {updates, new_this_state} = apply_fn.(updates, this_state, params) - {updates, Tuple.insert_at(new_other_state, 0, new_this_state)} - end - - {init_fn, apply_fn} - end + defdelegate stateful(parent_combinator \\ identity(), init_fn, apply_fn), to: Polaris.Updates - @doc """ - Applies updates to params and updates state parameters with - given state map. - """ + @deprecated "Use Polaris.Updates.apply_updates/3 instead" defn apply_updates(params, updates, state \\ nil) do - new_params = - deep_merge(params, updates, fn x, u -> - Nx.add(x, Nx.as_type(u, Nx.type(x))) - end) - - merge_state(new_params, state) - end - - deftransformp merge_state(params, state) do - case {params, state} do - {params, nil} -> - params - - {params, state} -> - merge_inner(params, state) - end - end - - defp merge_inner(%Nx.Tensor{}, %Nx.Tensor{} = state) do - state - end - - defp merge_inner(params, state) when is_map(params) and is_map(state) do - Map.merge(params, state, fn _, s1, s2 -> merge_inner(s1, s2) end) - end - - ## Helpers - - defnp update_moment(x, moment, decay, order) do - deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end) - end - - defnp bias_correction(moment, decay, count) do - deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end) - end - - defnp safe_norm(g, min_norm) do - norm = Nx.LinAlg.norm(g) - z = Nx.select(Nx.less_equal(norm, min_norm), 1, g) - masked_norm = Nx.LinAlg.norm(z) - Nx.select(Nx.less_equal(norm, min_norm), min_norm, masked_norm) + Polaris.Updates.apply_updates(params, updates, state) end end diff --git a/mix.exs b/mix.exs index 663f1678..6a72dfb7 100644 --- a/mix.exs +++ b/mix.exs @@ -35,8 +35,8 @@ defmodule Axon.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:exla, "~> 0.6.0", [only: :test] ++ exla_opts()}, - {:torchx, "~> 0.6.0", [only: :test] ++ torchx_opts()}, + {:exla, "~> 0.7.0", [only: :test] ++ exla_opts()}, + {:torchx, "~> 0.7.0", [only: :test] ++ torchx_opts()}, {:nx, "~> 0.6.0 or ~> 0.7.0", nx_opts()}, {:ex_doc, "~> 0.23", only: :docs}, {:table_rex, "~> 3.1.1", optional: true}, diff --git a/mix.lock b/mix.lock index bc3be37b..a7ea4911 100644 --- a/mix.lock +++ b/mix.lock @@ -1,11 +1,9 @@ %{ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, - "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.2.0", "557c43befb8e3b119b718da302adccde3bd855acdb999498a14a2a8d2814b8b9", [:rebar3], [], "hexpm", "a2115d4bf1cca488a7b33f3c648847f64019b32c0382d10286d84dd5c3cbc0e5"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, - "exla": {:hex, :exla, "0.6.4", "24a46884696c4904d7c8f87a41461a7460f5f118ca171062044d187b32fae279", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.4", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "09b3608b55941736f222388da7611f33fe4b0bb308119cdf2f32f50b924d0ad6"}, + "exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"}, "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, "kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"}, "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, @@ -13,12 +11,12 @@ "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.4", "29563475afa9b8a2add1b7a9c8fb68d06ca7737648f28398e04461f008b69521", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f4ed47ecda66de70dd817698a703f8816daa91272e7e45812469498614ae8b29"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, - "nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"}, + "nx": {:hex, :nx, "0.7.0", "cec684cada356e9d268af01daa758882f7372aa952716dbe0369c657abb9e762", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "68edaa48a5841495ecab0dd4cf7b11b2fc0ad809754ae7f82d9c4090b91acf55"}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "torchx": {:hex, :torchx, "0.6.4", "0251f153521db77fbe8ea786eafc594ee3e260539481d1b45680cc2061501f53", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "d161b34f04bf418d19e482c84dd1083561eb00538432a5d44f4e9916f2c54a3b"}, + "torchx": {:hex, :torchx, "0.7.0", "c71fd603b0133ed8709450d82aa3434cbcf485a37c9a68e9ebcce86f5e4fb7f0", [:mix], [{:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "a324079c56bb67750b1da16f859d994982bb467020a8c2cba324639552f3adb8"}, "vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"}, - "xla": {:hex, :xla, "0.5.1", "8ba4c2c51c1a708ff54e9d4f88158c1a75b7f2cb3e5db02bf222b5b3852afffd", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "82a2490f6e9a76c8a29d1aedb47f07c59e3d5081095eac5a74db34d46c8212bc"}, + "xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"}, } diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 98a045ca..0e7b088a 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5772,7 +5772,7 @@ defmodule CompilerTest do assert %{"custom" => %{"a" => a}} = params = init_fn.(Nx.template({1, 1}, :f32), %{}) - assert predict_fn.(params, x) == Nx.add(x, a) + assert_equal(predict_fn.(params, x), Nx.add(x, a)) end test "supports composite/map parameter types" do @@ -5792,7 +5792,7 @@ defmodule CompilerTest do assert %{"custom" => %{"composite" => %{"inner" => inner}}} = params = init_fn.(Nx.template({1, 1}, :f32), %{}) - assert predict_fn.(params, x) == Nx.add(x, inner) + assert_equal(predict_fn.(params, x), Nx.add(x, inner)) end test "inner params in composite parameters initialize to different values" do @@ -5834,7 +5834,7 @@ defmodule CompilerTest do assert %{"custom" => %{"composite" => %{"inner_composite" => %{"a" => a}}}} = params = init_fn.(Nx.template({1, 1}, :f32), %{}) - assert predict_fn.(params, x) == Nx.add(x, a) + assert_equal(predict_fn.(params, x), Nx.add(x, a)) end end diff --git a/test/axon/optimizers_test.exs b/test/axon/optimizers_test.exs deleted file mode 100644 index 0cb70abb..00000000 --- a/test/axon/optimizers_test.exs +++ /dev/null @@ -1,318 +0,0 @@ -defmodule OptimizersTest do - use Axon.Case, async: true - - @learning_rate 1.0e-1 - @iterations 100 - - describe "adabelief" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.adabelief(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.adabelief(@learning_rate, b1: 0.95, b2: 0.99) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.adabelief(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "adagrad" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.adagrad(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.adagrad(@learning_rate, eps: 1.0e-3) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.adagrad(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "adam" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.adam(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.adam(@learning_rate, b1: 0.95, b2: 0.99) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.adam(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "adamw" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.adamw(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.adamw(@learning_rate, decay: 0.9) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.adamw(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "lamb" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.lamb(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.lamb(@learning_rate, decay: 0.9, min_norm: 0.1) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.lamb(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "noisy_sgd" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.noisy_sgd(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.noisy_sgd(@learning_rate, eta: 0.2, gamma: 0.6) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.noisy_sgd(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "radam" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.radam(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.radam(@learning_rate, threshold: 2.0) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.radam(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "rmsprop" do - test "correctly optimizes simple loss default case" do - optimizer = Axon.Optimizers.rmsprop(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss centered case" do - optimizer = - Axon.Optimizers.rmsprop(@learning_rate, centered: true, initial_scale: 0.1, decay: 0.8) - - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss rms case" do - optimizer = Axon.Optimizers.rmsprop(@learning_rate, initial_scale: 0.1, decay: 0.8) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with momentum" do - optimizer = - Axon.Optimizers.rmsprop(@learning_rate, initial_scale: 0.1, decay: 0.8, momentum: 0.9) - - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.rmsprop(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "sgd" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.sgd(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.sgd(@learning_rate, momentum: 0.9) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.sgd(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end - - describe "yogi" do - test "correctly optimizes simple loss with default options" do - optimizer = Axon.Optimizers.yogi(@learning_rate) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with custom options" do - optimizer = Axon.Optimizers.yogi(@learning_rate, initial_accumulator_value: 0.1) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor([1.0])} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - - test "correctly optimizes simple loss with schedule" do - optimizer = Axon.Optimizers.yogi(Axon.Schedules.constant(@learning_rate)) - loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end - num_steps = @iterations - x0 = %{"x0" => Nx.tensor(1.0)} - - check_optimizer!(optimizer, loss_fn, x0, num_steps) - end - end -end diff --git a/test/axon/schedules_test.exs b/test/axon/schedules_test.exs deleted file mode 100644 index a3b46bee..00000000 --- a/test/axon/schedules_test.exs +++ /dev/null @@ -1,230 +0,0 @@ -defmodule Axon.SchedulesTest do - use Axon.Case - doctest Axon.Schedules - - import Axon.Schedules - import Nx.Defn - - describe "exponential_decay" do - test "returns arity-1 function with defaults" do - fun = exponential_decay(1.0e-2) - assert is_function(fun, 1) - end - - test "returns arity-1 function with options" do - fun = exponential_decay(1.0e-3, decay_rate: 0.9) - assert is_function(fun, 1) - end - - test "can be called as anonymous function" do - fun = exponential_decay(1.0e-2) - assert_all_close(fun.(0), 1.0e-2) - - fun = exponential_decay(1.0e-3) - assert_all_close(fun.(0), 1.0e-3) - end - - test "can be called within JIT" do - fun = exponential_decay(1.0e-2) - assert_all_close(apply(jit(fun), [0]), 1.0e-2) - - fun = exponential_decay(1.0e-3) - assert_all_close(apply(jit(fun), [0]), 1.0e-3) - end - - test "matches optax values at different counts" do - fun1 = exponential_decay(1.0e-2, decay_rate: 0.9, transition_steps: 15) - - assert_all_close(fun1.(0), 1.0e-2) - assert_all_close(fun1.(25), 0.008389527) - assert_all_close(fun1.(50), 0.007038417) - assert_all_close(fun1.(1000), 8.902254e-06) - assert_all_close(fun1.(100_000), 0.0) - - fun2 = exponential_decay(1.0e-3, decay_rate: 0.99, transition_steps: 100) - - assert_all_close(fun2.(0), 1.0e-3) - assert_all_close(fun2.(25), 0.0009974906) - assert_all_close(fun2.(50), 0.0009949874) - assert_all_close(fun2.(1000), 0.0009043822) - assert_all_close(fun2.(100_000), 4.3171664e-08) - - fun3 = - exponential_decay( - 1.0e-1, - decay_rate: 0.99, - transition_begin: 100, - transition_steps: 25 - ) - - assert_all_close(fun3.(0), 0.1) - assert_all_close(fun3.(25), 0.1) - assert_all_close(fun3.(50), 0.1) - assert_all_close(fun3.(1000), 0.069641344) - assert_all_close(fun3.(100_000), 3.6162157e-19) - end - end - - describe "cosine_decay" do - test "returns arity-1 function with defaults" do - fun = cosine_decay(1.0e-2) - assert is_function(fun, 1) - end - - test "returns arity-1 function with options" do - fun = cosine_decay(1.0e-3, decay_steps: 5) - assert is_function(fun, 1) - end - - test "can be called as anonymous function" do - fun = cosine_decay(1.0e-2) - assert_all_close(fun.(0), 1.0e-2) - - fun = cosine_decay(1.0e-3) - assert_all_close(fun.(0), 1.0e-3) - end - - test "can be called within JIT" do - fun = cosine_decay(1.0e-2) - assert_all_close(apply(jit(fun), [0]), 1.0e-2) - - fun = cosine_decay(1.0e-3) - assert_all_close(apply(jit(fun), [0]), 1.0e-3) - end - - test "matches optax values at different counts" do - fun1 = cosine_decay(1.0e-3, decay_steps: 10, alpha: 0.0) - - assert_all_close(fun1.(0), 0.001) - assert_all_close(fun1.(25), 0.0) - assert_all_close(fun1.(50), 0.00) - assert_all_close(fun1.(1000), 0.0) - assert_all_close(fun1.(100_000), 0.0) - - fun2 = cosine_decay(1.0e-2, decay_steps: 1000, alpha: 0.5) - - assert_all_close(fun2.(0), 0.01) - assert_all_close(fun2.(25), 0.009992293) - assert_all_close(fun2.(50), 0.0099692205) - assert_all_close(fun2.(1000), 0.005) - assert_all_close(fun2.(100_000), 0.005) - - fun3 = cosine_decay(1.0e-1, decay_steps: 1) - - assert_all_close(fun3.(0), 0.1) - assert_all_close(fun3.(25), 0.0) - assert_all_close(fun3.(50), 0.0) - assert_all_close(fun3.(1000), 0.0) - assert_all_close(fun3.(100_000), 0.0) - end - end - - describe "constant" do - test "returns arity-1 function with defaults" do - fun = constant(1.0e-2) - assert is_function(fun, 1) - end - - test "can be called as anonymous function" do - fun = constant(1.0e-2) - assert_all_close(fun.(0), 1.0e-2) - - fun = cosine_decay(1.0e-3) - assert_all_close(fun.(0), 1.0e-3) - end - - test "can be called within JIT" do - fun = constant(1.0e-2) - assert_all_close(apply(jit(fun), [0]), 1.0e-2) - - fun = constant(1.0e-3) - assert_all_close(apply(jit(fun), [0]), 1.0e-3) - end - - test "matches optax values at different counts" do - fun1 = constant(1.0e-3) - - assert_all_close(fun1.(0), 0.001) - assert_all_close(fun1.(25), 0.001) - assert_all_close(fun1.(50), 0.001) - assert_all_close(fun1.(1000), 0.001) - assert_all_close(fun1.(100_000), 0.001) - - fun2 = constant(1.0e-2) - - assert_all_close(fun2.(0), 0.01) - assert_all_close(fun2.(25), 0.01) - assert_all_close(fun2.(50), 0.01) - assert_all_close(fun2.(1000), 0.01) - assert_all_close(fun2.(100_000), 0.01) - - fun3 = constant(1.0e-1) - - assert_all_close(fun3.(0), 0.1) - assert_all_close(fun3.(25), 0.1) - assert_all_close(fun3.(50), 0.1) - assert_all_close(fun3.(1000), 0.1) - assert_all_close(fun3.(100_000), 0.1) - end - end - - describe "polynomial_decay" do - test "returns arity-1 function with defaults" do - fun = polynomial_decay(1.0e-2) - assert is_function(fun, 1) - end - - test "returns arity-1 function with options" do - fun = polynomial_decay(1.0e-3, end_value: 1.0e-4) - assert is_function(fun, 1) - end - - test "can be called as anonymous function" do - fun = polynomial_decay(1.0e-2) - assert_all_close(fun.(0), 1.0e-2) - - fun = polynomial_decay(1.0e-3) - assert_all_close(fun.(0), 1.0e-3) - end - - test "can be called within JIT" do - fun = polynomial_decay(1.0e-2) - assert_all_close(apply(jit(fun), [0]), 1.0e-2) - - fun = polynomial_decay(1.0e-3, end_value: 1.0e-4) - assert_all_close(apply(jit(fun), [0]), 1.0e-3) - end - - test "matches optax values at different counts" do - fun1 = polynomial_decay(1.0e-2, end_value: 1.0e-3, power: 3, transition_steps: 1000) - - assert_all_close(fun1.(0), 0.01) - assert_all_close(fun1.(25), 0.009341734) - assert_all_close(fun1.(50), 0.008716375) - assert_all_close(fun1.(1000), 0.001) - assert_all_close(fun1.(100_000), 0.001) - - fun2 = polynomial_decay(1.0e-3, end_value: 1.0e-4, transition_begin: 100, power: 2) - - assert_all_close(fun2.(0), 0.001) - assert_all_close(fun2.(25), 0.001) - assert_all_close(fun2.(50), 0.001) - assert_all_close(fun2.(1000), 0.0001) - assert_all_close(fun2.(100_000), 0.0001) - - fun3 = - polynomial_decay( - 1.0e-1, - end_value: 1.0e-3, - transition_steps: 10000, - power: 1.5 - ) - - assert_all_close(fun3.(0), 0.1) - assert_all_close(fun3.(25), 0.099628985) - assert_all_close(fun3.(50), 0.09925843) - assert_all_close(fun3.(1000), 0.08552768) - assert_all_close(fun3.(100_000), 0.001) - end - end -end diff --git a/test/axon/updates_test.exs b/test/axon/updates_test.exs deleted file mode 100644 index 3faca9a1..00000000 --- a/test/axon/updates_test.exs +++ /dev/null @@ -1,2243 +0,0 @@ -defmodule Axon.UpdatesTest do - use Axon.Case - doctest Axon.Updates - - import Axon.Updates - - describe "add_decayed_weights" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = add_decayed_weights() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "constructs a stateless transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - add_decayed_weights(decay: 0.95) |> add_decayed_weights(decay: 0.95) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> add_decayed_weights(decay: 0.95) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) - params = %{a: Nx.tensor([0.18884168, 0.92323774, 0.4513516])} - updates = %{a: Nx.tensor([0.62370003, 0.86674502, 0.11204521])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.80309962, 1.74382088, 0.54082923]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) - - params = %{ - a: %{ - b: Nx.tensor([0.26106195, 0.52850289, 0.19788291]), - c: %{d: %{}, e: Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.83834362, 0.75873946, 0.54735649]), - c: %{d: %{}, e: Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([1.08635247, 1.26081721, 0.73534525]) - expected_e = Nx.tensor([[1.41295937, 1.15964536, 1.06867228]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = add_decayed_weights(decay: 0.95) - - params = { - { - Nx.tensor([0.26106195, 0.52850289, 0.19788291]), - {{}, Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} - } - } - - updates = { - { - Nx.tensor([0.83834362, 0.75873946, 0.54735649]), - {{}, Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([1.08635247, 1.26081721, 0.73534525]) - expected_e = Nx.tensor([[1.41295937, 1.15964536, 1.06867228]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "add_noise" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = add_noise() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {add_noise_state} = init_fn.(params) - assert %{count: count} = add_noise_state - assert_equal(count, Nx.tensor(0)) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = add_noise(gamma: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {add_noise_state} = init_fn.(params) - assert %{count: count} = add_noise_state - assert_equal(count, Nx.tensor(0)) - end - end - - describe "clip" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = clip() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "constructs a stateless transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = clip(delta: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = clip(delta: 2.0) |> clip(delta: 2.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> clip(delta: 2.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = clip(delta: 2.0) - params = %{a: Nx.tensor([0.74794595, 0.99105549, 0.5621627])} - updates = %{a: Nx.tensor([0.84208747, 0.69837738, 0.61840895])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.84208745, 0.6983774, 0.618409]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = clip(delta: 1.0) - - params = %{ - a: %{ - b: Nx.tensor([0.62866726, 0.04867021, 0.66160428]), - c: %{d: %{}, e: Nx.tensor([0.70566323, 0.52083707, 0.14541595])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.19084232, 0.09963277, 0.28141486]), - c: %{d: %{}, e: Nx.tensor([0.91124607, 0.2248316, 0.79530217])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.19084232, 0.09963277, 0.28141487]) - expected_e = Nx.tensor([0.91124606, 0.2248316, 0.79530215]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = clip(delta: 1.0) - - params = { - { - Nx.tensor([0.62866726, 0.04867021, 0.66160428]), - {{}, Nx.tensor([0.70566323, 0.52083707, 0.14541595])} - } - } - - updates = { - { - Nx.tensor([0.19084232, 0.09963277, 0.28141486]), - {{}, Nx.tensor([0.91124607, 0.2248316, 0.79530217])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.19084232, 0.09963277, 0.28141487]) - expected_e = Nx.tensor([0.91124606, 0.2248316, 0.79530215]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "clip_by_global_norm" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = clip_by_global_norm() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "constructs a stateless transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - clip_by_global_norm(max_norm: 1.0) |> clip_by_global_norm(max_norm: 1.0) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> clip_by_global_norm(max_norm: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) - params = %{a: Nx.tensor([0.72673265, 0.35788219, 0.75329067])} - updates = %{a: Nx.tensor([0.68235248, 0.56976359, 0.79599518])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.571844, 0.47748914, 0.667082]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) - - params = %{ - a: %{ - b: Nx.tensor([0.85107357, 0.67088125, 0.59811338]), - c: %{d: %{}, e: Nx.tensor([0.45385324, 0.05131562, 0.91526984])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.59629243, 0.86219328, 0.30155944]), - c: %{d: %{}, e: Nx.tensor([0.83792943, 0.22030587, 0.72606433])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.3795878, 0.54885495, 0.1919667]) - expected_e = Nx.tensor([0.53340906, 0.14024231, 0.462198]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = clip_by_global_norm(max_norm: 1.0) - - params = { - { - Nx.tensor([0.85107357, 0.67088125, 0.59811338]), - {{}, Nx.tensor([0.45385324, 0.05131562, 0.91526984])} - } - } - - updates = { - { - Nx.tensor([0.59629243, 0.86219328, 0.30155944]), - {{}, Nx.tensor([0.83792943, 0.22030587, 0.72606433])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.3795878, 0.54885495, 0.1919667]) - expected_e = Nx.tensor([0.53340906, 0.14024231, 0.462198]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "centralize" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = centralize() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = centralize() |> centralize() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> centralize() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = centralize() - params = %{a: Nx.tensor([0.14574998, 0.53619206, 0.68726124])} - updates = %{a: Nx.tensor([0.05166196, 0.3979764, 0.84524461])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.05166196, 0.3979764, 0.84524461]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = centralize() - - params = %{ - a: %{ - b: Nx.tensor([0.21855268, 0.21286796, 0.83114509]), - c: %{d: %{}, e: Nx.tensor([[0.26958357, 0.59519575, 0.87732692]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.41087112, 0.97778015, 0.51054674]), - c: %{d: %{}, e: Nx.tensor([[0.20577277, 0.95319838, 0.14168365]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.41087112, 0.97778015, 0.51054674]) - expected_e = Nx.tensor([[-0.22777883, 0.51964678, -0.29186795]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = centralize() - - params = { - { - Nx.tensor([0.21855268, 0.21286796, 0.83114509]), - {{}, Nx.tensor([[0.26958357, 0.59519575, 0.87732692]])} - } - } - - updates = { - { - Nx.tensor([0.41087112, 0.97778015, 0.51054674]), - {{}, Nx.tensor([[0.20577277, 0.95319838, 0.14168365]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.41087112, 0.97778015, 0.51054674]) - expected_e = Nx.tensor([[-0.22777883, 0.51964678, -0.29186795]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "identity" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = identity() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = identity() |> identity() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> identity() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = identity() - params = %{a: Nx.tensor([0.18884168, 0.92323774, 0.4513516])} - updates = %{a: Nx.tensor([0.62370003, 0.86674502, 0.11204521])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.62370003, 0.86674502, 0.11204521]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = identity() - - params = %{ - a: %{ - b: Nx.tensor([0.26106195, 0.52850289, 0.19788291]), - c: %{d: %{}, e: Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.83834362, 0.75873946, 0.54735649]), - c: %{d: %{}, e: Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.83834362, 0.75873946, 0.54735649]) - expected_e = Nx.tensor([[0.7384456, 0.76676084, 0.72992148]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = identity() - - params = { - { - Nx.tensor([0.26106195, 0.52850289, 0.19788291]), - {{}, Nx.tensor([[0.7100145, 0.41356265, 0.35657979]])} - } - } - - updates = { - { - Nx.tensor([0.83834362, 0.75873946, 0.54735649]), - {{}, Nx.tensor([[0.7384456, 0.76676084, 0.72992148]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.83834362, 0.75873946, 0.54735649]) - expected_e = Nx.tensor([[0.7384456, 0.76676084, 0.72992148]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "scale" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale(1.0e-2) |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale(1.0e-2) - params = %{a: Nx.tensor([0.29887561, 0.70429164, 0.43314898])} - updates = %{a: Nx.tensor([0.2584395, 0.35890494, 0.84845509])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.00258439, 0.00358905, 0.00848455]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale(1.0e-2) - - params = %{ - a: %{ - b: Nx.tensor([0.58813851, 0.27981229, 0.17335737]), - c: %{d: %{}, e: Nx.tensor([0.21444265, 0.63923396, 0.12755156])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.48363215, 0.7147937, 0.32252682]), - c: %{d: %{}, e: Nx.tensor([0.09518468, 0.38613084, 0.20729078])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) - expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale(1.0e-2) - - params = { - { - Nx.tensor([0.58813851, 0.27981229, 0.17335737]), - {{}, Nx.tensor([0.21444265, 0.63923396, 0.12755156])} - } - } - - updates = { - { - Nx.tensor([0.48363215, 0.7147937, 0.32252682]), - {{}, Nx.tensor([0.09518468, 0.38613084, 0.20729078])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) - expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "scale_by_state" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_state(1.0e-3) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {state} = init_fn.(params) - assert %{scale: scale} = state - assert_equal(scale, Nx.tensor(1.0e-3)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_state(1.0e-3) |> scale_by_state(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {state_1, state_2} = init_fn.(params) - assert %{scale: scale_1} = state_1 - assert_equal(scale_1, Nx.tensor(1.0e-2)) - assert %{scale: scale_2} = state_2 - assert_equal(scale_2, Nx.tensor(1.0e-3)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_state(1.0e-3) |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {state} = init_fn.(params) - assert %{scale: scale} = state - assert_equal(scale, Nx.tensor(1.0e-3)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_state(1.0e-2) - params = %{a: Nx.tensor([0.29887561, 0.70429164, 0.43314898])} - updates = %{a: Nx.tensor([0.2584395, 0.35890494, 0.84845509])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.00258439, 0.00358905, 0.00848455]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{scale: scale}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(scale, Nx.tensor(1.0e-2)) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_state(1.0e-2) - - params = %{ - a: %{ - b: Nx.tensor([0.58813851, 0.27981229, 0.17335737]), - c: %{d: %{}, e: Nx.tensor([0.21444265, 0.63923396, 0.12755156])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.48363215, 0.7147937, 0.32252682]), - c: %{d: %{}, e: Nx.tensor([0.09518468, 0.38613084, 0.20729078])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) - expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{scale: scale}} = new_state - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(scale, Nx.tensor(1.0e-2)) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_state(1.0e-2) - - params = { - { - Nx.tensor([0.58813851, 0.27981229, 0.17335737]), - {{}, Nx.tensor([0.21444265, 0.63923396, 0.12755156])} - } - } - - updates = { - { - Nx.tensor([0.48363215, 0.7147937, 0.32252682]), - {{}, Nx.tensor([0.09518468, 0.38613084, 0.20729078])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00483632, 0.00714794, 0.00322527]) - expected_e = Nx.tensor([0.00095185, 0.00386131, 0.00207291]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{scale: scale}} = new_state - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(scale, Nx.tensor(1.0e-2)) - end - end - - describe "scale_by_adam" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam(b1: 0.5) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> scale_by_adam() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state_1, adam_state_2} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_1 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_2 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_adam() - params = %{a: Nx.tensor([0.29236649, 0.26508023, 0.05959644])} - updates = %{a: Nx.tensor([0.01461005, 0.3796587, 0.76886989])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.99999267, 0.9999933, 0.9999933]) - expected_next_mu_a = Nx.tensor([0.00146101, 0.03796587, 0.07688699]) - expected_next_nu_a = Nx.tensor([2.1345357e-07, 1.4414072e-04, 5.9116090e-04]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - - assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = - new_state - - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_mu_a, expected_next_mu_a) - assert_all_close(actual_next_nu_a, expected_next_nu_a) - assert_equal(actual_next_count, expected_next_count) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_adam() - - params = %{ - a: %{ - b: Nx.tensor([0.16028131, 0.82155978, 0.67870557]), - c: %{d: %{}, e: Nx.tensor([[0.42164469, 0.59406027, 0.24703223]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.37850456, 0.80079877, 0.16309247]), - c: %{d: %{}, e: Nx.tensor([[0.29081831, 0.29872105, 0.48405271]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.9999934, 0.9999933, 0.99999315]) - expected_e = Nx.tensor([[0.9999933, 0.9999933, 0.9999933]]) - expected_next_mu_b = Nx.tensor([0.03785046, 0.08007988, 0.01630925]) - expected_next_mu_e = Nx.tensor([[0.02908183, 0.0298721, 0.04840527]]) - expected_next_nu_b = Nx.tensor([1.4326570e-04, 6.4127869e-04, 2.6599155e-05]) - expected_next_nu_e = Nx.tensor([[8.4575287e-05, 8.9234265e-05, 2.3430702e-04]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_adam() - - params = { - { - Nx.tensor([0.16028131, 0.82155978, 0.67870557]), - {{}, Nx.tensor([[0.42164469, 0.59406027, 0.24703223]])} - } - } - - updates = { - { - Nx.tensor([0.37850456, 0.80079877, 0.16309247]), - {{}, Nx.tensor([[0.29081831, 0.29872105, 0.48405271]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.9999934, 0.9999933, 0.99999315]) - expected_e = Nx.tensor([[0.9999933, 0.9999933, 0.9999933]]) - expected_next_mu_b = Nx.tensor([0.03785046, 0.08007988, 0.01630925]) - expected_next_mu_e = Nx.tensor([[0.02908183, 0.0298721, 0.04840527]]) - expected_next_nu_b = Nx.tensor([1.4326570e-04, 6.4127869e-04, 2.6599155e-05]) - expected_next_nu_e = Nx.tensor([[8.4575287e-05, 8.9234265e-05, 2.3430702e-04]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - end - - describe "scale_by_belief" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_belief() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {belief_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_belief(b1: 0.4) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {belief_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_belief() |> scale_by_belief() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {belief_state_1, belief_state_2} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state_1 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state_2 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_belief() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {belief_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = belief_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_belief() - params = %{a: Nx.tensor([0.35582285, 0.02904734, 0.8684706])} - updates = %{a: Nx.tensor([0.64641294, 0.19990149, 0.54263212])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.9999934, 0.99999326, 0.9999933]) - expected_next_mu_a = Nx.tensor([0.0646413, 0.01999015, 0.05426321]) - expected_next_nu_a = Nx.tensor([4.1784969e-04, 3.9960611e-05, 2.9444962e-04]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - - assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = - new_state - - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_mu_a, expected_next_mu_a) - assert_all_close(actual_next_nu_a, expected_next_nu_a) - assert_equal(actual_next_count, expected_next_count) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_belief() - - params = %{ - a: %{ - b: Nx.tensor([0.48266117, 0.21594939, 0.25310925]), - c: %{d: %{}, e: Nx.tensor([[0.08780911, 0.25273182, 0.02973737]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.15456417, 0.03338711, 0.47241908]), - c: %{d: %{}, e: Nx.tensor([[0.76352976, 0.86033023, 0.22758512]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.9999933, 0.9999933, 0.99999326]) - expected_e = Nx.tensor([[0.9999934, 0.99999326, 0.9999933]]) - expected_next_mu_b = Nx.tensor([0.01545642, 0.00333871, 0.04724191]) - expected_next_mu_e = Nx.tensor([[0.07635298, 0.08603302, 0.02275851]]) - expected_next_nu_b = Nx.tensor([2.3890085e-05, 1.1146991e-06, 2.2317980e-04]) - expected_next_nu_e = Nx.tensor([[5.8297772e-04, 7.4016815e-04, 5.1794988e-05]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_belief() - - params = { - { - Nx.tensor([0.48266117, 0.21594939, 0.25310925]), - {{}, Nx.tensor([[0.08780911, 0.25273182, 0.02973737]])} - } - } - - updates = { - { - Nx.tensor([0.15456417, 0.03338711, 0.47241908]), - {{}, Nx.tensor([[0.76352976, 0.86033023, 0.22758512]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.9999933, 0.9999933, 0.99999326]) - expected_e = Nx.tensor([[0.9999934, 0.99999326, 0.9999933]]) - expected_next_mu_b = Nx.tensor([0.01545642, 0.00333871, 0.04724191]) - expected_next_mu_e = Nx.tensor([[0.07635298, 0.08603302, 0.02275851]]) - expected_next_nu_b = Nx.tensor([2.3890085e-05, 1.1146991e-06, 2.2317980e-04]) - expected_next_nu_e = Nx.tensor([[5.8297772e-04, 7.4016815e-04, 5.1794988e-05]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - end - - describe "scale_by_radam" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_radam() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_radam(b1: 0.5) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_radam() |> scale_by_radam() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state_1, adam_state_2} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_1 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state_2 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_radam() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_radam() - params = %{a: Nx.tensor([0.71289699, 0.29554161, 0.50779425])} - updates = %{a: Nx.tensor([0.88675452, 0.21455035, 0.53581422])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.88675433, 0.2145503, 0.53581405]) - expected_next_mu_a = Nx.tensor([0.08867545, 0.02145503, 0.05358142]) - expected_next_nu_a = Nx.tensor([7.863336e-04, 4.603185e-05, 2.870969e-04]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - - assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = - new_state - - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_mu_a, expected_next_mu_a) - assert_all_close(actual_next_nu_a, expected_next_nu_a) - assert_equal(actual_next_count, expected_next_count) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_radam() - - params = %{ - a: %{ - b: Nx.tensor([0.72504156, 0.86982723, 0.58679938]), - c: %{d: %{}, e: Nx.tensor([[0.26001513, 0.62556789, 0.29528421]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.01536453, 0.61977439, 0.561842]), - c: %{d: %{}, e: Nx.tensor([[0.03755132, 0.80392208, 0.87391938]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.01536453, 0.6197742, 0.56184185]) - expected_e = Nx.tensor([[0.03755131, 0.8039219, 0.8739191]]) - expected_next_mu_b = Nx.tensor([0.00153645, 0.06197744, 0.0561842]) - expected_next_mu_e = Nx.tensor([[0.00375513, 0.0803922, 0.08739194]]) - expected_next_nu_b = Nx.tensor([2.3606893e-07, 3.8412030e-04, 3.1566643e-04]) - expected_next_nu_e = Nx.tensor([[1.4101014e-06, 6.4629072e-04, 7.6373509e-04]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_radam() - - params = { - { - Nx.tensor([0.72504156, 0.86982723, 0.58679938]), - {{}, Nx.tensor([[0.26001513, 0.62556789, 0.29528421]])} - } - } - - updates = { - { - Nx.tensor([0.01536453, 0.61977439, 0.561842]), - {{}, Nx.tensor([[0.03755132, 0.80392208, 0.87391938]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.01536453, 0.6197742, 0.56184185]) - expected_e = Nx.tensor([[0.03755131, 0.8039219, 0.8739191]]) - expected_next_mu_b = Nx.tensor([0.00153645, 0.06197744, 0.0561842]) - expected_next_mu_e = Nx.tensor([[0.00375513, 0.0803922, 0.08739194]]) - expected_next_nu_b = Nx.tensor([2.3606893e-07, 3.8412030e-04, 3.1566643e-04]) - expected_next_nu_e = Nx.tensor([[1.4101014e-06, 6.4629072e-04, 7.6373509e-04]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - end - - describe "scale_by_rms" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rms() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rms_state} = init_fn.(params) - assert %{nu: %{a: nu_a}} = rms_state - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rms(initial_scale: 0.1) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rms_state} = init_fn.(params) - assert %{nu: %{a: nu_a}} = rms_state - assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rms() |> scale_by_rms() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rms_state_1, rms_state_2} = init_fn.(params) - assert %{nu: %{a: nu_a}} = rms_state_1 - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert %{nu: %{a: nu_a}} = rms_state_2 - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rms() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rms_state} = init_fn.(params) - assert %{nu: %{a: nu_a}} = rms_state - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_rms() - params = %{a: Nx.tensor([0.77100057, 0.98078091, 0.78499164])} - updates = %{a: Nx.tensor([0.25156708, 0.30524656, 0.97350756])} - state = init_fn.(params) - - expected_a = Nx.tensor([3.162275, 3.162276, 3.1622777]) - expected_next_nu_a = Nx.tensor([0.0063286, 0.00931755, 0.0947717]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{nu: %{a: actual_next_nu_a}}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(expected_next_nu_a, actual_next_nu_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_rms() - - params = %{ - a: %{ - b: Nx.tensor([0.0553049, 0.21828064, 0.98751916]), - c: %{d: %{}, e: Nx.tensor([[0.17757973, 0.67966022, 0.19382288]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.61220327, 0.73535765, 0.42179138]), - c: %{d: %{}, e: Nx.tensor([[0.39331236, 0.27389305, 0.30131908]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([3.1622772, 3.1622772, 3.162277]) - expected_e = Nx.tensor([[3.1622767, 3.1622758, 3.162276]]) - expected_next_nu_b = Nx.tensor([0.03747929, 0.05407509, 0.0177908]) - expected_next_nu_e = Nx.tensor([[0.01546946, 0.00750174, 0.00907932]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{nu: new_nu}} = new_state - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_rms() - - params = { - { - Nx.tensor([0.0553049, 0.21828064, 0.98751916]), - {{}, Nx.tensor([[0.17757973, 0.67966022, 0.19382288]])} - } - } - - updates = { - { - Nx.tensor([0.61220327, 0.73535765, 0.42179138]), - {{}, Nx.tensor([[0.39331236, 0.27389305, 0.30131908]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([3.1622772, 3.1622772, 3.162277]) - expected_e = Nx.tensor([[3.1622767, 3.1622758, 3.162276]]) - expected_next_nu_b = Nx.tensor([0.03747929, 0.05407509, 0.0177908]) - expected_next_nu_e = Nx.tensor([[0.01546946, 0.00750174, 0.00907932]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{nu: new_nu}} = new_state - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - end - end - - describe "scale_by_rss" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rss() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rss_state} = init_fn.(params) - assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state - assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rss(initial_accumulator_value: 0.2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rss_state} = init_fn.(params) - assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state - assert_equal(sum_of_squares_a, Nx.tensor([0.2, 0.2, 0.2])) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rss() |> scale_by_rss() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rss_state_1, rss_state_2} = init_fn.(params) - assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state_1 - assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) - assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state_2 - assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_rss() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {rss_state} = init_fn.(params) - assert %{sum_of_squares: %{a: sum_of_squares_a}} = rss_state - assert_equal(sum_of_squares_a, Nx.tensor([0.1, 0.1, 0.1])) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_rss() - params = %{a: Nx.tensor([0.41327447, 0.06948837, 0.03234601])} - updates = %{a: Nx.tensor([0.2137085, 0.84399692, 0.63099467])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.55993116, 0.93642795, 0.89401275]) - expected_next_sum_of_squares_a = Nx.tensor([0.14567132, 0.81233084, 0.49815428]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{sum_of_squares: %{a: actual_next_sum_of_squares_a}}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_sum_of_squares_a, expected_next_sum_of_squares_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_rss() - - params = %{ - a: %{ - b: Nx.tensor([0.92084601, 0.27218277, 0.56501597]), - c: %{d: %{}, e: Nx.tensor([[0.92937211, 0.44536295, 0.95296635]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.79292352, 0.11484326, 0.84693855]), - c: %{d: %{}, e: Nx.tensor([[0.13715272, 0.63276641, 0.5234425]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.92885643, 0.34135267, 0.9368279]) - expected_e = Nx.tensor([[0.39790204, 0.894515, 0.855929]]) - expected_next_sum_of_squares_b = Nx.tensor([0.72872776, 0.11318897, 0.8173049]) - expected_next_sum_of_squares_e = Nx.tensor([[0.11881087, 0.50039333, 0.37399206]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{sum_of_squares: new_sum_of_squares}} = new_state - - assert %{ - a: %{ - b: actual_next_sum_of_squares_b, - c: %{d: %{}, e: actual_next_sum_of_squares_e} - } - } = new_sum_of_squares - - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_sum_of_squares_b, expected_next_sum_of_squares_b) - assert_all_close(actual_next_sum_of_squares_e, expected_next_sum_of_squares_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_rss() - - params = { - { - Nx.tensor([0.92084601, 0.27218277, 0.56501597]), - {{}, Nx.tensor([[0.92937211, 0.44536295, 0.95296635]])} - } - } - - updates = { - { - Nx.tensor([0.79292352, 0.11484326, 0.84693855]), - {{}, Nx.tensor([[0.13715272, 0.63276641, 0.5234425]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.92885643, 0.34135267, 0.9368279]) - expected_e = Nx.tensor([[0.39790204, 0.894515, 0.855929]]) - expected_next_sum_of_squares_b = Nx.tensor([0.72872776, 0.11318897, 0.8173049]) - expected_next_sum_of_squares_e = Nx.tensor([[0.11881087, 0.50039333, 0.37399206]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{sum_of_squares: new_sum_of_squares}} = new_state - - assert { - { - actual_next_sum_of_squares_b, - {{}, actual_next_sum_of_squares_e} - } - } = new_sum_of_squares - - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_sum_of_squares_b, expected_next_sum_of_squares_b) - assert_all_close(actual_next_sum_of_squares_e, expected_next_sum_of_squares_e) - end - end - - describe "scale_by_schedule" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {schedule_state} = init_fn.(params) - assert %{count: count} = schedule_state - assert_equal(count, Nx.tensor(0)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - |> scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {schedule_state_2, schedule_state_1} = init_fn.(params) - assert %{count: count} = schedule_state_1 - assert_equal(count, Nx.tensor(0)) - assert %{count: count} = schedule_state_2 - assert_equal(count, Nx.tensor(0)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) |> scale(1.0e-2) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {schedule_state} = init_fn.(params) - assert %{count: count} = schedule_state - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - params = %{a: Nx.tensor([0.77425031, 0.65418105, 0.86150202])} - updates = %{a: Nx.tensor([0.56082198, 0.94549107, 0.54412826])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.00560822, 0.00945491, 0.00544128]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{count: actual_next_count}} = new_state - assert_all_close(actual_a, expected_a) - assert_equal(actual_next_count, expected_next_count) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - - params = %{ - a: %{ - b: Nx.tensor([0.3440084, 0.16096481, 0.43997161]), - c: %{d: %{}, e: Nx.tensor([[0.26168961, 0.40905451, 0.3061841]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.27159927, 0.37657519, 0.38219061]), - c: %{d: %{}, e: Nx.tensor([[0.9613661, 0.30215168, 0.24110271]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00271599, 0.00376575, 0.00382191]) - expected_e = Nx.tensor([[0.00961366, 0.00302152, 0.00241103]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{count: actual_next_count}} = new_state - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_equal(actual_next_count, expected_next_count) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_schedule(Axon.Schedules.polynomial_decay(1.0e-2)) - - params = { - { - Nx.tensor([0.3440084, 0.16096481, 0.43997161]), - {{}, Nx.tensor([[0.26168961, 0.40905451, 0.3061841]])} - } - } - - updates = { - { - Nx.tensor([0.27159927, 0.37657519, 0.38219061]), - {{}, Nx.tensor([[0.9613661, 0.30215168, 0.24110271]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.00271599, 0.00376575, 0.00382191]) - expected_e = Nx.tensor([[0.00961366, 0.00302152, 0.00241103]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{count: actual_next_count}} = new_state - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_equal(actual_next_count, expected_next_count) - end - end - - describe "scale_by_stddev" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_stddev() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {stddev_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_stddev(initial_scale: 0.5) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {stddev_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.5, 0.5, 0.5])) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - scale_by_stddev(initial_scale: 0.1) |> scale_by_stddev(initial_scale: 0.2) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {stddev_state_2, stddev_state_1} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state_1 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state_2 - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.2, 0.2, 0.2])) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_stddev(initial_scale: 0.1) |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {stddev_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}} = stddev_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.1, 0.1, 0.1])) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_stddev() - params = %{a: Nx.tensor([0.98013234, 0.0653057, 0.39361905])} - updates = %{a: Nx.tensor([0.58050587, 0.04869076, 0.62340991])} - state = init_fn.(params) - - expected_a = Nx.tensor([3.3333325, 3.333255, 3.3333328]) - expected_next_mu_a = Nx.tensor([0.05805059, 0.00486908, 0.06234099]) - expected_next_nu_a = Nx.tensor([0.03369871, 0.00023708, 0.03886399]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_mu_a, expected_next_mu_a) - assert_all_close(actual_next_nu_a, expected_next_nu_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_stddev() - - params = %{ - a: %{ - b: Nx.tensor([0.49792875, 0.04941673, 0.33815839]), - c: %{d: %{}, e: Nx.tensor([[0.70057761, 0.3689184, 0.36608007]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.54587409, 0.04849768, 0.23020724]), - c: %{d: %{}, e: Nx.tensor([[0.29348535, 0.79428645, 0.76129383]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([3.333333, 3.3332546, 3.33333]) - expected_e = Nx.tensor([[3.333331, 3.333333, 3.333333]]) - expected_next_mu_b = Nx.tensor([0.05458741, 0.00484977, 0.02302072]) - expected_next_mu_e = Nx.tensor([[0.02934854, 0.07942864, 0.07612938]]) - expected_next_nu_b = Nx.tensor([0.02979785, 0.0002352, 0.00529954]) - expected_next_nu_e = Nx.tensor([[0.00861336, 0.0630891, 0.05795683]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu}} = new_state - assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_stddev() - - params = { - { - Nx.tensor([0.49792875, 0.04941673, 0.33815839]), - {{}, Nx.tensor([[0.70057761, 0.3689184, 0.36608007]])} - } - } - - updates = { - { - Nx.tensor([0.54587409, 0.04849768, 0.23020724]), - {{}, Nx.tensor([[0.29348535, 0.79428645, 0.76129383]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([3.333333, 3.3332546, 3.33333]) - expected_e = Nx.tensor([[3.333331, 3.333333, 3.333333]]) - expected_next_mu_b = Nx.tensor([0.05458741, 0.00484977, 0.02302072]) - expected_next_mu_e = Nx.tensor([[0.02934854, 0.07942864, 0.07612938]]) - expected_next_nu_b = Nx.tensor([0.02979785, 0.0002352, 0.00529954]) - expected_next_nu_e = Nx.tensor([[0.00861336, 0.0630891, 0.05795683]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu}} = new_state - assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - end - end - - describe "scale_by_trust_ratio" do - test "constructs a stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_trust_ratio() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "constructs a stateless transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - - assert {init_fn, update_fn} = - scale_by_trust_ratio(min_norm: 1.0) |> scale_by_trust_ratio(min_norm: 1.0) - - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert init_fn.(params) == {} - end - - test "composes with stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_adam() |> scale_by_trust_ratio(min_norm: 1.0) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {adam_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = adam_state - assert_equal(mu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(nu_a, Nx.tensor([0.0, 0.0, 0.0])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) - params = %{a: Nx.tensor([0.07719177, 0.1812708, 0.94959977])} - updates = %{a: Nx.tensor([0.29626032, 0.328152, 0.20388144])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.29626033, 0.328152, 0.20388144]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert new_state == {} - assert_all_close(actual_a, expected_a) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) - - params = %{ - a: %{ - b: Nx.tensor([0.98282674, 0.34776357, 0.33319137]), - c: %{d: %{}, e: Nx.tensor([[0.95596768, 0.67948137, 0.05268411]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.53616958, 0.24854466, 0.26695091]), - c: %{d: %{}, e: Nx.tensor([[0.50354858, 0.91245821, 0.30518247]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.58683133, 0.27202922, 0.29217464]) - expected_e = Nx.tensor([[0.5443927, 0.98647004, 0.3299366]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_trust_ratio(min_norm: 1.0) - - params = { - { - Nx.tensor([0.98282674, 0.34776357, 0.33319137]), - {{}, Nx.tensor([[0.95596768, 0.67948137, 0.05268411]])} - } - } - - updates = { - { - Nx.tensor([0.53616958, 0.24854466, 0.26695091]), - {{}, Nx.tensor([[0.50354858, 0.91245821, 0.30518247]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.58683133, 0.27202922, 0.29217464]) - expected_e = Nx.tensor([[0.5443927, 0.98647004, 0.3299366]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert new_state == {} - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - end - end - - describe "scale_by_yogi" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_yogi() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {yogi_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state - assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(count, Nx.tensor(0)) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_yogi(initial_accumulator_value: 1.0e-4) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {yogi_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state - assert_equal(mu_a, Nx.tensor([1.0e-4, 1.0e-4, 1.0e-4])) - assert_equal(nu_a, Nx.tensor([1.0e-4, 1.0e-4, 1.0e-4])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_yogi() |> scale_by_yogi() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {yogi_state_1, yogi_state_2} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state_1 - assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(count, Nx.tensor(0)) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state_2 - assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(count, Nx.tensor(0)) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = scale_by_yogi() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {yogi_state} = init_fn.(params) - assert %{mu: %{a: mu_a}, nu: %{a: nu_a}, count: count} = yogi_state - assert_equal(mu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(nu_a, Nx.tensor([1.0e-6, 1.0e-6, 1.0e-6])) - assert_equal(count, Nx.tensor(0)) - end - - test "matches optax with simple container" do - assert {init_fn, update_fn} = scale_by_yogi() - params = %{a: Nx.tensor([0.39152084, 0.86061072, 0.22693509])} - updates = %{a: Nx.tensor([0.10820288, 0.73034528, 0.6741126])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.95148116, 0.99770474, 0.9974302]) - expected_next_mu_a = Nx.tensor([0.01082119, 0.07303543, 0.06741216]) - expected_next_nu_a = Nx.tensor([1.2707865e-05, 5.3440424e-04, 4.5542780e-04]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - - assert {%{mu: %{a: actual_next_mu_a}, nu: %{a: actual_next_nu_a}, count: actual_next_count}} = - new_state - - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_mu_a, expected_next_mu_a) - assert_all_close(actual_next_nu_a, expected_next_nu_a) - assert_equal(actual_next_count, expected_next_count) - end - - test "matches optax with nested container" do - assert {init_fn, update_fn} = scale_by_yogi() - - params = %{ - a: %{ - b: Nx.tensor([0.87690482, 0.80993702, 0.87935556]), - c: %{d: %{}, e: Nx.tensor([[0.00528695, 0.06690531, 0.12589192]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.47019351, 0.72034131, 0.32043362]), - c: %{d: %{}, e: Nx.tensor([[0.84200356, 0.76360484, 0.55381714]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.99564576, 0.9976599, 0.99210596]) - expected_e = Nx.tensor([[0.9981149, 0.99784315, 0.9965868]]) - expected_next_mu_b = Nx.tensor([0.04702025, 0.07203503, 0.03204427]) - expected_next_mu_e = Nx.tensor([[0.08420125, 0.07636139, 0.05538262]]) - expected_next_nu_b = Nx.tensor([0.00022208, 0.00051989, 0.00010368]) - expected_next_nu_e = Nx.tensor([[0.00070997, 0.00058409, 0.00030771]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert %{a: %{b: actual_next_mu_b, c: %{d: %{}, e: actual_next_mu_e}}} = new_mu - assert %{a: %{b: actual_next_nu_b, c: %{d: %{}, e: actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - - test "supports generic container" do - assert {init_fn, update_fn} = scale_by_yogi() - - params = { - { - Nx.tensor([0.87690482, 0.80993702, 0.87935556]), - {{}, Nx.tensor([[0.00528695, 0.06690531, 0.12589192]])} - } - } - - updates = { - { - Nx.tensor([0.47019351, 0.72034131, 0.32043362]), - {{}, Nx.tensor([[0.84200356, 0.76360484, 0.55381714]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.99564576, 0.9976599, 0.99210596]) - expected_e = Nx.tensor([[0.9981149, 0.99784315, 0.9965868]]) - expected_next_mu_b = Nx.tensor([0.04702025, 0.07203503, 0.03204427]) - expected_next_mu_e = Nx.tensor([[0.08420125, 0.07636139, 0.05538262]]) - expected_next_nu_b = Nx.tensor([0.00022208, 0.00051989, 0.00010368]) - expected_next_nu_e = Nx.tensor([[0.00070997, 0.00058409, 0.00030771]]) - expected_next_count = Nx.tensor(1) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{mu: new_mu, nu: new_nu, count: actual_next_count}} = new_state - assert {{actual_next_mu_b, {{}, actual_next_mu_e}}} = new_mu - assert {{actual_next_nu_b, {{}, actual_next_nu_e}}} = new_nu - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_mu_b, expected_next_mu_b) - assert_all_close(actual_next_mu_e, expected_next_mu_e) - assert_all_close(actual_next_nu_b, expected_next_nu_b) - assert_all_close(actual_next_nu_e, expected_next_nu_e) - assert_equal(actual_next_count, expected_next_count) - end - end - - describe "trace" do - test "constructs a stateful transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = trace() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {trace_state} = init_fn.(params) - assert %{trace: %{a: trace_a}} = trace_state - assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "constructs a stateful transformation with options" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = trace(decay: 0.8) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {trace_state} = init_fn.(params) - assert %{trace: %{a: trace_a}} = trace_state - assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "composes with itself" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = trace() |> trace() - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {trace_state_2, trace_state_1} = init_fn.(params) - assert %{trace: %{a: trace_a}} = trace_state_1 - assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) - assert %{trace: %{a: trace_a}} = trace_state_2 - assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "composes with stateless transformation" do - params = %{a: Nx.tensor([1.0, 2.0, 3.0])} - assert {init_fn, update_fn} = trace() |> scale(1.0e-2) - assert is_function(init_fn, 1) - assert is_function(update_fn, 3) - assert {trace_state} = init_fn.(params) - assert %{trace: %{a: trace_a}} = trace_state - assert_equal(trace_a, Nx.tensor([0.0, 0.0, 0.0])) - end - - test "matches optax with simple container, nesterov: false" do - assert {init_fn, update_fn} = trace(nesterov: false) - params = %{a: Nx.tensor([0.54044065, 0.54168045, 0.14243068])} - updates = %{a: Nx.tensor([0.76976679, 0.19561062, 0.84724249])} - state = init_fn.(params) - - expected_a = Nx.tensor([0.7697668, 0.19561061, 0.8472425]) - expected_next_trace = Nx.tensor([0.7697668, 0.19561061, 0.8472425]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{trace: %{a: actual_next_trace}}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_trace, expected_next_trace) - end - - test "matches optax with nested container, nesterov: false" do - assert {init_fn, update_fn} = trace(nesterov: false) - - params = %{ - a: %{ - b: Nx.tensor([0.23468207, 0.75940123, 0.06601013]), - c: %{d: %{}, e: Nx.tensor([[0.68877159, 0.84383744, 0.15230977]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.60272336, 0.42772071, 0.39653623]), - c: %{d: %{}, e: Nx.tensor([[0.25453278, 0.64759897, 0.71080799]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.60272336, 0.4277207, 0.39653623]) - expected_e = Nx.tensor([[0.25453278, 0.647599, 0.710808]]) - expected_next_trace_b = Nx.tensor([0.60272336, 0.4277207, 0.39653623]) - expected_next_trace_e = Nx.tensor([[0.25453278, 0.647599, 0.710808]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{trace: new_trace}} = new_state - assert %{a: %{b: actual_next_trace_b, c: %{d: %{}, e: actual_next_trace_e}}} = new_trace - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_trace_b, expected_next_trace_b) - assert_all_close(actual_next_trace_e, expected_next_trace_e) - end - - test "matches optax with simple container, nesterov: true" do - assert {init_fn, update_fn} = trace(nesterov: true) - params = %{a: Nx.tensor([0.05727068, 0.71336316, 0.52111667])} - updates = %{a: Nx.tensor([0.99510349, 0.38321624, 0.37485662])} - state = init_fn.(params) - - expected_a = Nx.tensor([1.8906965, 0.7281108, 0.7122276]) - expected_next_trace = Nx.tensor([0.9951035, 0.38321623, 0.37485662]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: actual_a} = new_updates - assert {%{trace: %{a: actual_next_trace}}} = new_state - assert_all_close(actual_a, expected_a) - assert_all_close(actual_next_trace, expected_next_trace) - end - - test "matches optax with nested container, nesterov: true" do - assert {init_fn, update_fn} = trace(nesterov: true) - - params = %{ - a: %{ - b: Nx.tensor([0.81068757, 0.89196671, 0.21672469]), - c: %{d: %{}, e: Nx.tensor([[0.9194404, 0.19829658, 0.96960522]])} - } - } - - updates = %{ - a: %{ - b: Nx.tensor([0.21182614, 0.29456406, 0.50427876]), - c: %{d: %{}, e: Nx.tensor([[0.26525984, 0.66349034, 0.11212149]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.40246966, 0.55967176, 0.95812964]) - expected_e = Nx.tensor([[0.5039937, 1.2606317, 0.21303083]]) - expected_next_trace_b = Nx.tensor([0.21182615, 0.29456407, 0.5042788]) - expected_next_trace_e = Nx.tensor([[0.26525983, 0.66349036, 0.11212149]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert %{a: %{b: actual_b, c: %{d: %{}, e: actual_e}}} = new_updates - assert {%{trace: new_trace}} = new_state - assert %{a: %{b: actual_next_trace_b, c: %{d: %{}, e: actual_next_trace_e}}} = new_trace - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_trace_b, expected_next_trace_b) - assert_all_close(actual_next_trace_e, expected_next_trace_e) - end - - test "supports generic container" do - assert {init_fn, update_fn} = trace(nesterov: true) - - params = { - { - Nx.tensor([0.81068757, 0.89196671, 0.21672469]), - {{}, Nx.tensor([[0.9194404, 0.19829658, 0.96960522]])} - } - } - - updates = { - { - Nx.tensor([0.21182614, 0.29456406, 0.50427876]), - {{}, Nx.tensor([[0.26525984, 0.66349034, 0.11212149]])} - } - } - - state = init_fn.(params) - - expected_b = Nx.tensor([0.40246966, 0.55967176, 0.95812964]) - expected_e = Nx.tensor([[0.5039937, 1.2606317, 0.21303083]]) - expected_next_trace_b = Nx.tensor([0.21182615, 0.29456407, 0.5042788]) - expected_next_trace_e = Nx.tensor([[0.26525983, 0.66349036, 0.11212149]]) - - assert {new_updates, new_state} = update_fn.(updates, state, params) - assert {{actual_b, {{}, actual_e}}} = new_updates - assert {%{trace: new_trace}} = new_state - assert {{actual_next_trace_b, {{}, actual_next_trace_e}}} = new_trace - assert_all_close(actual_b, expected_b) - assert_all_close(actual_e, expected_e) - assert_all_close(actual_next_trace_b, expected_next_trace_b) - assert_all_close(actual_next_trace_e, expected_next_trace_e) - end - end -end