diff --git a/examples/generative/text_generator.exs b/examples/generative/text_generator.exs index 97a2f0994..98c193c5e 100644 --- a/examples/generative/text_generator.exs +++ b/examples/generative/text_generator.exs @@ -1,9 +1,9 @@ # Based on https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/ Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.2.1"}, - {:exla, "~> 0.2.2"}, - {:req, "~> 0.3.0"} + {:axon, "~> 0.3.0"}, + {:nx, "~> 0.4.1"}, + {:exla, "~> 0.4.1"}, + {:req, "~> 0.3.3"} ]) EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) @@ -18,7 +18,7 @@ defmodule TextGenerator do def build_model(characters_count) do Axon.input("input_chars", shape: {nil, @sequence_length, 1}) |> Axon.lstm(256) - |> then(fn {_, out} -> out end) + |> then(fn {out, _} -> out end) |> Axon.nx(fn t -> t[[0..-1//1, -1]] end) |> Axon.dropout(rate: 0.2) |> Axon.dense(characters_count, activation: :softmax) @@ -59,7 +59,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.divide(characters_count) |> Nx.reshape({:auto, @sequence_length, 1}) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) train_labels = text @@ -68,7 +68,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.reshape({:auto, 1}) |> Nx.equal(Nx.iota({characters_count})) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) {train_data, train_labels} end