Skip to content

Commit

Permalink
feat(router): add /tokenize route (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jan 26, 2024
1 parent 3d9cc20 commit 6395a7a
Show file tree
Hide file tree
Showing 9 changed files with 556 additions and 38 deletions.
18 changes: 17 additions & 1 deletion core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::queue::{Entry, Metadata, NextBatch, Queue};
use crate::tokenization::{EncodingInput, Tokenization};
use crate::tokenization::{EncodingInput, RawEncoding, Tokenization};
use crate::TextEmbeddingsError;
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -58,6 +58,22 @@ impl Infer {
}
}

#[instrument(skip(self))]
pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>(
&self,
inputs: I,
add_special_tokens: bool,
) -> Result<RawEncoding, TextEmbeddingsError> {
self.tokenization
.tokenize(inputs.into(), add_special_tokens)
.await
.map_err(|err| {
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
tracing::error!("{err}");
err
})
}

#[instrument(skip(self))]
pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
// Limit concurrent requests by acquiring a permit from the semaphore
Expand Down
4 changes: 2 additions & 2 deletions core/src/queue.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::infer::InferResponse;
use crate::tokenization::Encoding;
use crate::tokenization::ValidEncoding;
use std::cmp::max;
use std::collections::VecDeque;
use std::time::{Duration, Instant};
Expand All @@ -11,7 +11,7 @@ use tracing::{instrument, Span};
#[derive(Debug)]
pub struct Entry {
/// Payload
pub encoding: Encoding,
pub encoding: ValidEncoding,
/// Entry metadata
pub metadata: Metadata,
}
Expand Down
136 changes: 104 additions & 32 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use tokenizers::tokenizer::Tokenizer;
pub use tokenizers::Encoding as RawEncoding;
use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy};
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
Expand Down Expand Up @@ -63,7 +64,7 @@ impl Tokenization {
&self,
inputs: EncodingInput,
truncate: bool,
) -> Result<Encoding, TextEmbeddingsError> {
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
return Err(TextEmbeddingsError::Validation(
Expand All @@ -76,7 +77,43 @@ impl Tokenization {
// Send request to the background validation task
// Unwrap is safe here
self.sender
.send((inputs, truncate, response_sender, Span::current()))
.send(TokenizerRequest::Encode(
inputs,
truncate,
response_sender,
Span::current(),
))
.expect("Tokenization background task dropped the receiver. This is a bug.");

// Await on response channel
// Unwrap is safe here
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
}

#[instrument(skip_all)]
pub async fn tokenize(
&self,
inputs: EncodingInput,
add_special_tokens: bool,
) -> Result<RawEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
return Err(TextEmbeddingsError::Validation(
"`inputs` cannot be empty".to_string(),
));
}

// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender
.send(TokenizerRequest::Tokenize(
inputs,
add_special_tokens,
response_sender,
Span::current(),
))
.expect("Tokenization background task dropped the receiver. This is a bug.");

// Await on response channel
Expand All @@ -93,31 +130,65 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests
while let Some((inputs, truncate, response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(encode_input(
inputs,
truncate,
max_input_length,
position_offset,
&mut tokenizer,
));
while let Some(request) = receiver.blocking_recv() {
match request {
TokenizerRequest::Encode(inputs, truncate, response_tx, parent_span) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(encode_input(
inputs,
truncate,
max_input_length,
position_offset,
&mut tokenizer,
));
}
})
}
TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(tokenize_input(
inputs,
add_special_tokens,
None,
&mut tokenizer,
));
}
})
}
})
}
}
}

fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
truncate_params: Option<TruncationParams>,
tokenizer: &mut Tokenizer,
) -> Result<RawEncoding, TextEmbeddingsError> {
let inputs: EncodeInput = match inputs {
EncodingInput::Single(s) => s.into(),
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
};

Ok(tokenizer
.with_truncation(truncate_params)?
.encode(inputs, add_special_tokens)?)
}

