diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..0e858bc9 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -188,6 +188,11 @@ defmodule Bumblebee do "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification}, "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling}, "RobertaModel" => {Bumblebee.Text.Roberta, :base}, + "SmolLM3Model" => {Bumblebee.Text.SmolLM3, :base}, + "SmolLM3ForCausalLM" => {Bumblebee.Text.SmolLM3, :for_causal_language_modeling}, + "SmolLM3ForQuestionAnswering" => {Bumblebee.Text.SmolLM3, :for_question_answering}, + "SmolLM3ForSequenceClassification" => {Bumblebee.Text.SmolLM3, :for_sequence_classification}, + "SmolLM3ForTokenClassification" => {Bumblebee.Text.SmolLM3, :for_token_classification}, "SwinModel" => {Bumblebee.Vision.Swin, :base}, "SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification}, "T5Model" => {Bumblebee.Text.T5, :base}, @@ -254,6 +259,7 @@ defmodule Bumblebee do "phi" => :code_gen, "phi3" => :llama, "roberta" => :roberta, + "smollm3" => :smollm3, "t5" => :t5, "whisper" => :whisper, "xlm-roberta" => :xlm_roberta, diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 6cf93cd6..59ad9595 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -21,6 +21,10 @@ defmodule Bumblebee.Layers.Transformer do is configured, this option controls whether the bias from the first block is used for all other blocks. Defaults to `false` + * `:rotary_embedding` - configuration of rotary embedding. Can be: + - a keyword list (applied to all blocks) + - a function that takes the block index and returns the configuration + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -49,8 +53,7 @@ defmodule Bumblebee.Layers.Transformer do :layer_norm, :block_type, :attention_window_size, - :scale_attention_weights, - :rotary_embedding + :scale_attention_weights ] opts = @@ -60,6 +63,7 @@ defmodule Bumblebee.Layers.Transformer do [ :name, :num_blocks, + :rotary_embedding, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -80,6 +84,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_mask = opts[:cross_attention_mask] cross_attention_head_mask = opts[:cross_attention_head_mask] cache = opts[:cache] + rotary_embedding = opts[:rotary_embedding] block_opts = Keyword.take(opts, block_opts_keys) @@ -109,6 +114,13 @@ defmodule Bumblebee.Layers.Transformer do opts[:attention_relative_bias] || Layers.none() end + block_rotary_embedding = + case rotary_embedding do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + config when is_list(config) -> config + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, @@ -121,6 +133,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_head_mask: block_cross_attention_head_mask, block_cache: block_cache, offset: offset, + rotary_embedding: block_rotary_embedding, name: join(name, idx) ] ++ block_opts ) diff --git a/lib/bumblebee/text/pre_trained_tokenizer.ex b/lib/bumblebee/text/pre_trained_tokenizer.ex index 59ab3468..599ac647 100644 --- a/lib/bumblebee/text/pre_trained_tokenizer.ex +++ b/lib/bumblebee/text/pre_trained_tokenizer.ex @@ -211,6 +211,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do mask: "" } }, + smollm3: %{ + special_tokens: %{ + eos: "<|im_end|>", + pad: "<|im_end|>" + } + }, t5: %{ special_tokens: %{ bos: "", diff --git a/lib/bumblebee/text/smollm3.ex b/lib/bumblebee/text/smollm3.ex new file mode 100644 index 00000000..873abb4c --- /dev/null +++ b/lib/bumblebee/text/smollm3.ex @@ -0,0 +1,614 @@ +defmodule Bumblebee.Text.SmolLM3 do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 128_256, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 65536, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048. + """ + ], + hidden_size: [ + default: 4096, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 11008, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: nil, + doc: """ + the size of the key, value, and query projection per attention head. + Defaults to `div(hidden_size, num_attention_heads)` + """ + ], + num_blocks: [ + default: 32, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 4, + doc: "the number of key value heads for each attention layer in the model" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 5_000_000, + doc: "base for computing rotary embedding frequency" + ], + rotary_embedding_scaling_strategy: [ + default: nil, + doc: """ + scaling configuration for rotary embedding. Currently the supported values are: + + * `%{type: :linear, factor: number()}` + + * `%{type: :dynamic, factor: number()}` + + * `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}` + + For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases + """ + ], + rotary_embedding_enabled: [ + default: nil, + doc: """ + a list of booleans specifying whether rotary embeddings are enabled for the block that corresponds to the index. + Defaults to `nil` which enables rotary embeddings for all blocks. + """ + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + tie_word_embeddings: [ + default: true, + doc: "whether to tie input and output embedding weights" + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + SmolLM3 is a 3B parameter language model designed to push the boundaries of small models. + It supports dual mode reasoning, 6 languages and long context. SmolLM3 is a fully open model + that offers strong performance at the 3B–4B scale. + + Key features + + * Instruct model optimized for hybrid reasoning + * Fully open model: open weights + full training details including public data mixture and training configs + * Long context: Trained on 64k context and supports up to 128k tokens using YARN extrapolation (not implemented in `bumblebee`) + * Multilingual: 6 natively supported (English, French, Spanish, German, Italian, and Portuguese) + + For best results, follow the [chat template](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja). + To disable reasoning, append `\\n\\n` to the prompt. + + For more details see: https://huggingface.co/HuggingFaceTB/SmolLM3-3B + + ## Architectures + + * `:base` - plain SmolLM3 without any head on top + + * `:for_causal_language_modeling` - SmolLM3 with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - SmolLM3 with a sequence + classification head. The head returns logits corresponding to + possible classes + + * `:for_token_classification` - SmolLM3 with a token classification + head. The head returns logits for each token in the original + sequence + + * `:for_question_answering` - SmolLM3 with a span classification head. + The head returns logits for the span start and end positions + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification, + :for_token_classification, + :for_question_answering + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_token_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Axon.dropout( + rate: 0.1, + name: "token_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "token_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_question_answering} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, 2, + kernel_initializer: kernel_initializer(spec), + name: "question_answering_head.output" + ) + + {start_logits, end_logits} = Layers.split_pair(logits) + + Layers.output(%{ + start_logits: start_logits, + end_logits: end_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + rotary_embedding_config = [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ] + + rotary_embedding = + case opts[:rotary_embedding_enabled] do + nil -> + rotary_embedding_config + + rotary_embedding_enabled -> + fn layer_index -> + if Enum.at(rotary_embedding_enabled, layer_index) do + rotary_embedding_config + else + nil + end + end + end + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + attention_head_size: spec.attention_head_size, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal: true, + rotary_embedding: rotary_embedding, + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + scaling_strategy_converter = fn name, value -> + # "type" has been renamed to "rope_type" + value = + case Map.pop(value, "type") do + {nil, value} -> value + {type, value} -> Map.put(value, "rope_type", type) + end + + case value do + %{"rope_type" => "linear", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :linear, factor: factor}} + + %{"rope_type" => "dynamic", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :dynamic, factor: factor}} + + %{ + "rope_type" => "llama3", + "factor" => factor, + "low_freq_factor" => low_frequency_factor, + "high_freq_factor" => high_frequency_factor, + "original_max_position_embeddings" => original_max_positions + } + when is_number(factor) and is_number(low_frequency_factor) and + is_number(high_frequency_factor) and + is_number(original_max_positions) -> + {:ok, + %{ + type: :llama3, + factor: factor, + low_frequency_factor: low_frequency_factor, + high_frequency_factor: high_frequency_factor, + original_max_positions: original_max_positions + }} + + _other -> + {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} + end + end + + rotary_embedding_enabled_converter = fn name, value -> + case value do + no_rope_layers when is_list(no_rope_layers) -> + {:ok, %{rotary_embedding_enabled: Enum.map(no_rope_layers, &(&1 == 1))}} + + _other -> + {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} + end + end + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_scaling_strategy: + {"rope_scaling", optional(scaling_strategy_converter)}, + rotary_embedding_enabled: + {"no_rope_layers", optional(rotary_embedding_enabled_converter)}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + base_mapping = %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", + "output_norm" => "model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), + "sequence_classification_head.output" => "score", + "token_classification_head.output" => "score", + "question_answering_head.output" => "qa_outputs" + } + + rotary_mapping = + case spec.rotary_embedding_enabled do + nil -> + [] + + rotary_embedding_enabled -> + Enum.with_index(rotary_embedding_enabled, fn rope, index -> + if rope do + {"decoder.blocks.#{index}.self_attention.rotary_embedding", + "model.layers.#{index}.self_attn.rotary_emb"} + end + end) + end + + mapping = Map.merge(base_mapping, Map.new(rotary_mapping)) + + case spec do + %{architecture: :for_question_answering} -> + for {key, value} <- mapping, into: %{} do + {key, String.replace_leading(value, "model.", "transformer.")} + end + + _else -> + mapping + end + end + end +end diff --git a/mix.exs b/mix.exs index 23e8f8e8..63804d50 100644 --- a/mix.exs +++ b/mix.exs @@ -102,6 +102,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Text.Phi, Bumblebee.Text.Phi3, Bumblebee.Text.Roberta, + Bumblebee.Text.SmolLM3, Bumblebee.Text.T5, Bumblebee.Vision.BlipVision, Bumblebee.Vision.ClipVision, diff --git a/test/bumblebee/text/pre_trained_tokenizer_test.exs b/test/bumblebee/text/pre_trained_tokenizer_test.exs index 19537fdc..ab0e891c 100644 --- a/test/bumblebee/text/pre_trained_tokenizer_test.exs +++ b/test/bumblebee/text/pre_trained_tokenizer_test.exs @@ -416,6 +416,36 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do ) end + test ":smollm3" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "HuggingFaceTB/SmolLM3-3B"}) + + assert %Bumblebee.Text.PreTrainedTokenizer{type: :smollm3} = tokenizer + + inputs = + Bumblebee.apply_tokenizer(tokenizer, [ + "Test sentence with .", + {"Question?", "Answer"} + ]) + + assert_equal( + inputs["input_ids"], + Nx.tensor([ + [2323, 11914, 449, 366, 11508, 14611], + [14924, 30, 16533, 128_012, 128_012, 128_012] + ]) + ) + + assert_equal( + inputs["attention_mask"], + Nx.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) + ) + + assert_equal( + inputs["token_type_ids"], + Nx.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]]) + ) + end + test ":t5" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-t5/t5-small"}) diff --git a/test/bumblebee/text/smollm3_test.exs b/test/bumblebee/text/smollm3_test.exs new file mode 100644 index 00000000..cde4f5de --- /dev/null +++ b/test/bumblebee/text/smollm3_test.exs @@ -0,0 +1,163 @@ +defmodule Bumblebee.Text.SmolLM3Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-SmolLM3Model"}, + architecture: :base + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [0.2562, -0.4248, -0.1371], + [-0.8060, -0.1415, 0.3646], + [-0.4071, -1.0187, -1.1379] + ] + ]) + ) + end + + test ":for_question_answering" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-SmolLM3ForQuestionAnswering"}, + architecture: :for_question_answering + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :for_question_answering} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.end_logits) == {1, 10} + + assert_all_close( + outputs.end_logits, + Nx.tensor([ + [ + 0.1937, + -0.0345, + 0.0913, + -0.0821, + -0.0658, + -0.0438, + -0.0525, + -0.0771, + -0.1270, + -0.1270 + ] + ]) + ) + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-SmolLM3ForSequenceClassification"}, + architecture: :for_sequence_classification + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[-0.0567, 0.0249]]) + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-SmolLM3ForCausalLM"}, + architecture: :for_causal_language_modeling + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [-0.0438, 0.2976, 0.1326], + [0.0285, 0.0493, 0.0535], + [0.0457, 0.2303, 0.0854] + ] + ]) + ) + end + + test ":for_token_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-SmolLM3ForTokenClassification"}, + architecture: :for_token_classification + ) + + assert %Bumblebee.Text.SmolLM3{architecture: :for_token_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([ + [ + [-0.0053, -0.0636], + [-0.1258, 0.2581], + [-0.0500, 0.0485], + [-0.1136, -0.0659], + [0.0423, 0.1303], + [0.0800, 0.0743], + [-0.1378, 0.0709], + [-0.0322, 0.1488], + [-0.0916, -0.0296], + [-0.0917, -0.0293] + ] + ]) + ) + end +end