Skip to content

Commit

Permalink
feat: adding completion params passing down (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisgv committed Jun 11, 2023
1 parent 46a522a commit 7adf294
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 29 deletions.
7 changes: 1 addition & 6 deletions apps/desktop/src-tauri/src/inference_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ use actix_web::dev::ServerHandle;
use actix_web::web::{Bytes, Json};

use actix_web::{get, post, App, HttpResponse, HttpServer, Responder};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use tauri::AppHandle;

Expand All @@ -23,10 +21,7 @@ use crate::inference_thread::{
use crate::model_pool::{self, spawn_pool};
use crate::model_stats;
use crate::path::get_app_dir_path_buf;
use llm::{Model, VocabularySource};

static _LOADED_MODELMAP: Lazy<Mutex<HashMap<String, Box<dyn Model>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
use llm::VocabularySource;

#[derive(Default)]
pub struct State {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{convert::Infallible, sync::Arc};
use actix_web::web::Bytes;
use flume::Sender;
use llm::{
InferenceError, InferenceFeedback, InferenceParameters, Model, OutputRequest,
Prompt, TokenUtf8Buffer,
samplers, InferenceError, InferenceFeedback, InferenceParameters, Model,
OutputRequest, Prompt, TokenUtf8Buffer,
};
use parking_lot::{Mutex, RwLock};

Expand All @@ -14,14 +14,27 @@ use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use tokio::task::JoinHandle;

use crate::model_pool::{self, get_inference_params};
use crate::{
inference_thread::stop_handler::StopHandler,
model_pool::{self, get_n_threads},
};

mod stop_handler;

#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionRequest {
prompt: String,
max_tokens: u32,
temperature: f64,
max_tokens: Option<usize>,
stream: bool,

pub seed: Option<u64>,

pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,

pub stop_sequences: Option<Vec<String>>,
}

pub type ModelGuard = Arc<Mutex<Option<Box<dyn Model>>>>;
Expand Down Expand Up @@ -60,6 +73,32 @@ fn clean_prompt(s: &str) -> String {
.replace("\n<human>: ", "\n===\nhuman: ")
}

fn get_inference_params(
completion_request: &CompletionRequest,
) -> InferenceParameters {
let n_threads = get_n_threads();

let repeat_penalty =
(completion_request.frequency_penalty.unwrap_or(0.6) + 2.0) / 2.0;

let repetition_penalty_last_n = 256
+ (((completion_request.presence_penalty.unwrap_or(0.0) + 2.0) / 4.0
* 512.0) as usize);

InferenceParameters {
n_threads,
n_batch: n_threads,
sampler: Arc::new(samplers::TopPTopK {
temperature: completion_request.temperature.unwrap_or(1.0),
top_p: completion_request.top_p.unwrap_or(1.0),
repeat_penalty,
repetition_penalty_last_n,

..Default::default()
}),
}
}

// Perhaps might be better to clone the model for each thread...
pub fn start_inference(req: InferenceThreadRequest) -> Option<JoinHandle<()>> {
println!("Starting inference ...");
Expand All @@ -82,7 +121,7 @@ pub fn start_inference(req: InferenceThreadRequest) -> Option<JoinHandle<()>> {
let prompt = &raw_prompt;
let mut output_request = OutputRequest::default();

let inference_params = get_inference_params();
let inference_params = get_inference_params(&req.completion_request);

// Manual tokenization if needed
// let vocab = model.vocabulary();
Expand Down Expand Up @@ -118,6 +157,7 @@ pub fn start_inference(req: InferenceThreadRequest) -> Option<JoinHandle<()>> {

let handle =
spawn_inference_thread(req, inference_params, session, output_request);

Some(handle)
}

Expand All @@ -129,9 +169,12 @@ fn spawn_inference_thread(
) -> JoinHandle<()> {
println!("Spawning inference thread...");
let handle = actix_web::rt::task::spawn_blocking(move || {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let mut tokens_processed = 0;
let maximum_token_count = usize::MAX;
let mut rng =
ChaCha8Rng::seed_from_u64(req.completion_request.seed.unwrap_or(420));

let maximum_token_count =
req.completion_request.max_tokens.unwrap_or(usize::MAX);

let mut token_utf8_buf = TokenUtf8Buffer::new();
let guard = req.model_guard.lock();

Expand All @@ -143,8 +186,21 @@ fn spawn_inference_thread(
}
};

let mut stop_handler = StopHandler::new(
model.as_ref(),
req
.completion_request
.stop_sequences
.as_ref()
.unwrap_or(&vec![]),
);

let mut tokens_processed = 0;

while tokens_processed < maximum_token_count {
if *Arc::clone(&req.abort_flag).read() {
if *Arc::clone(&req.abort_flag).read()
|| req.token_sender.is_disconnected()
{
break;
}

Expand All @@ -164,21 +220,28 @@ fn spawn_inference_thread(
}
};

if stop_handler.check(&token) {
break;
}

// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(&token) {
match req.token_sender.send(get_completion_resp(tokens)) {
Ok(_) => {}
Err(e) => {
println!("Error while sending token: {:?}", e);
Err(_) => {
break;
}
}
}

tokens_processed += 1;
}
// TODO: Might make this into a callback later, for now we just abuse the singleton

if !req.token_sender.is_disconnected() {
req.token_sender.send(Bytes::from("data: [DONE]")).unwrap();
}

// TODO: Might make this into a callback later, for now we just abuse the singleton
model_pool::return_model(Some(Arc::clone(&req.model_guard)));

// Run inference
Expand Down
76 changes: 76 additions & 0 deletions apps/desktop/src-tauri/src/inference_thread/stop_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/// Based on https://github.com/LLukas22/llm-rs-python/blob/e2a6d459e39df9be14442676fa43397a3b8753a4/src/stopwords.rs#L33
/// Copyright (c) 2023 Lukas Kreussel
/// MIT License
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

#[derive(Clone, Serialize, Deserialize)]
struct Buffer<T> {
pub data: Vec<T>,
capacity: usize,
}

impl<T> Buffer<T> {
fn new(capacity: usize) -> Self {
Buffer {
data: Vec::with_capacity(capacity),
capacity,
}
}

fn push(&mut self, item: T) {
if self.data.len() == self.capacity {
self.data.remove(0);
}
self.data.push(item);
}
}

#[derive(Clone, Serialize, Deserialize)]
pub struct StopHandler {
pub stops: HashSet<Vec<u8>>,
buffer: Buffer<u8>,
capacity: usize,
}

impl StopHandler {
pub fn new(model: &dyn llm::Model, stops: &[String]) -> StopHandler {
let stop_tokens: HashSet<Vec<u8>> = stops
.iter()
.map(|word| {
model
.vocabulary()
.tokenize(word, false)
.unwrap()
.iter()
.flat_map(|(encoding, _)| encoding.to_owned())
.collect::<Vec<u8>>()
})
.collect();

let capacity = stop_tokens.iter().map(|v| v.len()).max().unwrap_or(0);

StopHandler {
stops: stop_tokens,
buffer: Buffer::new(capacity),
capacity,
}
}

pub fn check(&mut self, new_tokens: &Vec<u8>) -> bool {
if self.capacity == 0 {
return false;
}

for token in new_tokens {
self.buffer.push(token.to_owned());
for i in 0..self.buffer.data.len() {
let slice = self.buffer.data[i..].to_vec();
if self.stops.contains(&slice) {
return true;
}
}
}
false
}
}
13 changes: 3 additions & 10 deletions apps/desktop/src-tauri/src/model_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use parking_lot::Mutex;

use std::{fs, path::PathBuf, sync::Arc};

use llm::{
load_progress_callback_stdout, InferenceParameters, ModelArchitecture,
VocabularySource,
};
use llm::{load_progress_callback_stdout, ModelArchitecture, VocabularySource};

use std::path::Path;

Expand All @@ -33,12 +30,8 @@ pub fn return_model(model: Option<ModelGuard>) {
models.push_back(model);
}

pub fn get_inference_params() -> InferenceParameters {
InferenceParameters {
// n_batch: 4,
n_threads: num_cpus::get_physical() / (*CONCURRENCY_COUNT.lock()),
..Default::default()
}
pub fn get_n_threads() -> usize {
num_cpus::get_physical() / (*CONCURRENCY_COUNT.lock())
}

pub async fn spawn_pool(
Expand Down
6 changes: 6 additions & 0 deletions apps/desktop/src/features/thread/process-sse-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ export async function processSseStream(
if (result?.startsWith(SSE_DATA_EVENT_PREFIX)) {
const eventData = result.slice(SSE_DATA_EVENT_PREFIX.length).trim()

if (eventData === "[DONE]") {
// Handle early termination here if needed. This is the final value event emitted by the server before closing the connection.

break
}

await onData(JSON.parse(eventData))
}
} catch (_) {}
Expand Down

1 comment on commit 7adf294

@vercel
Copy link

@vercel vercel bot commented on 7adf294 Jun 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.