Skip to content

Commit

Permalink
Merge pull request #12 from edgenai/feat/non-stream-completions
Browse files Browse the repository at this point in the history
feat: non-streaming chat completions
  • Loading branch information
pedro-devv committed Feb 2, 2024
2 parents dd6cda2 + c2757ac commit 0d2d1ce
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 18 deletions.
2 changes: 1 addition & 1 deletion crates/edgen_core/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct CompletionArgs {
pub frequency_penalty: f32,
}

/// A large language language model endpoint, that is, an object that provides various ways to interact with a large
/// A large language model endpoint, that is, an object that provides various ways to interact with a large
/// language model.
pub trait LLMEndpoint {
/// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually contain the prompt
Expand Down
87 changes: 71 additions & 16 deletions crates/edgen_server/src/openai_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt::{Display, Formatter};

use axum::http::StatusCode;
Expand All @@ -26,10 +25,8 @@ use axum::response::{IntoResponse, Response, Sse};
use axum::Json;
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use derive_more::{Deref, DerefMut, From};
use edgen_core::settings::SETTINGS;
use edgen_core::settings::{get_audio_transcriptions_model_dir, get_chat_completions_model_dir};
use either::Either;
use futures::StreamExt;
use futures::{Stream, StreamExt, TryStream};
use serde_derive::{Deserialize, Serialize};
use thiserror::Error;
use time::OffsetDateTime;
Expand All @@ -38,6 +35,9 @@ use tracing::error;
use utoipa::ToSchema;
use uuid::Uuid;

use edgen_core::settings::SETTINGS;
use edgen_core::settings::{get_audio_transcriptions_model_dir, get_chat_completions_model_dir};

use crate::model::{Model, ModelKind};
use crate::whisper::WhisperEndpointError;

Expand Down Expand Up @@ -521,6 +521,29 @@ impl IntoResponse for ChatCompletionError {
}
}

/// The return type of [`chat_completions`]. Contains either a [`Stream`] of [`Event`]s or a [`Json`]
/// of a [`ChatCompletion`].
enum ChatCompletionResponse<'a, S>
where
S: TryStream<Ok = Event> + Send + 'static,
{
Stream(Sse<S>),
Full(Json<ChatCompletion<'a>>),
}

impl<'a, S, E> IntoResponse for ChatCompletionResponse<'a, S>
where
S: Stream<Item = Result<Event, E>> + Send + 'static,
E: Into<axum::BoxError>,
{
fn into_response(self) -> Response {
match self {
ChatCompletionResponse::Stream(stream) => stream.into_response(),
ChatCompletionResponse::Full(full) => full.into_response(),
}
}
}

/// POST `/v1/chat/completions`: generate chat completions for the provided context, optionally
/// streaming those completions in real-time.
///
Expand Down Expand Up @@ -590,12 +613,18 @@ pub async fn chat_completions(

let untokenized_context = format!("{}<|ASSISTANT|>", req.messages);

let completions_stream = crate::llm::chat_completion_stream(model, untokenized_context)
.await?
.map(|chunk| {
let fp = format!("edgen-{}", cargo_crate_version!());
Event::default()
.json_data(ChatCompletionChunk {
let stream_response = if let Some(stream) = req.stream {
stream
} else {
false
};

let fp = format!("edgen-{}", cargo_crate_version!());
let response = if stream_response {
let completions_stream = crate::llm::chat_completion_stream(model, untokenized_context)
.await?
.map(move |chunk| {
Event::default().json_data(ChatCompletionChunk {
id: Uuid::new_v4().to_string().into(),
choices: tiny_vec![ChatCompletionChunkChoice {
index: 0,
Expand All @@ -610,11 +639,37 @@ pub async fn chat_completions(
system_fingerprint: Cow::Borrowed(&fp), // use macro for version
object: Cow::Borrowed("text_completion"),
})
.expect("Could not serialize JSON; this should never happen")
})
.map(Ok::<Event, Infallible>);

Ok(Sse::new(completions_stream))
});

ChatCompletionResponse::Stream(Sse::new(completions_stream))
} else {
let content_str = crate::llm::chat_completion(model, untokenized_context).await?;
let response = ChatCompletion {
id: Uuid::new_v4().to_string().into(),
choices: vec![ChatCompletionChoice {
message: ChatMessage::Assistant {
content: Some(Cow::Owned(content_str)),
name: None,
tool_calls: None,
},
finish_reason: None,
index: 0,
}],
created: OffsetDateTime::now_utc().unix_timestamp(),
model: Cow::Borrowed("main"),
object: Cow::Borrowed("text_completion"),
system_fingerprint: Cow::Owned(fp), // use macro for version
usage: ChatCompletionUsage {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
},
};

ChatCompletionResponse::Full(Json(response))
};

Ok(response)
}

/// A request to transcribe an audio file into text in either the specified language, or whichever
Expand Down Expand Up @@ -662,7 +717,7 @@ pub struct CreateTranscriptionRequest {
///
/// See [the original OpenAI API specification][openai], which this endpoint is compatible with.
///
/// [openai]: https://platform.openai.com/docs/api-reference/auddio/createTranscription
/// [openai]: https://platform.openai.com/docs/api-reference/audio/createTranscription
///
/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`WhisperEndpointError`]
/// to the peer.
Expand Down
2 changes: 1 addition & 1 deletion crates/edgen_server/src/util/stopping_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct StoppingStream<T> {
/// The stop words (phrases) that this stream should stop at.
///
/// These are never emitted downstream, and the stream will yield with `Pending` until it
/// is is impossible for any stop word to be generated.
/// is impossible for any stop word to be generated.
stop_words: Vec<String>,

/// If this stream is uncertain whether it's collecting a stop word, this buffer contains
Expand Down

0 comments on commit 0d2d1ce

Please sign in to comment.