Skip to content

Commit

Permalink
Add T5 model (#159)
Browse files Browse the repository at this point in the history
* Failing draft of T5

* Fix all tests with new transformer options

* Apply position bias at each layer

* Do not scale query

* Fix some minor implementation bugs:

* Pass tests

* Add conditional generation head

* Fix gated act

* Apply suggestions from code review

Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>

* Fix tests

* Update test

* Account for attention projection size in the cache

* Fix relative attention bias in autoregression with cache

* Remove :output_norm

* Add tokenizer tests

* Pass :scale_query? to attention weights layer

* Simplify

* Refactor gated activation

* attention_projection_size -> attention_head_size

* Refactor relative attention bias

* Make layer norm option optional

* Fix albert tests

* Remove usage of deprecated power

---------

Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
  • Loading branch information
seanmor5 and jonatanklosko committed Feb 20, 2023
1 parent 18936ab commit a2df872
Show file tree
Hide file tree
Showing 24 changed files with 1,199 additions and 70 deletions.
3 changes: 3 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ defmodule Bumblebee do
"RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
"RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
"RobertaModel" => {Bumblebee.Text.Roberta, :base},
"T5Model" => {Bumblebee.Text.T5, :base},
"T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
"ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification},
"ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling},
"ViTModel" => {Bumblebee.Vision.Vit, :base},
Expand Down Expand Up @@ -148,6 +150,7 @@ defmodule Bumblebee do
"layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
"mbart" => Bumblebee.Text.MbartTokenizer,
"roberta" => Bumblebee.Text.RobertaTokenizer,
"t5" => Bumblebee.Text.T5Tokenizer,
"whisper" => Bumblebee.Text.WhisperTokenizer,
"xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
}
Expand Down
8 changes: 6 additions & 2 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ defmodule Bumblebee.Audio.Whisper do
dropout_rate: spec.dropout_rate,
attention_dropout_rate: spec.attention_dropout_rate,
key_use_bias: false,
layer_norm_epsilon: 1.0e-5,
layer_norm: [
epsilon: 1.0e-5
],
norm_placement: :first,
ffn: [
intermediate_size: spec.encoder_intermediate_size,
Expand Down Expand Up @@ -473,7 +475,9 @@ defmodule Bumblebee.Audio.Whisper do
dropout_rate: spec.dropout_rate,
attention_dropout_rate: spec.attention_dropout_rate,
key_use_bias: false,
layer_norm_epsilon: 1.0e-5,
layer_norm: [
epsilon: 1.0e-5
],
norm_placement: :first,
ffn: [
intermediate_size: spec.decoder_intermediate_size,
Expand Down
3 changes: 3 additions & 0 deletions lib/bumblebee/diffusion/layers/unet.ex
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ defmodule Bumblebee.Diffusion.Layers.UNet do
num_blocks: depth,
num_attention_heads: num_heads,
hidden_size: hidden_size,
layer_norm: [
epsilon: 1.0e-5
],
dropout_rate: dropout,
norm_placement: :first,
ffn: &ffn_geglu(&1, hidden_size, dropout: dropout, name: &2),
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/vae_kl.ex
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ defmodule Bumblebee.Diffusion.VaeKl do
|> Axon.group_norm(32, epsilon: 1.0e-6, name: join(name, "norm"))
|> Axon.reshape({:batch, :auto, channels})

{hidden_state, _attention, _self_attention_cache} =
{hidden_state, _attention, _self_attention_cache, _position_bias} =
Layers.Transformer.multi_head_attention(hidden_state, hidden_state, hidden_state,
num_heads: num_heads,
hidden_size: channels,
Expand Down
163 changes: 157 additions & 6 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,141 @@ defmodule Bumblebee.Layers do
end)
end

@doc """
Computes relative attention bias.
"""
def relative_attention_bias(query, key, attention_cache, offset, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
bidirectional: true,
num_heads: 8,
num_buckets: 32,
max_distance: 128
])

name = opts[:name]

relative_position_buckets =
Axon.layer(
&compute_relative_position_buckets/4,
[query, key, Axon.optional(attention_cache)],
bidirectional: opts[:bidirectional],
num_buckets: opts[:num_buckets],
max_distance: opts[:max_distance]
)

bias =
relative_position_buckets
|> Axon.embedding(opts[:num_buckets], opts[:num_heads], name: name)
|> Axon.transpose([2, 0, 1])
|> Axon.nx(&Nx.new_axis(&1, 0))

Axon.layer(
fn bias, query, offset, _opts ->
case offset do
%Axon.None{} ->
bias

offset ->
mask_shift = Nx.as_type(offset, {:s, 64})
query_length = Nx.axis_size(query, 1)
Nx.slice_along_axis(bias, mask_shift, query_length, axis: 2)
end
end,
[bias, query, Axon.optional(offset)]
)
end

defnp compute_relative_position_buckets(query, key, attention_cache, opts \\ []) do
opts = keyword!(opts, mode: :train, bidirectional: true, num_buckets: 32, max_distance: 128)

{key_length, query_length} = key_query_lengths(query, key, attention_cache)

context_position = Nx.iota({query_length, 1})
memory_position = Nx.iota({1, key_length})
relative_position = memory_position - context_position

{num_buckets, relative_buckets, relative_position} =
bidirectional_buckets(relative_position, opts[:num_buckets], opts[:bidirectional])

max_exact = Nx.quotient(num_buckets, 2)
is_small = Nx.less(relative_position, max_exact)

relative_position_if_large =
max_exact +
Nx.log(relative_position / max_exact) / Nx.log(opts[:max_distance] / max_exact) *
(num_buckets - max_exact)

relative_position_if_large =
Nx.min(
relative_position_if_large,
Nx.broadcast(num_buckets - 1, Nx.shape(relative_position_if_large))
)
|> Nx.as_type(:s64)

relative_buckets + Nx.select(is_small, relative_position, relative_position_if_large)
end

deftransformp key_query_lengths(query, key, attention_cache) do
case attention_cache do
%Axon.None{} ->
{Nx.axis_size(key, 1), Nx.axis_size(query, 1)}

attention_cache ->
key_length = Nx.axis_size(attention_cache.key, 1)
{key_length, key_length}
end
end

deftransformp bidirectional_buckets(relative_position, num_buckets, bidirectional) do
relative_buckets = 0

if bidirectional do
num_buckets = div(num_buckets, 2)

relative_buckets =
Nx.add(relative_buckets, Nx.multiply(Nx.greater(relative_position, 0), num_buckets))

relative_position = Nx.abs(relative_position)
{num_buckets, relative_buckets, relative_position}
else
relative_position =
relative_position
|> Nx.min(Nx.broadcast(0, Nx.shape(relative_position)))
|> Nx.negate()

{num_buckets, relative_buckets, relative_position}
end
end

@doc """
Computes attention weights.
## Options
* `:scale_query?` - whether to scale the query. Defaults to `true`
"""
def attention_weights(query, key, bias) do
Axon.layer(&attention_weights_impl/4, [query, key, bias])
def attention_weights(query, key, bias, opts \\ []) do
Axon.layer(&attention_weights_impl/4, [query, key, bias], opts)
end

defnp attention_weights_impl(query, key, bias, _opts \\ []) do
defnp attention_weights_impl(query, key, bias, opts \\ []) do
opts = keyword!(opts, mode: :train, scale_query?: true)

key = Nx.transpose(key, axes: [0, 2, 1, 3])
query = Nx.transpose(query, axes: [0, 2, 1, 3])

depth = Nx.axis_size(query, -1)
scaled_query = query / Nx.sqrt(depth)
query =
if opts[:scale_query?] do
depth = Nx.axis_size(query, -1)
query / Nx.sqrt(depth)
else
query
end

weights = Nx.dot(scaled_query, [3], [0, 1], key, [3], [0, 1])
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
weights = weights + bias
Axon.Activations.softmax(weights, axis: -1)
end
Expand Down Expand Up @@ -797,4 +917,35 @@ defmodule Bumblebee.Layers do
op_name: :prepend_embedding
)
end

