From ff6e845b854d48e4356cd41d252bd72cd916037c Mon Sep 17 00:00:00 2001 From: Nicholas Date: Fri, 6 Jan 2023 17:05:04 -0500 Subject: [PATCH 1/3] Update dependencies to latest versions --- examples/generative/text_generator.exs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/generative/text_generator.exs b/examples/generative/text_generator.exs index 97a2f0994..4616455b7 100644 --- a/examples/generative/text_generator.exs +++ b/examples/generative/text_generator.exs @@ -1,9 +1,9 @@ # Based on https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/ Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.2.1"}, - {:exla, "~> 0.2.2"}, - {:req, "~> 0.3.0"} + {:axon, "~> 0.3.0"}, + {:nx, "~> 0.4.1"}, + {:exla, "~> 0.4.1"}, + {:req, "~> 0.3.3"} ]) EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) From 02c4097f2f69ed29e14a947e41166a5a5a95466a Mon Sep 17 00:00:00 2001 From: Nicholas Date: Fri, 6 Jan 2023 17:07:52 -0500 Subject: [PATCH 2/3] The output sequence is now the first element of the return value of Axon.lstm/2 --- examples/generative/text_generator.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/generative/text_generator.exs b/examples/generative/text_generator.exs index 4616455b7..a0e0dd51c 100644 --- a/examples/generative/text_generator.exs +++ b/examples/generative/text_generator.exs @@ -18,7 +18,7 @@ defmodule TextGenerator do def build_model(characters_count) do Axon.input("input_chars", shape: {nil, @sequence_length, 1}) |> Axon.lstm(256) - |> then(fn {_, out} -> out end) + |> then(fn {out, _} -> out end) |> Axon.nx(fn t -> t[[0..-1//1, -1]] end) |> Axon.dropout(rate: 0.2) |> Axon.dense(characters_count, activation: :softmax) From 22c82d5535dba38f919c8c40cc3b7da8a2184e05 Mon Sep 17 00:00:00 2001 From: Nicholas Date: Fri, 6 Jan 2023 17:14:31 -0500 Subject: [PATCH 3/3] Update deprecated Nx function: to_batched_list/3 -> to_batched/3 --- examples/generative/text_generator.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/generative/text_generator.exs b/examples/generative/text_generator.exs index a0e0dd51c..98c193c5e 100644 --- a/examples/generative/text_generator.exs +++ b/examples/generative/text_generator.exs @@ -59,7 +59,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.divide(characters_count) |> Nx.reshape({:auto, @sequence_length, 1}) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) train_labels = text @@ -68,7 +68,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.reshape({:auto, 1}) |> Nx.equal(Nx.iota({characters_count})) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) {train_data, train_labels} end