Skip to content
6 changes: 6 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ defmodule Bumblebee do
"RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
"RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
"RobertaModel" => {Bumblebee.Text.Roberta, :base},
"SmolLM3Model" => {Bumblebee.Text.SmolLM3, :base},
"SmolLM3ForCausalLM" => {Bumblebee.Text.SmolLM3, :for_causal_language_modeling},
"SmolLM3ForQuestionAnswering" => {Bumblebee.Text.SmolLM3, :for_question_answering},
"SmolLM3ForSequenceClassification" => {Bumblebee.Text.SmolLM3, :for_sequence_classification},
"SmolLM3ForTokenClassification" => {Bumblebee.Text.SmolLM3, :for_token_classification},
"SwinModel" => {Bumblebee.Vision.Swin, :base},
"SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification},
"T5Model" => {Bumblebee.Text.T5, :base},
Expand Down Expand Up @@ -254,6 +259,7 @@ defmodule Bumblebee do
"phi" => :code_gen,
"phi3" => :llama,
"roberta" => :roberta,
"smollm3" => :smollm3,
"t5" => :t5,
"whisper" => :whisper,
"xlm-roberta" => :xlm_roberta,
Expand Down
17 changes: 15 additions & 2 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ defmodule Bumblebee.Layers.Transformer do
is configured, this option controls whether the bias from the
first block is used for all other blocks. Defaults to `false`

* `:rotary_embedding` - configuration of rotary embedding. Can be:
- a keyword list (applied to all blocks)
- a function that takes the block index and returns the configuration

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -49,8 +53,7 @@ defmodule Bumblebee.Layers.Transformer do
:layer_norm,
:block_type,
:attention_window_size,
:scale_attention_weights,
:rotary_embedding
:scale_attention_weights
]

opts =
Expand All @@ -60,6 +63,7 @@ defmodule Bumblebee.Layers.Transformer do
[
:name,
:num_blocks,
:rotary_embedding,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -80,6 +84,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_mask = opts[:cross_attention_mask]
cross_attention_head_mask = opts[:cross_attention_head_mask]
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -109,6 +114,13 @@ defmodule Bumblebee.Layers.Transformer do
opts[:attention_relative_bias] || Layers.none()
end

block_rotary_embedding =
case rotary_embedding do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
config when is_list(config) -> config
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
Expand All @@ -121,6 +133,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask: block_cross_attention_head_mask,
block_cache: block_cache,
offset: offset,
rotary_embedding: block_rotary_embedding,
name: join(name, idx)
] ++ block_opts
)
Expand Down
6 changes: 6 additions & 0 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
mask: "<mask>"
}
},
smollm3: %{
special_tokens: %{
eos: "<|im_end|>",
pad: "<|im_end|>"
}
},
t5: %{
special_tokens: %{
bos: "<s>",
Expand Down
Loading
Loading