Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add question answering serving #157

Merged
merged 11 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,68 @@ defmodule Bumblebee.Text do
@spec fill_mask(Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t()
defdelegate fill_mask(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.FillMask

@type question_answering_input :: %{question: String.t(), context: String.t()}

@type question_answering_output :: %{
predictions: list(question_answering_result())
}

@type question_answering_result :: %{
text: String.t(),
start: number(),
end: number(),
score: number()
}
@doc """
Builds serving for the question answering task.

The serving accepts `t:question_answering_input/0` and returns
`t:question_answering_output/0`. A list of inputs is also supported.

The question answering task finds the most probable answer to a
question within the given context text.

## Options

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:

* `:batch_size` - the maximum batch size of the input. Inputs
are optionally padded to always match this batch size. Note
that the batch size refers to the number of prompts to classify,
while the model prediction is made for every combination of
prompt and label

* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length

It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
time.

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

## Examples

{:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"})

serving = Bumblebee.Text.question_answering(roberta, tokenizer)

input = %{question: "What's my name?", context: "My name is Sarah and I live in London."}
Nx.Serving.run(serving, input)
#=> %{results: [%{end: 16, score: 0.81039959192276, start: 11, text: "Sarah"}]}

"""
@spec question_answering(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate question_answering(model_info, tokenizer, opts \\ []),
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
to: Bumblebee.Text.QuestionAnswering

@type zero_shot_classification_input :: String.t()
@type zero_shot_classification_output :: %{
predictions: list(zero_shot_classification_prediction())
Expand All @@ -282,7 +344,8 @@ defmodule Bumblebee.Text do
Builds serving for the zero-shot classification task.

The serving accepts `t:zero_shot_classification_input/0` and returns
`t:zero_shot_classification_output/0`.
`t:zero_shot_classification_output/0`. A list of inputs is also
supported.

The zero-shot task predicts zero-shot labels for a given sequence by
proposing each label as a premise-hypothesis pairing.
Expand Down
103 changes: 103 additions & 0 deletions lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
defmodule Bumblebee.Text.QuestionAnswering do
@moduledoc false

alias Bumblebee.Utils
alias Bumblebee.Shared

def question_answering(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
Shared.validate_architecture!(spec, :for_question_answering)

opts = Keyword.validate!(opts, [:compile, defn_options: []])

compile = opts[:compile]
defn_options = opts[:defn_options]

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

if compile != nil and (batch_size == nil or sequence_length == nil) do
raise ArgumentError,
"expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}"
end

{_init_fun, predict_fun} = Axon.build(model)

scores_fun = fn params, input ->
outputs = predict_fun.(params, input)
start_scores = Axon.Activations.softmax(outputs.start_logits)
end_scores = Axon.Activations.softmax(outputs.end_logits)
%{start_scores: start_scores, end_scores: end_scores}
end

Nx.Serving.new(
fn ->
predict_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :s64),
"attention_mask" => Nx.template({batch_size, sequence_length}, :s64),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :s64)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

predict_fun.(params, inputs)
end
end,
batch_size: batch_size
)
|> Nx.Serving.client_preprocessing(fn raw_input ->
{raw_inputs, multi?} =
Shared.validate_serving_input!(raw_input, fn
%{question: question, context: context}
when is_binary(question) and is_binary(context) ->
{:ok, {question, context}}

other ->
{:error,
"expected input map with :question and :context keys, got: #{inspect(other)}"}
end)

all_inputs =
Bumblebee.apply_tokenizer(tokenizer, raw_inputs,
length: sequence_length,
return_token_type_ids: true,
return_offsets: true
)

inputs = Map.take(all_inputs, ["input_ids", "attention_mask", "token_type_ids"])
{Nx.Batch.concatenate([inputs]), {all_inputs, raw_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn outputs, _metadata, {inputs, raw_inputs, multi?} ->
Enum.zip_with(
[raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)],
fn [{_question_text, context_text}, inputs, outputs] ->
start_idx = outputs.start_scores |> Nx.argmax() |> Nx.to_number()
end_idx = outputs.end_scores |> Nx.argmax() |> Nx.to_number()

start = Nx.to_number(inputs["start_offsets"][start_idx])
ending = Nx.to_number(inputs["end_offsets"][end_idx])

score =
outputs.start_scores[start_idx]
|> Nx.multiply(outputs.end_scores[end_idx])
|> Nx.to_number()

answer_text = binary_part(context_text, start, ending - start)

results = [
%{text: answer_text, start: start, end: ending, score: score}
]

%{results: results}
end
)
|> Shared.normalize_output(multi?)
end)
end
end
12 changes: 2 additions & 10 deletions notebooks/examples.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,9 @@ context_input =
question = Kino.Input.read(question_input)
context = Kino.Input.read(context_input)

inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context})
serving = Bumblebee.Text.question_answering(roberta, tokenizer)

outputs = Axon.predict(roberta.model, roberta.params, inputs)

answer_start_index = outputs.start_logits |> Nx.argmax() |> Nx.to_number()
answer_end_index = outputs.end_logits |> Nx.argmax() |> Nx.to_number()

answer_tokens =
inputs["input_ids"][[0, answer_start_index..answer_end_index]] |> Nx.to_flat_list()

Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)
Nx.Serving.run(serving, %{question: question, context: context})
```

## Final notes
Expand Down
48 changes: 48 additions & 0 deletions test/bumblebee/text/question_answering_test.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
defmodule Bumblebee.Text.QuestionAnsweringTest do
use ExUnit.Case, async: false

import Bumblebee.TestHelpers

@moduletag model_test_tags()

describe "integration" do
test "returns the most probable answer" do
{:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"})

serving = Bumblebee.Text.question_answering(roberta, tokenizer)

input = %{question: "What's my name?", context: "My name is Sarah and I live in London."}

assert %{
results: [
%{
text: "Sarah",
start: 11,
end: 16,
score: score
}
]
} = Nx.Serving.run(serving, input)

assert_all_close(score, 0.8105)
end

test "supports multiple inputs" do
{:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"})

serving = Bumblebee.Text.question_answering(roberta, tokenizer)

inputs = [
%{question: "What's my name?", context: "My name is Sarah and I live in London."},
%{question: "Where do I live?", context: "My name is Clara and I live in Berkeley."}
]

assert [
%{results: [%{text: "Sarah", start: 11, end: 16, score: _}]},
%{results: [%{text: "Berkeley", start: 31, end: 39, score: _}]}
] = Nx.Serving.run(serving, inputs)
end
end
end