Skip to content

Commit

Permalink
Merge pull request #93 from edgenai/feat/oneshot-llm
Browse files Browse the repository at this point in the history
One-shot LLM requests
  • Loading branch information
pedro-devv committed Feb 26, 2024
2 parents a264b27 + c2412d9 commit 93d5581
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 88 deletions.
20 changes: 17 additions & 3 deletions crates/edgen_core/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
* limitations under the License.
*/

use crate::BoxedFuture;
use core::time::Duration;
use std::path::Path;

use futures::Stream;
use serde::Serialize;
use std::path::Path;
use thiserror::Error;

use crate::BoxedFuture;

/// The context tag marking the start of generated dialogue.
pub const ASSISTANT_TAG: &str = "<|ASSISTANT|>";

Expand All @@ -42,10 +44,22 @@ pub enum LLMEndpointError {
#[derive(Debug, Clone)]
pub struct CompletionArgs {
pub prompt: String,
pub seed: u32,
pub one_shot: bool,
pub seed: Option<u32>,
pub frequency_penalty: f32,
}

impl Default for CompletionArgs {
fn default() -> Self {
Self {
prompt: "".to_string(),
one_shot: false,
seed: None,
frequency_penalty: 0.0,
}
}
}

/// A large language model endpoint, that is, an object that provides various ways to interact with
/// a large language model.
pub trait LLMEndpoint {
Expand Down
202 changes: 156 additions & 46 deletions crates/edgen_rt_llama_cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* limitations under the License.
*/

use std::mem::take;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -216,39 +217,76 @@ impl UnloadingModel {

/// Computes the full chat completions for the provided [`CompletionArgs`].
async fn chat_completions(&self, args: CompletionArgs) -> Result<String, LLMEndpointError> {
let (session, mut id, new_context) = self.take_chat_session(&args.prompt).await;
let (_model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?;

let (_session_signal, mut handle) = {
let (session_signal, mut session_guard) =
get_or_init_session(&session, model_guard.clone()).await?;
if args.one_shot {
info!("Allocating one-shot LLM session");
let mut params = SessionParams::default();
let threads = SETTINGS.read().await.read().await.auto_threads(false);

session_guard
.advance_context_async(new_context)
// TODO handle optional params
//params.seed = args.seed;
params.n_threads = threads;
params.n_threads_batch = threads;
params.n_ctx = CONTEXT_SIZE;

let mut session = model_guard
.create_session(params)
.map_err(move |e| LLMEndpointError::SessionCreationFailed(e.to_string()))?;

session
.advance_context_async(args.prompt)
.await
.map_err(move |e| LLMEndpointError::Advance(e.to_string()))?;
id.advance(new_context);

let sampler = StandardSampler::default();
let handle = session_guard.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);
let mut handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

(session_signal, handle)
};
let mut res = String::default();
while let Some(token) = handle.next_token_async().await {
if token == model_guard.eos() {
break;
}

let mut res = String::default();
while let Some(token) = handle.next_token_async().await {
if token == model_guard.eos() {
break;
let piece = model_guard.token_to_piece(token);
res += &piece;
}

let piece = model_guard.token_to_piece(token);
res += &piece;
id.advance(&piece);
}
Ok(res)
} else {
let (session, mut id, new_context) = self.take_chat_session(&args.prompt).await;

self.sessions.insert(id, session);
let (_session_signal, mut handle) = {
let (session_signal, mut session_guard) =
get_or_init_session(&session, model_guard.clone()).await?;

Ok(res)
session_guard
.advance_context_async(new_context)
.await
.map_err(move |e| LLMEndpointError::Advance(e.to_string()))?;
id.advance(new_context);

let sampler = StandardSampler::default();
let handle = session_guard.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

(session_signal, handle)
};

let mut res = String::default();
while let Some(token) = handle.next_token_async().await {
if token == model_guard.eos() {
break;
}

let piece = model_guard.token_to_piece(token);
res += &piece;
id.advance(&piece);
}

self.sessions.insert(id, session);

Ok(res)
}
}

/// Return a [`Box`]ed [`Stream`] of chat completions computed for the provided
Expand All @@ -257,24 +295,53 @@ impl UnloadingModel {
&self,
args: CompletionArgs,
) -> Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError> {
let (session, id, new_context) = self.take_chat_session(&args.prompt).await;
let (model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?;

let sampler = StandardSampler::default();
let tx = self.finished_tx.clone();

Ok(Box::new(
CompletionStream::new(
session,
id,
new_context,
model_guard.clone(),
model_signal,
sampler,
tx,
)
.await?,
))
if args.one_shot {
info!("Allocating one-shot LLM session");
let mut params = SessionParams::default();
let threads = SETTINGS.read().await.read().await.auto_threads(false);

// TODO handle optional params
//params.seed = args.seed;
params.n_threads = threads;
params.n_threads_batch = threads;
params.n_ctx = CONTEXT_SIZE;

let session = model_guard
.create_session(params)
.map_err(move |e| LLMEndpointError::SessionCreationFailed(e.to_string()))?;
let sampler = StandardSampler::default();

Ok(Box::new(
CompletionStream::new_oneshot(
session,
&args.prompt,
model_guard.clone(),
model_signal,
sampler,
)
.await?,
))
} else {
let (session, id, new_context) = self.take_chat_session(&args.prompt).await;

let sampler = StandardSampler::default();
let tx = self.finished_tx.clone();

Ok(Box::new(
CompletionStream::new(
session,
id,
new_context,
model_guard.clone(),
model_signal,
sampler,
tx,
)
.await?,
))
}
}
}

Expand Down Expand Up @@ -473,19 +540,19 @@ struct CompletionStream {
end_token: Token,

/// The session used for generation completions.
session: Option<Perishable<LlamaSession>>,
session: SessionOption,

/// The `session`'s id.
session_id: Option<SessionId>,

/// A sender used to send both `session` and `session_id` once generation is completion
finished_tx: UnboundedSender<(SessionId, Perishable<LlamaSession>)>,
finished_tx: Option<UnboundedSender<(SessionId, Perishable<LlamaSession>)>>,

/// The object signaling that `model` is currently active.
_model_signal: ActiveSignal,

/// The object signaling that `session` is currently active.
_session_signal: ActiveSignal,
_session_signal: Option<ActiveSignal>,
}

impl CompletionStream {
Expand Down Expand Up @@ -532,11 +599,40 @@ impl CompletionStream {
handle: Arc::new(Mutex::new(handle)),
next: None,
end_token,
session: Some(session),
session: SessionOption::Perishable(session),
session_id: Some(session_id),
finished_tx,
finished_tx: Some(finished_tx),
_model_signal: model_signal,
_session_signal: Some(session_signal),
})
}

async fn new_oneshot(
mut session: LlamaSession,
new_context: &str,
model: LlamaModel,
model_signal: ActiveSignal,
sampler: StandardSampler,
) -> Result<Self, LLMEndpointError> {
let model_clone = model.clone();
let end_token = model.eos();

session
.advance_context_async(new_context)
.await
.map_err(move |e| LLMEndpointError::Advance(e.to_string()))?;
let handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

Ok(Self {
model: model_clone,
handle: Arc::new(Mutex::new(handle)),
next: None,
end_token,
session: SessionOption::OneShot(session),
session_id: None,
finished_tx: None,
_model_signal: model_signal,
_session_signal: session_signal,
_session_signal: None,
})
}

Expand Down Expand Up @@ -595,13 +691,27 @@ impl Stream for CompletionStream {
impl Drop for CompletionStream {
fn drop(&mut self) {
if let Some(id) = self.session_id.take() {
if let Some(session) = self.session.take() {
self.finished_tx
.send((id, session))
.unwrap_or_else(move |e| {
if let SessionOption::Perishable(session) = self.session.take() {
if let Some(channel) = self.finished_tx.take() {
channel.send((id, session)).unwrap_or_else(move |e| {
error!("Failed to send session to maintenance thread: {e}")
});
}
}
}
}
}

#[derive(Default)]
enum SessionOption {
OneShot(LlamaSession),
Perishable(Perishable<LlamaSession>),
#[default]
None,
}

impl SessionOption {
fn take(&mut self) -> Self {
take(self)
}
}
19 changes: 5 additions & 14 deletions crates/edgen_server/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@ use crate::util::StoppingStream;

static ENDPOINT: Lazy<LlamaCppEndpoint> = Lazy::new(Default::default);

pub async fn chat_completion(model: Model, context: String) -> Result<String, LLMEndpointError> {
let args = CompletionArgs {
prompt: context,
seed: 0,
frequency_penalty: 0.0,
};

pub async fn chat_completion(
model: Model,
args: CompletionArgs,
) -> Result<String, LLMEndpointError> {
ENDPOINT
.chat_completions(
model
Expand All @@ -40,14 +37,8 @@ pub async fn chat_completion(model: Model, context: String) -> Result<String, LL

pub async fn chat_completion_stream(
model: Model,
context: String,
args: CompletionArgs,
) -> Result<StoppingStream<Box<dyn Stream<Item = String> + Unpin + Send>>, LLMEndpointError> {
let args = CompletionArgs {
prompt: context,
seed: 0,
frequency_penalty: 0.0,
};

let stream = ENDPOINT
.stream_chat_completions(
model
Expand Down
Loading

0 comments on commit 93d5581

Please sign in to comment.