Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mistral #264

Merged
merged 15 commits into from
Oct 23, 2023
4 changes: 4 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ defmodule Bumblebee do
"MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering},
"MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification},
"MBartModel" => {Bumblebee.Text.Mbart, :base},
"MistralModel" => {Bumblebee.Text.Mistral, :base},
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -214,6 +217,7 @@ defmodule Bumblebee do
"gpt2" => Bumblebee.Text.Gpt2Tokenizer,
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"llama" => Bumblebee.Text.LlamaTokenizer,
"mistral" => Bumblebee.Text.LlamaTokenizer,
"mbart" => Bumblebee.Text.MbartTokenizer,
"roberta" => Bumblebee.Text.RobertaTokenizer,
"t5" => Bumblebee.Text.T5Tokenizer,
Expand Down
23 changes: 23 additions & 0 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1012,4 +1012,27 @@ defmodule Bumblebee.Layers do
x2 = x[[.., .., .., size..-1//1]]
Nx.concatenate([-x2, x1], axis: -1)
end

@doc """
Adds a repeat layer to the network.

## Options

* `:name` - layer name

* `:axis` - the axis to repeat along. Defaults to `-1`

"""
def repeat_interleave(x, times, opts \\ []) do
opts = Keyword.validate!(opts, [:name, axis: -1])

Axon.layer(
fn x, opts ->
axis = Nx.axis_index(x, opts[:axis])
Bumblebee.Utils.Nx.repeat_interleave(x, times, axis: axis)
end,
[x],
opts
)
end
end
40 changes: 26 additions & 14 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ defmodule Bumblebee.Layers.Transformer do

block_opts_keys = [
:num_attention_heads,
:num_key_value_heads,
:causal?,
:hidden_size,
:ffn,
Expand Down Expand Up @@ -298,6 +299,7 @@ defmodule Bumblebee.Layers.Transformer do
:num_attention_heads,
:hidden_size,
:ffn,
:num_key_value_heads,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: Layers.none(),
Expand All @@ -323,6 +325,7 @@ defmodule Bumblebee.Layers.Transformer do

name = opts[:name]
num_attention_heads = opts[:num_attention_heads]
num_key_value_heads = opts[:num_key_value_heads] || num_attention_heads
hidden_size = opts[:hidden_size]
ffn = opts[:ffn]
causal? = opts[:causal?]
Expand Down Expand Up @@ -392,6 +395,7 @@ defmodule Bumblebee.Layers.Transformer do
offset: offset,
causal?: causal?,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
kernel_initializer: kernel_initializer,
attention_head_size: attention_head_size,
Expand Down Expand Up @@ -435,6 +439,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_cache: cross_attention_cache,
offset: offset,
num_heads: num_attention_heads,
num_key_value_heads: num_key_value_heads,
hidden_size: hidden_size,
kernel_initializer: kernel_initializer,
attention_head_size: attention_head_size,
Expand Down Expand Up @@ -716,6 +721,7 @@ defmodule Bumblebee.Layers.Transformer do
:name,
:num_heads,
:hidden_size,
:num_key_value_heads,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: Layers.none(),
Expand All @@ -740,6 +746,7 @@ defmodule Bumblebee.Layers.Transformer do

name = opts[:name]
num_heads = opts[:num_heads]
num_key_value_heads = opts[:num_key_value_heads] || num_heads
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal? = opts[:causal?]
Expand All @@ -754,14 +761,9 @@ defmodule Bumblebee.Layers.Transformer do

attention_relative_bias = opts[:attention_relative_bias]

inner_size =
if attention_head_size = opts[:attention_head_size] do
num_heads * attention_head_size
else
hidden_size
end

head_size = div(hidden_size, num_heads)
attention_head_size = opts[:attention_head_size] || div(hidden_size, num_heads)
inner_size = num_heads * attention_head_size
inner_kv_size = num_key_value_heads * attention_head_size

query =
query
Expand All @@ -774,21 +776,21 @@ defmodule Bumblebee.Layers.Transformer do

key =
key
|> Axon.dense(inner_size,
|> Axon.dense(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "key"),
use_bias: key_use_bias
)
|> Layers.split_heads(num_heads)
|> Layers.split_heads(num_key_value_heads)

value =
value
|> Axon.dense(inner_size,
|> Axon.dense(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "value"),
use_bias: value_use_bias
)
|> Layers.split_heads(num_heads)
|> Layers.split_heads(num_key_value_heads)

{query, key} =
case rotary_embedding do
Expand All @@ -801,11 +803,11 @@ defmodule Bumblebee.Layers.Transformer do
{position_ids, opts} = Keyword.pop(opts, :position_ids)
{percentage, opts} = Keyword.pop(opts, :percentage)

size = trunc(head_size * percentage)
size = trunc(attention_head_size * percentage)

rotary_opts = [name: join(name, "rotary_embedding")] ++ opts

if size == head_size do
if size == attention_head_size do
Layers.rotary_embedding(query, key, position_ids, size, rotary_opts)
else
query_rotary = Axon.nx(query, & &1[[.., .., .., 0..(size - 1)//1]])
Expand All @@ -825,6 +827,10 @@ defmodule Bumblebee.Layers.Transformer do
{query, key}
end

num_key_value_groups = div(num_heads, num_key_value_heads)
key = repeat_states(key, num_key_value_groups)
value = repeat_states(value, num_key_value_groups)

{key, value, attention_cache} =
Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset)

Expand Down Expand Up @@ -882,6 +888,12 @@ defmodule Bumblebee.Layers.Transformer do
{attention_output, attention_weights, attention_cache, attention_relative_bias}
end

defp repeat_states(state, 1), do: state

defp repeat_states(state, times) do
Layers.repeat_interleave(state, times, axis: 2)
end

defp validate_required_keys!(opts, keys) do
case keys -- Keyword.keys(opts) do
[] -> :ok
Expand Down
6 changes: 6 additions & 0 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ defmodule Bumblebee.Text.Llama do
default: 32,
doc: "the number of attention heads for each attention layer in the model"
],
num_key_value_heads: [
default: nil,
doc: "the number of key value heads for each attention layer in the model"
],
activation: [
default: :silu,
doc: "the activation function"
Expand Down Expand Up @@ -302,6 +306,7 @@ defmodule Bumblebee.Text.Llama do
cache: cache,
num_blocks: spec.num_blocks,
num_attention_heads: spec.num_attention_heads,
num_key_value_heads: spec.num_key_value_heads,
hidden_size: spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon),
Expand Down Expand Up @@ -365,6 +370,7 @@ defmodule Bumblebee.Text.Llama do
hidden_size: {"hidden_size", number()},
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", atom()},
initializer_scale: {"initializer_range", number()},
Expand Down
Loading
Loading