/// Get input length and optionally truncate it
fn encode_input(
inputs: EncodingInput,
truncate: bool,
max_input_length: usize,
position_offset: usize,
tokenizer: &mut Tokenizer,
) -> Result<Encoding, TextEmbeddingsError> {
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Default truncation params
let truncate_params = truncate.then_some(TruncationParams {
direction: TruncationDirection::Right,
Expand All @@ -126,14 +197,7 @@ fn encode_input(
stride: 0,
});

let inputs: EncodeInput = match inputs {
EncodingInput::Single(s) => s.into(),
EncodingInput::Dual(s1, s2) => (s1, s2).into(),
};

let encoding = tokenizer
.with_truncation(truncate_params)?
.encode(inputs, true)?;
let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
let seq_len = encoding.len();

if seq_len > max_input_length {
Expand All @@ -144,7 +208,7 @@ fn encode_input(

metrics::histogram!("te_request_input_length", seq_len as f64);

Ok(Encoding {
Ok(ValidEncoding {
input_ids: encoding.get_ids().to_vec(),
token_type_ids: encoding.get_type_ids().to_vec(),
position_ids: (position_offset as u32..(seq_len + position_offset) as u32)
Expand All @@ -153,7 +217,7 @@ fn encode_input(
}

#[derive(Debug)]
pub struct Encoding {
pub struct ValidEncoding {
pub input_ids: Vec<u32>,
pub token_type_ids: Vec<u32>,
pub position_ids: Vec<u32>,
Expand Down Expand Up @@ -186,9 +250,17 @@ impl From<(String, String)> for EncodingInput {
}
}

type TokenizerRequest = (
EncodingInput,
bool,
oneshot::Sender<Result<Encoding, TextEmbeddingsError>>,
Span,
);
enum TokenizerRequest {
Encode(
EncodingInput,
bool,
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
Span,
),
Tokenize(
EncodingInput,
bool,
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
Span,
),
}
129 changes: 129 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,52 @@
}
}
}
},
"/tokenize": {
"post": {
"tags": [
"Text Embeddings Inference"
],
"summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TokenizeRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Tokenized ids",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TokenizeResponse"
}
}
}
},
"422": {
"description": "Tokenization error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OpenAICompatErrorResponse"
},
"example": {
"message": "Tokenization error",
"type": "tokenizer"
}
}
}
}
}
}
}
},
"components": {
Expand Down Expand Up @@ -660,6 +706,17 @@
"$ref": "#/components/schemas/EmbeddingModel"
}
}
},
{
"type": "object",
"required": [
"reranker"
],
"properties": {
"reranker": {
"$ref": "#/components/schemas/ClassifierModel"
}
}
}
]
},
Expand Down Expand Up @@ -953,6 +1010,78 @@
"items": {
"$ref": "#/components/schemas/Rank"
}
},
"SimpleToken": {
"type": "object",
"required": [
"id",
"text",
"special"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0,
"minimum": 0
},
"special": {
"type": "boolean",
"example": "false"
},
"start": {
"type": "integer",
"example": 0,
"nullable": true,
"minimum": 0
},
"stop": {
"type": "integer",
"example": 2,
"nullable": true,
"minimum": 0
},
"text": {
"type": "string",
"example": "test"
}
}
},
"TokenizeRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"add_special_tokens": {
"type": "boolean",
"default": "true",
"example": "true"
},
"inputs": {
"$ref": "#/components/schemas/Input"
}
}
},
"TokenizeResponse": {
"type": "array",
"items": {
"type": "array",
"items": {
"$ref": "#/components/schemas/SimpleToken"
}
},
"example": [
[
{
"id": 0,
"special": false,
"start": 0,
"stop": 2,
"text": "test"
}
]
]
}
}
},
Expand Down
Loading

0 comments on commit 6395a7a

Please sign in to comment.