From 9185dd347e025620e09a0cb539f173470a2184ea Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 19 Oct 2023 04:43:18 -0700 Subject: [PATCH 01/15] Add mistral --- lib/bumblebee.ex | 1 + lib/bumblebee/layers.ex | 10 + lib/bumblebee/layers/transformer.ex | 38 ++- lib/bumblebee/text/mistral.ex | 419 +++++++++++++++++++++++++++ test/bumblebee/text/mistral_test.exs | 34 +++ 5 files changed, 498 insertions(+), 4 deletions(-) create mode 100644 lib/bumblebee/text/mistral.ex create mode 100644 test/bumblebee/text/mistral_test.exs diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 72dc24da..b80204c8 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -148,6 +148,7 @@ defmodule Bumblebee do "MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering}, "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification}, "MBartModel" => {Bumblebee.Text.Mbart, :base}, + "MistralModel" => {Bumblebee.Text.Mistral, :base}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index aa5d2051..33bfe611 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1012,4 +1012,14 @@ defmodule Bumblebee.Layers do x2 = x[[.., .., .., size..-1//1]] Nx.concatenate([-x2, x1], axis: -1) end + + @doc """ + Adds a repeat layer to the network. + """ + def repeat_interleave(x, opts \\ []) do + opts = Keyword.validate!(opts, [:name, :repeats]) + Axon.layer(fn x, opts -> + Bumblebee.Utils.Nx.repeat_interleave(x, opts[:repeats], axis: 1) + end, [x], opts) + end end diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 58480b05..4781c30b 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -42,6 +42,7 @@ defmodule Bumblebee.Layers.Transformer do block_opts_keys = [ :num_attention_heads, + :num_key_value_heads, :causal?, :hidden_size, :ffn, @@ -298,6 +299,7 @@ defmodule Bumblebee.Layers.Transformer do :num_attention_heads, :hidden_size, :ffn, + :num_key_value_heads, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: Layers.none(), @@ -323,6 +325,7 @@ defmodule Bumblebee.Layers.Transformer do name = opts[:name] num_attention_heads = opts[:num_attention_heads] + num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads hidden_size = opts[:hidden_size] ffn = opts[:ffn] causal? = opts[:causal?] @@ -392,6 +395,7 @@ defmodule Bumblebee.Layers.Transformer do offset: offset, causal?: causal?, num_heads: num_attention_heads, + num_key_value_heads: num_key_value_heads, hidden_size: hidden_size, kernel_initializer: kernel_initializer, attention_head_size: attention_head_size, @@ -435,6 +439,7 @@ defmodule Bumblebee.Layers.Transformer do attention_cache: cross_attention_cache, offset: offset, num_heads: num_attention_heads, + num_key_value_heads: num_key_value_heads, hidden_size: hidden_size, kernel_initializer: kernel_initializer, attention_head_size: attention_head_size, @@ -716,6 +721,7 @@ defmodule Bumblebee.Layers.Transformer do :name, :num_heads, :hidden_size, + :num_key_value_heads, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: Layers.none(), @@ -740,6 +746,7 @@ defmodule Bumblebee.Layers.Transformer do name = opts[:name] num_heads = opts[:num_heads] + num_key_value_heads = opts[:num_key_value_heads] hidden_size = opts[:hidden_size] kernel_initializer = opts[:kernel_initializer] causal? = opts[:causal?] @@ -761,6 +768,13 @@ defmodule Bumblebee.Layers.Transformer do hidden_size end + inner_kv_size = + if num_heads == num_key_value_heads do + inner_size + else + div(hidden_size, num_heads) * num_key_value_heads + end + head_size = div(hidden_size, num_heads) query = @@ -774,21 +788,21 @@ defmodule Bumblebee.Layers.Transformer do key = key - |> Axon.dense(inner_size, + |> Axon.dense(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "key"), use_bias: key_use_bias ) - |> Layers.split_heads(num_heads) + |> Layers.split_heads(num_key_value_heads) value = value - |> Axon.dense(inner_size, + |> Axon.dense(inner_kv_size, kernel_initializer: kernel_initializer, name: join(name, "value"), use_bias: value_use_bias ) - |> Layers.split_heads(num_heads) + |> Layers.split_heads(num_key_value_heads) {query, key} = case rotary_embedding do @@ -825,6 +839,14 @@ defmodule Bumblebee.Layers.Transformer do {query, key} end + {key, value} = + if num_key_value_heads == num_heads do + {key, value} + else + num_key_value_groups = div(num_heads, num_key_value_heads) + {repeat_states(key, num_key_value_groups), repeat_states(value, num_key_value_groups)} + end + {key, value, attention_cache} = Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) @@ -882,6 +904,14 @@ defmodule Bumblebee.Layers.Transformer do {attention_output, attention_weights, attention_cache, attention_relative_bias} end + defp repeat_states(state, repeats) do + if repeats == 1 do + state + else + Layers.repeat_kv(state, repeats: repeats) + end + end + defp validate_required_keys!(opts, keys) do case keys -- Keyword.keys(opts) do [] -> :ok diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex new file mode 100644 index 00000000..19549300 --- /dev/null +++ b/lib/bumblebee/text/mistral.ex @@ -0,0 +1,419 @@ +defmodule Bumblebee.Text.Mistral do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 32000, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131072, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 4096, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 14336, + doc: "the dimensionality of intermediate layers" + ], + num_blocks: [ + default: 32, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 8, + doc: """ + the number of key-value heads used to implement Grouped Query Attention. If + this value is set to the same as the number of attention heads, it will use + regular MHA. If it's set to 1, it will use MQA, otherwise it uses Grouped Query + Attention + """ + ] + activation: [ + default: :silu, + doc: "the activation function" + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + rope_theta: [ + default: 10_000.0, + doc: "base period of RoPE embeddings" + ], + sliding_window: [ + default: 4096, + doc: "sliding window attention size" + ] + ] ++ + Shared.common_options([ + :output_hidden_states, + :output_attentions, + :num_labels, + :id_to_label + ]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Mistral model family. + + ## Architectures + + * `:base` - plain Mistral without any head on top + + * `:for_causal_language_modeling` - Mistral with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Mistral with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts \\ []) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output" + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + # TODO: Axon needs a way to specify ignoring pad tokens + # in gradient + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal?: true, + rotary_embedding: [position_ids: position_ids, max_positions: spec.max_positions], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + output_hidden_states: spec.output_hidden_states, + output_attentions: spec.output_attentions, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", atom()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.self_attention.rotary_embedding" => + "model.layers.{n}.self_attn.rotary_emb", + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", + "output_norm" => "model.norm", + "language_modeling_head.output" => "lm_head", + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs new file mode 100644 index 00000000..4fd0b898 --- /dev/null +++ b/test/bumblebee/text/mistral_test.exs @@ -0,0 +1,34 @@ +defmodule Bumblebee.Text.MistralTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + describe "integration" do + test "base model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "seanmor5/tiny-llama-test"}, architecture: :base) + + assert %Bumblebee.Text.Llama{architecture: :base} = spec + + input_ids = Nx.tensor([[1, 15043, 3186, 825, 29915, 29879, 701]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 7, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-0.4411, -1.9037, 0.9454], [0.8148, -1.4606, 0.0076], [0.9480, 0.6038, 0.1649]] + ]), + atol: 1.0e-2 + ) + end + end +end \ No newline at end of file From 27100b549dcf1b1bbaa4929d73e77088262786c5 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 19 Oct 2023 04:53:17 -0700 Subject: [PATCH 02/15] Pass mistral test --- lib/bumblebee.ex | 1 + lib/bumblebee/layers.ex | 11 ++++++++--- lib/bumblebee/layers/transformer.ex | 2 +- lib/bumblebee/shared.ex | 2 +- lib/bumblebee/text/mistral.ex | 5 +++-- lib/bumblebee/utils/tokenizers.ex | 15 +++++++-------- test/bumblebee/text/bert_tokenizer_test.exs | 3 +-- .../text/generation/logits_processing_test.exs | 4 +--- test/bumblebee/text/generation_test.exs | 15 +++++---------- test/bumblebee/text/mistral_test.exs | 16 ++++++++++------ 10 files changed, 38 insertions(+), 36 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index b80204c8..a3d23166 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -149,6 +149,7 @@ defmodule Bumblebee do "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification}, "MBartModel" => {Bumblebee.Text.Mbart, :base}, "MistralModel" => {Bumblebee.Text.Mistral, :base}, + "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 33bfe611..4eea647c 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1018,8 +1018,13 @@ defmodule Bumblebee.Layers do """ def repeat_interleave(x, opts \\ []) do opts = Keyword.validate!(opts, [:name, :repeats]) - Axon.layer(fn x, opts -> - Bumblebee.Utils.Nx.repeat_interleave(x, opts[:repeats], axis: 1) - end, [x], opts) + + Axon.layer( + fn x, opts -> + Bumblebee.Utils.Nx.repeat_interleave(x, opts[:repeats], axis: 2) + end, + [x], + opts + ) end end diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 4781c30b..a020a539 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -908,7 +908,7 @@ defmodule Bumblebee.Layers.Transformer do if repeats == 1 do state else - Layers.repeat_kv(state, repeats: repeats) + Layers.repeat_interleave(state, repeats: repeats) end end diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index dbd72869..0412bf51 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -288,7 +288,7 @@ defmodule Bumblebee.Shared do function(), keyword(), boolean(), - (-> list(Nx.Tensor.t())) + (() -> list(Nx.Tensor.t())) ) :: function() def compile_or_jit(fun, defn_options, compile?, template_fun) do if compile? do diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index 19549300..7dcbdc62 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -11,7 +11,7 @@ defmodule Bumblebee.Text.Mistral do """ ], max_positions: [ - default: 131072, + default: 131_072, doc: """ the vocabulary size of the position embedding. This corresponds to the maximum sequence length that this model can process. Typically this is set to a large value just in case, @@ -42,7 +42,7 @@ defmodule Bumblebee.Text.Mistral do regular MHA. If it's set to 1, it will use MQA, otherwise it uses Grouped Query Attention """ - ] + ], activation: [ default: :silu, doc: "the activation function" @@ -385,6 +385,7 @@ defmodule Bumblebee.Text.Mistral do hidden_size: {"hidden_size", number()}, num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", atom()}, initializer_scale: {"initializer_range", number()}, diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index 5a0b723e..60d3c66e 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -49,14 +49,13 @@ defmodule Bumblebee.Utils.Tokenizers do encodings = Enum.map(encodings, fn encoding -> - transformations = - [ - Encoding.Transformation.pad(pad_length, - pad_id: pad_id, - pad_token: pad_token, - direction: opts[:pad_direction] - ) - ] + transformations = [ + Encoding.Transformation.pad(pad_length, + pad_id: pad_id, + pad_token: pad_token, + direction: opts[:pad_direction] + ) + ] transformations = transformations ++ diff --git a/test/bumblebee/text/bert_tokenizer_test.exs b/test/bumblebee/text/bert_tokenizer_test.exs index bf61e20b..248b1f46 100644 --- a/test/bumblebee/text/bert_tokenizer_test.exs +++ b/test/bumblebee/text/bert_tokenizer_test.exs @@ -74,8 +74,7 @@ defmodule Bumblebee.Text.BertTokenizerTest do test "encoding with multiple lengths" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"}) - inputs = - Bumblebee.apply_tokenizer(tokenizer, "This is short.", length: [8, 16]) + inputs = Bumblebee.apply_tokenizer(tokenizer, "This is short.", length: [8, 16]) assert {1, 8} = Nx.shape(inputs["input_ids"]) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 0d8c945d..09c1bcd0 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -12,9 +12,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do context = context([1, 0, 0, 0]) assert_equal( - LogitsProcessing.suppressed_tokens_processor(logits, context, - suppressed_token_ids: [1, 3] - ), + LogitsProcessing.suppressed_tokens_processor(logits, context, suppressed_token_ids: [1, 3]), Nx.tensor([1.0, :neg_infinity, 3.0, :neg_infinity]) ) end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index edb974f5..5131031a 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -22,8 +22,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end @@ -36,8 +35,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) # Without :no_repeat_ngram_length we get # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} @@ -57,8 +55,7 @@ defmodule Bumblebee.Text.GenerationTest do strategy: %{type: :multinomial_sampling} ) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference @@ -81,8 +78,7 @@ defmodule Bumblebee.Text.GenerationTest do strategy: %{type: :contrastive_search, top_k: 4, alpha: 0.6} ) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} = Nx.Serving.run(serving, "I was going") @@ -104,8 +100,7 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) - serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, stream: true) + serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config, stream: true) stream = Nx.Serving.run(serving, article) assert Enum.to_list(stream) == ["PG&E", " scheduled", " the", " black"] diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs index 4fd0b898..88144458 100644 --- a/test/bumblebee/text/mistral_test.exs +++ b/test/bumblebee/text/mistral_test.exs @@ -8,11 +8,11 @@ defmodule Bumblebee.Text.MistralTest do describe "integration" do test "base model" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "seanmor5/tiny-llama-test"}, architecture: :base) + Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, architecture: :base) - assert %Bumblebee.Text.Llama{architecture: :base} = spec + assert %Bumblebee.Text.Mistral{architecture: :base} = spec - input_ids = Nx.tensor([[1, 15043, 3186, 825, 29915, 29879, 701]]) + input_ids = Nx.tensor([[1, 6312, 28709, 1526, 28808]]) inputs = %{ "input_ids" => input_ids @@ -20,15 +20,19 @@ defmodule Bumblebee.Text.MistralTest do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 7, 32} + assert Nx.shape(outputs.hidden_state) == {1, 5, 32} assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[-0.4411, -1.9037, 0.9454], [0.8148, -1.4606, 0.0076], [0.9480, 0.6038, 0.1649]] + [ + [-1.1513, -0.3565, -1.3482], + [0.5468, 0.5652, -0.4141], + [-1.2177, -0.7919, -0.7064] + ] ]), atol: 1.0e-2 ) end end -end \ No newline at end of file +end From 27b413c9dd6da576f8e4b2238273c755fb84c473 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 08:57:04 -0700 Subject: [PATCH 03/15] Finish mistral --- lib/bumblebee.ex | 1 + test/bumblebee/text/mistral_test.exs | 49 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index a3d23166..125ece21 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -150,6 +150,7 @@ defmodule Bumblebee do "MBartModel" => {Bumblebee.Text.Mbart, :base}, "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, + "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs index 88144458..78784f79 100644 --- a/test/bumblebee/text/mistral_test.exs +++ b/test/bumblebee/text/mistral_test.exs @@ -35,4 +35,53 @@ defmodule Bumblebee.Text.MistralTest do ) end end + + test "sequence classification model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "seanmor5/tiny-random-mistral-classification"}) + + assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.0255, 0.0318]]), + atol: 1.0e-4 + ) + end + + test "causal language model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, + architecture: :for_causal_language_modeling + ) + + assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + + inputs = %{ + "input_ids" => input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 4, 32000} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.1156, 0.0420, -0.0609], [0.0333, 0.0376, -0.0531], [-0.0507, -0.0097, -0.0039]] + ]), + atol: 1.0e-2 + ) + end end From 832beb6b08b0539420cd1b802f53f4e183322435 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 09:00:18 -0700 Subject: [PATCH 04/15] Update lib/bumblebee/layers/transformer.ex --- lib/bumblebee/layers/transformer.ex | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index a020a539..1dfffca9 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -840,12 +840,9 @@ defmodule Bumblebee.Layers.Transformer do end {key, value} = - if num_key_value_heads == num_heads do - {key, value} - else - num_key_value_groups = div(num_heads, num_key_value_heads) - {repeat_states(key, num_key_value_groups), repeat_states(value, num_key_value_groups)} - end + num_key_value_groups = div(num_heads, num_key_value_heads) + key = repeat_states(key, num_key_value_groups) + value = repeat_states(value, num_key_value_groups) {key, value, attention_cache} = Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) From 143f0d694b1a0d72f68c5c0e32fdee76984ec868 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 10:11:52 -0700 Subject: [PATCH 05/15] Update tests --- lib/bumblebee/layers/transformer.ex | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 1dfffca9..c444fee9 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -746,7 +746,7 @@ defmodule Bumblebee.Layers.Transformer do name = opts[:name] num_heads = opts[:num_heads] - num_key_value_heads = opts[:num_key_value_heads] + num_key_value_heads = opts[:num_key_value_heads] || num_heads hidden_size = opts[:hidden_size] kernel_initializer = opts[:kernel_initializer] causal? = opts[:causal?] @@ -839,10 +839,9 @@ defmodule Bumblebee.Layers.Transformer do {query, key} end - {key, value} = - num_key_value_groups = div(num_heads, num_key_value_heads) - key = repeat_states(key, num_key_value_groups) - value = repeat_states(value, num_key_value_groups) + num_key_value_groups = div(num_heads, num_key_value_heads) + key = repeat_states(key, num_key_value_groups) + value = repeat_states(value, num_key_value_groups) {key, value, attention_cache} = Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) From 6eb8bc7033de70b43281e498e2fc418dd44f78d8 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 16:01:58 -0700 Subject: [PATCH 06/15] Update llama to use key-value heads --- lib/bumblebee/text/llama.ex | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 87ffc904..983d3727 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -34,6 +34,10 @@ defmodule Bumblebee.Text.Llama do default: 32, doc: "the number of attention heads for each attention layer in the model" ], + num_key_value_heads: [ + default: nil, + doc: "the number of key value heads for each attention layer in the model" + ] activation: [ default: :silu, doc: "the activation function" @@ -365,6 +369,7 @@ defmodule Bumblebee.Text.Llama do hidden_size: {"hidden_size", number()}, num_blocks: {"num_hidden_layers", number()}, num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", atom()}, initializer_scale: {"initializer_range", number()}, From fc850b16abb8e4903f172348c32009df19a27539 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 16:02:34 -0700 Subject: [PATCH 07/15] Fix error --- lib/bumblebee/text/llama.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 983d3727..eb4b9bac 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -37,7 +37,7 @@ defmodule Bumblebee.Text.Llama do num_key_value_heads: [ default: nil, doc: "the number of key value heads for each attention layer in the model" - ] + ], activation: [ default: :silu, doc: "the activation function" From a2b8229228d94e5f91fb4c5b7f8ac85e01c4f108 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 20 Oct 2023 16:03:29 -0700 Subject: [PATCH 08/15] Actually use kv heads --- lib/bumblebee/text/llama.ex | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index eb4b9bac..e989d4b1 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -306,6 +306,7 @@ defmodule Bumblebee.Text.Llama do cache: cache, num_blocks: spec.num_blocks, num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, kernel_initializer: kernel_initializer(spec), layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), From 17846e0f784388f776cf0e66d50e312fe16cb37b Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 21 Oct 2023 07:53:42 -0700 Subject: [PATCH 09/15] Fix formatting --- lib/bumblebee/shared.ex | 2 +- test/bumblebee/text/generation/logits_processing_test.exs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index 0412bf51..dbd72869 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -288,7 +288,7 @@ defmodule Bumblebee.Shared do function(), keyword(), boolean(), - (() -> list(Nx.Tensor.t())) + (-> list(Nx.Tensor.t())) ) :: function() def compile_or_jit(fun, defn_options, compile?, template_fun) do if compile? do diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 09c1bcd0..0d8c945d 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -12,7 +12,9 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do context = context([1, 0, 0, 0]) assert_equal( - LogitsProcessing.suppressed_tokens_processor(logits, context, suppressed_token_ids: [1, 3]), + LogitsProcessing.suppressed_tokens_processor(logits, context, + suppressed_token_ids: [1, 3] + ), Nx.tensor([1.0, :neg_infinity, 3.0, :neg_infinity]) ) end From 3c2e56372be94dab350b0642897e1150530599d0 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 21 Oct 2023 07:55:35 -0700 Subject: [PATCH 10/15] Map mistral tokenizer to llama --- lib/bumblebee.ex | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 125ece21..a4247850 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -217,6 +217,7 @@ defmodule Bumblebee do "gpt2" => Bumblebee.Text.Gpt2Tokenizer, "layoutlm" => Bumblebee.Text.LayoutLmTokenizer, "llama" => Bumblebee.Text.LlamaTokenizer, + "mistral" => Bumblebee.Text.LlamaTokenizer, "mbart" => Bumblebee.Text.MbartTokenizer, "roberta" => Bumblebee.Text.RobertaTokenizer, "t5" => Bumblebee.Text.T5Tokenizer, From 354b9d601a8d0411d7c10383954c06069d90b171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 23 Oct 2023 09:07:04 +0200 Subject: [PATCH 11/15] Indent --- test/bumblebee/text/mistral_test.exs | 72 ++++++++++++++-------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs index 78784f79..416c493a 100644 --- a/test/bumblebee/text/mistral_test.exs +++ b/test/bumblebee/text/mistral_test.exs @@ -34,54 +34,54 @@ defmodule Bumblebee.Text.MistralTest do atol: 1.0e-2 ) end - end - test "sequence classification model" do - assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "seanmor5/tiny-random-mistral-classification"}) + test "sequence classification model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "seanmor5/tiny-random-mistral-classification"}) - assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec - input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + assert %Bumblebee.Text.Mistral{architecture: :for_sequence_classification} = spec + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) - inputs = %{ - "input_ids" => input_ids - } + inputs = %{ + "input_ids" => input_ids + } - outputs = Axon.predict(model, params, inputs) + outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.logits) == {1, 2} + assert Nx.shape(outputs.logits) == {1, 2} - assert_all_close( - outputs.logits, - Nx.tensor([[0.0255, 0.0318]]), - atol: 1.0e-4 - ) - end + assert_all_close( + outputs.logits, + Nx.tensor([[0.0255, 0.0318]]), + atol: 1.0e-4 + ) + end - test "causal language model" do - assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, - architecture: :for_causal_language_modeling - ) + test "causal language model" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "echarlaix/tiny-random-mistral"}, + architecture: :for_causal_language_modeling + ) - assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec + assert %Bumblebee.Text.Mistral{architecture: :for_causal_language_modeling} = spec - input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) + input_ids = Nx.tensor([[1, 6312, 28709, 1526]]) - inputs = %{ - "input_ids" => input_ids - } + inputs = %{ + "input_ids" => input_ids + } - outputs = Axon.predict(model, params, inputs) + outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.logits) == {1, 4, 32000} + assert Nx.shape(outputs.logits) == {1, 4, 32000} - assert_all_close( - outputs.logits[[.., 1..3, 1..3]], - Nx.tensor([ - [[0.1156, 0.0420, -0.0609], [0.0333, 0.0376, -0.0531], [-0.0507, -0.0097, -0.0039]] - ]), - atol: 1.0e-2 - ) + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.1156, 0.0420, -0.0609], [0.0333, 0.0376, -0.0531], [-0.0507, -0.0097, -0.0039]] + ]), + atol: 1.0e-2 + ) + end end end From c1fda9fd14540335cddd8ac665121242306f3ffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 23 Oct 2023 14:17:57 +0700 Subject: [PATCH 12/15] Use rotary embedding base --- lib/bumblebee/text/mistral.ex | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index 7dcbdc62..4d8df50b 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -56,9 +56,9 @@ defmodule Bumblebee.Text.Mistral do doc: "the standard deviation of the normal initializer used for initializing kernel parameters" ], - rope_theta: [ - default: 10_000.0, - doc: "base period of RoPE embeddings" + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" ], sliding_window: [ default: 4096, @@ -332,7 +332,11 @@ defmodule Bumblebee.Text.Mistral do ), block_type: :norm_first, causal?: true, - rotary_embedding: [position_ids: position_ids, max_positions: spec.max_positions], + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base + ], query_use_bias: false, key_use_bias: false, value_use_bias: false, @@ -388,6 +392,7 @@ defmodule Bumblebee.Text.Mistral do num_key_value_heads: {"num_key_value_heads", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", atom()}, + rotary_embedding_base: {"rope_theta", number()}, initializer_scale: {"initializer_range", number()}, layer_norm_epsilon: {"rms_norm_eps", number()} ) ++ Shared.common_options_from_transformers(data, spec) From 2c1b300a5e041fcded1eb171343b96bdc24edf55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 23 Oct 2023 14:22:08 +0700 Subject: [PATCH 13/15] Remove unused :sliding_window --- lib/bumblebee/text/mistral.ex | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index 4d8df50b..6f1a4b91 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -59,10 +59,6 @@ defmodule Bumblebee.Text.Mistral do rotary_embedding_base: [ default: 10_000, doc: "base for computing rotary embedding frequency" - ], - sliding_window: [ - default: 4096, - doc: "sliding window attention size" ] ] ++ Shared.common_options([ From d19e245e141adc6fea6e1ce4a951bf20f9cf6d5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 23 Oct 2023 15:07:39 +0700 Subject: [PATCH 14/15] Up --- lib/bumblebee/layers.ex | 14 +++++++++++--- lib/bumblebee/layers/transformer.ex | 24 ++++++------------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 4eea647c..5d2806f1 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1015,13 +1015,21 @@ defmodule Bumblebee.Layers do @doc """ Adds a repeat layer to the network. + + ## Options + + * `:name` - layer name + + * `:axis` - the axis to repeat along. Defaults to `-1` + """ - def repeat_interleave(x, opts \\ []) do - opts = Keyword.validate!(opts, [:name, :repeats]) + def repeat_interleave(x, times, opts \\ []) do + opts = Keyword.validate!(opts, [:name, axis: -1]) Axon.layer( fn x, opts -> - Bumblebee.Utils.Nx.repeat_interleave(x, opts[:repeats], axis: 2) + axis = Nx.axis_index(x, opts[:axis]) + Bumblebee.Utils.Nx.repeat_interleave(x, times, axis: axis) end, [x], opts diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index c444fee9..656feba5 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -761,21 +761,9 @@ defmodule Bumblebee.Layers.Transformer do attention_relative_bias = opts[:attention_relative_bias] - inner_size = - if attention_head_size = opts[:attention_head_size] do - num_heads * attention_head_size - else - hidden_size - end - - inner_kv_size = - if num_heads == num_key_value_heads do - inner_size - else - div(hidden_size, num_heads) * num_key_value_heads - end - - head_size = div(hidden_size, num_heads) + attention_head_size = opts[:attention_head_size] || div(hidden_size, num_heads) + inner_size = num_heads * attention_head_size + inner_kv_size = num_key_value_heads * attention_head_size query = query @@ -815,11 +803,11 @@ defmodule Bumblebee.Layers.Transformer do {position_ids, opts} = Keyword.pop(opts, :position_ids) {percentage, opts} = Keyword.pop(opts, :percentage) - size = trunc(head_size * percentage) + size = trunc(attention_head_size * percentage) rotary_opts = [name: join(name, "rotary_embedding")] ++ opts - if size == head_size do + if size == attention_head_size do Layers.rotary_embedding(query, key, position_ids, size, rotary_opts) else query_rotary = Axon.nx(query, & &1[[.., .., .., 0..(size - 1)//1]]) @@ -904,7 +892,7 @@ defmodule Bumblebee.Layers.Transformer do if repeats == 1 do state else - Layers.repeat_interleave(state, repeats: repeats) + Layers.repeat_interleave(state, repeats, axis: 2) end end From e3fc451a5cc4c82b2d637570db2e7bf35bcae686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 23 Oct 2023 10:24:06 +0200 Subject: [PATCH 15/15] Update lib/bumblebee/layers/transformer.ex --- lib/bumblebee/layers/transformer.ex | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 656feba5..2b59177c 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -888,12 +888,10 @@ defmodule Bumblebee.Layers.Transformer do {attention_output, attention_weights, attention_cache, attention_relative_bias} end - defp repeat_states(state, repeats) do - if repeats == 1 do - state - else - Layers.repeat_interleave(state, repeats, axis: 2) - end + defp repeat_states(state, 1), do: state + + defp repeat_states(state, times) do + Layers.repeat_interleave(state, times, axis: 2) end defp validate_required_keys!(opts, keys) do