Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add question answering serving #157

Merged
merged 11 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,33 @@ defmodule Bumblebee.Text do
@spec fill_mask(Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t()
defdelegate fill_mask(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.FillMask

@type question_answering_input :: %{question: String.t(), context: String.t()}
@type question_answering_output :: %{
predictions: list(question_answering_result())
}
@type question_answering_result :: %{
text: String.t(),
start: number(),
end: number(),
score: number()
}
@doc """

Builds serving for the question answering task.

The serving accepts `t:question_answering_input/0` and returns `t:question_answering_ouput/0`.

The question answering task predicts an answer for a question given some context.

"""
@spec question_answering(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate question_answering(model_info, tokenizer, opts \\ []),
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
to: Bumblebee.Text.QuestionAnswering

@type zero_shot_classification_input :: String.t()
@type zero_shot_classification_output :: %{
predictions: list(zero_shot_classification_prediction())
Expand Down
116 changes: 116 additions & 0 deletions lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
defmodule Bumblebee.Text.QuestionAnswering do
alias Bumblebee.Tokenizer
alias Bumblebee.Shared
alias Bumblebee.Utils
alias Axon

@moduledoc false

def question_answering(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
Shared.validate_architecture!(spec, :for_question_answering)

opts =
Keyword.validate!(opts, [
:compile,
doc_stride: 128,
top_k: 1,
defn_options: []
])

top_k = opts[:top_k]
compile = opts[:compile]
defn_options = opts[:defn_options]
doc_stride = opts[:doc_stride]

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

if compile != nil and (batch_size == nil or sequence_length == nil) do
raise ArgumentError,
"expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}"
end

{_init_fun, predict_fun} = Axon.build(model)

scores_fun = fn params, input ->
# input = Utils.Nx.composite_flatten_batch(input)
output = predict_fun.(params, input)
end

Nx.Serving.new(
fn ->
predict_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :s64),
"attention_mask" => Nx.template({batch_size, sequence_length}, :s64)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

scores = predict_fun.(params, inputs)
end
end,
batch_size: batch_size
)
|> Nx.Serving.client_preprocessing(fn input ->
{inputs, multi?} =
Shared.validate_serving_input!(input, fn
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
%{question: question, context: context}
when is_binary(question) and is_binary(context) ->
{:ok, {question, context}}

other ->
{:error,
"expected input map with :question and :context keys, got: #{inspect(other)}"}
end)

all_inputs =
Bumblebee.apply_tokenizer(tokenizer, inputs,
length: sequence_length,
return_special_tokens_mask: true,
return_offsets: true
)

inputs = Map.take(all_inputs, ["input_ids", "attention_mask"])
{Nx.Batch.concatenate([inputs]), {all_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn outputs, metadata, {inputs, multi?} ->
%{
results:
Enum.zip_with(
Utils.Nx.batch_to_list(inputs),
Utils.Nx.batch_to_list(outputs),
fn inputs, outputs ->
answer_start_index =
outputs.start_logits
|> Axon.Activations.softmax()
|> Nx.argmax()
|> Nx.to_number()

answer_end_index =
outputs.end_logits |> Axon.Activations.softmax() |> Nx.argmax() |> Nx.to_number()

start = inputs["start_offsets"][answer_start_index] |> Nx.to_number()
ending = inputs["end_offsets"][answer_end_index] |> Nx.to_number()
answer_tokens = inputs["input_ids"][answer_start_index..answer_end_index]

answer = Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)
blackeuler marked this conversation as resolved.
Show resolved Hide resolved

%{
text: answer,
start: start,
end: ending,
score: answer_start_index * answer_end_index
blackeuler marked this conversation as resolved.
Show resolved Hide resolved
}
end
)
}
end)
end
end
14 changes: 4 additions & 10 deletions notebooks/examples.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ Another text-related task is question answering, where the objective is to retri
:ok
```



```elixir
question_input =
Kino.Input.text("Question",
Expand All @@ -199,17 +201,9 @@ context_input =
question = Kino.Input.read(question_input)
context = Kino.Input.read(context_input)

inputs = Bumblebee.apply_tokenizer(tokenizer, {question, context})

outputs = Axon.predict(roberta.model, roberta.params, inputs)

answer_start_index = outputs.start_logits |> Nx.argmax() |> Nx.to_number()
answer_end_index = outputs.end_logits |> Nx.argmax() |> Nx.to_number()

answer_tokens =
inputs["input_ids"][[0, answer_start_index..answer_end_index]] |> Nx.to_flat_list()
serving = Bumblebee.Text.question_answering(roberta,tokenizer)

Bumblebee.Tokenizer.decode(tokenizer, answer_tokens)
Nx.Serving.run(serving, {question,context})
```

## Final notes
Expand Down
40 changes: 40 additions & 0 deletions test/bumblebee/text/question_answering_test.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
defmodule Bumblebee.Text.QuestionAnsweringTest do
use ExUnit.Case, async: false

import Bumblebee.TestHelpers

@moduletag model_test_tags()

describe "integration" do
test "returns top scored labels" do
{:ok, roberta} =
Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"},
architecture: :for_question_answering
)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "roberta-base"})

serving =
Bumblebee.Text.question_answering(roberta, tokenizer,
compile: [batch_size: 1, sequence_length: 32],
defn_options: [compiler: EXLA]
)

text_and_context = %{
question: "What is my name",
context: "My name is blackeuler"
}

assert %{
results: [
%{
text: " blackeuler",
start: 11,
end: 21,
score: _score
}
]
} = Nx.Serving.run(serving, text_and_context)
end
end
end