diff --git a/examples/generative/fashionmnist_autoencoder.exs b/examples/generative/fashionmnist_autoencoder.exs index d6db4cab8..9923dc809 100644 --- a/examples/generative/fashionmnist_autoencoder.exs +++ b/examples/generative/fashionmnist_autoencoder.exs @@ -1,16 +1,15 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.6"} + {:axon, "~> 0.3.0"}, + {:exla, "~> 0.4.1"}, + {:nx, "~> 0.4.1"}, + {:scidata, "~> 0.1.9"} ]) # Configure default platform with accelerator precedence as tpu > cuda > rocm > host EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) -defmodule Fashionmist do +defmodule FashionMNIST do require Axon - alias Axon.Loop.State defmodule Autoencoder do defp encoder(x, latent_dim) do @@ -22,7 +21,7 @@ defmodule Fashionmist do defp decoder(x) do x |> Axon.dense(784, activation: :sigmoid) - |> Axon.reshape({1, 28, 28}) + |> Axon.reshape({:batch, 1, 28, 28}) end def build_model(input_shape, latent_dim) do @@ -37,7 +36,7 @@ defmodule Fashionmist do |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 1, 28, 28}) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) end defp train_model(model, train_images, epochs) do @@ -58,7 +57,7 @@ defmodule Fashionmist do sample_image = train_images - |> hd() + |> Enum.fetch!(0) |> Nx.slice_along_axis(0, 1) |> Nx.reshape({1, 1, 28, 28}) @@ -71,4 +70,4 @@ defmodule Fashionmist do end end -Fashionmist.run() +FashionMNIST.run()