diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 27adc251..5fd341d4 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -272,6 +272,68 @@ defmodule Bumblebee.Text do @spec fill_mask(Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t() defdelegate fill_mask(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.FillMask + @type question_answering_input :: %{question: String.t(), context: String.t()} + + @type question_answering_output :: %{ + predictions: list(question_answering_result()) + } + + @type question_answering_result :: %{ + text: String.t(), + start: number(), + end: number(), + score: number() + } + @doc """ + Builds serving for the question answering task. + + The serving accepts `t:question_answering_input/0` and returns + `t:question_answering_output/0`. A list of inputs is also supported. + + The question answering task finds the most probable answer to a + question within the given context text. + + ## Options + + * `:compile` - compiles all computations for predefined input shapes + during serving initialization. Should be a keyword list with the + following keys: + + * `:batch_size` - the maximum batch size of the input. Inputs + are optionally padded to always match this batch size. Note + that the batch size refers to the number of prompts to classify, + while the model prediction is made for every combination of + prompt and label + + * `:sequence_length` - the maximum input sequence length. Input + sequences are always padded/truncated to match that length + + It is advised to set this option in production and also configure + a defn compiler using `:defn_options` to maximally reduce inference + time. + + * `:defn_options` - the options for JIT compilation. Defaults to `[]` + + ## Examples + + {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"}) + + serving = Bumblebee.Text.question_answering(roberta, tokenizer) + + input = %{question: "What's my name?", context: "My name is Sarah and I live in London."} + Nx.Serving.run(serving, input) + #=> %{results: [%{end: 16, score: 0.81039959192276, start: 11, text: "Sarah"}]} + + """ + @spec question_answering( + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + keyword() + ) :: Nx.Serving.t() + defdelegate question_answering(model_info, tokenizer, opts \\ []), + to: Bumblebee.Text.QuestionAnswering + @type zero_shot_classification_input :: String.t() @type zero_shot_classification_output :: %{ predictions: list(zero_shot_classification_prediction()) @@ -282,7 +344,8 @@ defmodule Bumblebee.Text do Builds serving for the zero-shot classification task. The serving accepts `t:zero_shot_classification_input/0` and returns - `t:zero_shot_classification_output/0`. + `t:zero_shot_classification_output/0`. A list of inputs is also + supported. The zero-shot task predicts zero-shot labels for a given sequence by proposing each label as a premise-hypothesis pairing. diff --git a/lib/bumblebee/text/question_answering.ex b/lib/bumblebee/text/question_answering.ex new file mode 100644 index 00000000..6047a567 --- /dev/null +++ b/lib/bumblebee/text/question_answering.ex @@ -0,0 +1,103 @@ +defmodule Bumblebee.Text.QuestionAnswering do + @moduledoc false + + alias Bumblebee.Utils + alias Bumblebee.Shared + + def question_answering(model_info, tokenizer, opts \\ []) do + %{model: model, params: params, spec: spec} = model_info + Shared.validate_architecture!(spec, :for_question_answering) + + opts = Keyword.validate!(opts, [:compile, defn_options: []]) + + compile = opts[:compile] + defn_options = opts[:defn_options] + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + if compile != nil and (batch_size == nil or sequence_length == nil) do + raise ArgumentError, + "expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}" + end + + {_init_fun, predict_fun} = Axon.build(model) + + scores_fun = fn params, input -> + outputs = predict_fun.(params, input) + start_scores = Axon.Activations.softmax(outputs.start_logits) + end_scores = Axon.Activations.softmax(outputs.end_logits) + %{start_scores: start_scores, end_scores: end_scores} + end + + Nx.Serving.new( + fn -> + predict_fun = + Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :s64), + "attention_mask" => Nx.template({batch_size, sequence_length}, :s64), + "token_type_ids" => Nx.template({batch_size, sequence_length}, :s64) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + + predict_fun.(params, inputs) + end + end, + batch_size: batch_size + ) + |> Nx.Serving.client_preprocessing(fn raw_input -> + {raw_inputs, multi?} = + Shared.validate_serving_input!(raw_input, fn + %{question: question, context: context} + when is_binary(question) and is_binary(context) -> + {:ok, {question, context}} + + other -> + {:error, + "expected input map with :question and :context keys, got: #{inspect(other)}"} + end) + + all_inputs = + Bumblebee.apply_tokenizer(tokenizer, raw_inputs, + length: sequence_length, + return_token_type_ids: true, + return_offsets: true + ) + + inputs = Map.take(all_inputs, ["input_ids", "attention_mask", "token_type_ids"]) + {Nx.Batch.concatenate([inputs]), {all_inputs, raw_inputs, multi?}} + end) + |> Nx.Serving.client_postprocessing(fn outputs, _metadata, {inputs, raw_inputs, multi?} -> + Enum.zip_with( + [raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)], + fn [{_question_text, context_text}, inputs, outputs] -> + start_idx = outputs.start_scores |> Nx.argmax() |> Nx.to_number() + end_idx = outputs.end_scores |> Nx.argmax() |> Nx.to_number() + + start = Nx.to_number(inputs["start_offsets"][start_idx]) + ending = Nx.to_number(inputs["end_offsets"][end_idx]) + + score = + outputs.start_scores[start_idx] + |> Nx.multiply(outputs.end_scores[end_idx]) + |> Nx.to_number() + + answer_text = binary_part(context_text, start, ending - start) + + results = [ + %{text: answer_text, start: start, end: ending, score: score} + ] + + %{results: results} + end + ) + |> Shared.normalize_output(multi?) + end) + end +end diff --git a/notebooks/examples.livemd b/notebooks/examples.livemd index 41a097fd..0572a062 100644 --- a/notebooks/examples.livemd +++ b/notebooks/examples.livemd @@ -199,17 +199,9 @@ context_input = question = Kino.Input.read(question_input) context = Kino.Input.read(context_input) -inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context}) +serving = Bumblebee.Text.question_answering(roberta, tokenizer) -outputs = Axon.predict(roberta.model, roberta.params, inputs) - -answer_start_index = outputs.start_logits |> Nx.argmax() |> Nx.to_number() -answer_end_index = outputs.end_logits |> Nx.argmax() |> Nx.to_number() - -answer_tokens = - inputs["input_ids"][[0, answer_start_index..answer_end_index]] |> Nx.to_flat_list() - -Bumblebee.Tokenizer.decode(tokenizer, answer_tokens) +Nx.Serving.run(serving, %{question: question, context: context}) ``` ## Final notes diff --git a/test/bumblebee/text/question_answering_test.ex b/test/bumblebee/text/question_answering_test.ex new file mode 100644 index 00000000..7b128539 --- /dev/null +++ b/test/bumblebee/text/question_answering_test.ex @@ -0,0 +1,48 @@ +defmodule Bumblebee.Text.QuestionAnsweringTest do + use ExUnit.Case, async: false + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + describe "integration" do + test "returns the most probable answer" do + {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"}) + + serving = Bumblebee.Text.question_answering(roberta, tokenizer) + + input = %{question: "What's my name?", context: "My name is Sarah and I live in London."} + + assert %{ + results: [ + %{ + text: "Sarah", + start: 11, + end: 16, + score: score + } + ] + } = Nx.Serving.run(serving, input) + + assert_all_close(score, 0.8105) + end + + test "supports multiple inputs" do + {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"}) + + serving = Bumblebee.Text.question_answering(roberta, tokenizer) + + inputs = [ + %{question: "What's my name?", context: "My name is Sarah and I live in London."}, + %{question: "Where do I live?", context: "My name is Clara and I live in Berkeley."} + ] + + assert [ + %{results: [%{text: "Sarah", start: 11, end: 16, score: _}]}, + %{results: [%{text: "Berkeley", start: 31, end: 39, score: _}]} + ] = Nx.Serving.run(serving, inputs) + end + end +end