Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,10 @@ pub struct Infer {
queue: Queue,
/// Shared state
shared: Arc<Shared>,
/// Chat template
chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Chat template (template, bos_token, eos_token)
template: (
Option<Template<'static, 'static>>,
Option<String>,
Option<String>,
),
}

/// Infer shared state
Expand Down Expand Up @@ -88,32 +84,19 @@ impl Infer {
generation_health,
));

let chat_template = tokenizer_config
.chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));

// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

let template = tokenizer_config.chat_template.map(|t| {
let mut env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});
let eos_token = tokenizer_config
.eos_token
.map_or_else(String::new, |t| t)
.into();
let bos_token = tokenizer_config
.bos_token
.map_or_else(String::new, |t| t)
.into();
Self {
validation,
queue,
shared,
chat_template,
limit_concurrent_requests: semaphore,
template: (template, bos_token, eos_token),
}
}

Expand Down Expand Up @@ -192,20 +175,14 @@ impl Infer {
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
let (template, bos_token, eos_token) = &self.template;
template
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(ChatTemplateInputs {
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
add_generation_prompt: true,
})
.apply(messages)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
InferError::TemplateError(e)
e
})
}

Expand Down Expand Up @@ -329,6 +306,42 @@ impl Infer {
}
}

#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
}

impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();

Self {
template,
bos_token,
eos_token,
}
}

fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
})
.map_err(InferError::TemplateError)
}
}

/// Batching logic
/// Will be launched in a background Tokio task
///
Expand Down