@doc """
Adds an RMS Normalization layer to the network.
"""
# TODO: Add to Axon
def rms_norm(input, opts \\ []) do
opts =
Keyword.validate!(opts, [:name, channel_index: -1, epsilon: 1.0e-6, initializer: :ones])

weight =
Axon.param("weight", &Axon.Shape.norm_param(&1, opts[:channel_index]),
initializer: opts[:initializer]
)

Axon.layer(&rms_norm_impl/3, [input, weight], name: opts[:name], epsilon: opts[:epsilon])
end

defnp rms_norm_impl(input, weight, opts \\ []) do
opts = keyword!(opts, epsilon: 1.0e-6, channel_index: -1, mode: :train)

variance =
input
|> Nx.pow(2)
|> Nx.mean(axes: [opts[:channel_index]], keep_axes: true)

x =
input
|> Nx.multiply(Nx.rsqrt(variance + opts[:epsilon]))

x * weight
end
end
18 changes: 13 additions & 5 deletions lib/bumblebee/layers/decoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ defmodule Bumblebee.Layers.Decoder do
* `:hidden_size` - the dimensionality of the hidden layers
* `:attention_head_size` - the size of the key, value, and query
projection per attention head
* `:decoder_num_blocks` - the number of Transformer blocks in the decoder
* `:decoder_num_attention_heads` - the number of decoder attention heads
Expand All @@ -53,16 +56,22 @@ defmodule Bumblebee.Layers.Decoder do
encoder_num_attention_heads = opts[:encoder_num_attention_heads]
encoder_sequence_length = opts[:encoder_sequence_length]

decoder_head_size =
opts[:attention_head_size] || div(hidden_size, decoder_num_attention_heads)

encoder_head_size =
opts[:attention_head_size] || div(hidden_size, encoder_num_attention_heads)

self_attention =
attention_cache(batch_size, max_length, hidden_size, decoder_num_attention_heads)
attention_cache(batch_size, max_length, decoder_num_attention_heads, decoder_head_size)

cross_attention =
if encoder_sequence_length do
attention_cache(
batch_size,
encoder_sequence_length,
hidden_size,
encoder_num_attention_heads
encoder_num_attention_heads,
encoder_head_size
)
else
%Axon.None{}
Expand All @@ -80,8 +89,7 @@ defmodule Bumblebee.Layers.Decoder do
%{blocks: blocks, offset: offset, attention_mask: attention_mask}
end

defp attention_cache(batch_size, sequence_length, hidden_size, num_heads) do
head_size = div(hidden_size, num_heads)
defp attention_cache(batch_size, sequence_length, num_heads, head_size) do
shape = {batch_size, sequence_length, num_heads, head_size}
zeros = Nx.broadcast(0.0, shape)
%{key: zeros, value: zeros}
Expand Down
Loading

0 comments on commit a2df872

Please sign in to comment.