Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions examples/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@ use blockless_sdk::*;

/// This example demonstrates how to use the Blockless SDK to interact with two different LLM models.
///
/// It sets up two instances of the BlocklessLlm struct:
/// - One for a large model (Llama-3.1-8B)
/// - One for a small model (SmolLM2-1.7B)
///
/// It sets up two instances of the BlocklessLlm struct.
/// Each model is configured with a system message that changes the assistant's name.
/// The example then sends chat requests to both models and prints their responses,
/// demonstrating how the same instance maintains state between requests.

fn main() {
// large model
let mut llm = BlocklessLlm::new("Llama-3.1-8B-Instruct-q4f32_1-MLC").unwrap();
let mut llm = BlocklessLlm::new(SupportedModels::Mistral7BInstructV03(None)).unwrap();

// small model
let mut llm_smol = BlocklessLlm::new("SmolLM2-1.7B-Instruct-q4f16_1-MLC").unwrap();
let mut llm_small = BlocklessLlm::new(SupportedModels::Llama321BInstruct(None)).unwrap();

let prompt = r#"
You are a helpful assistant.
Expand All @@ -30,38 +27,28 @@ fn main() {
.unwrap();

let response = llm.chat_request("What is your name?").unwrap();
println!("LLM Response: {}", response);
println!("llm Response: {}", response);

let prompt_smol = r#"
You are a helpful assistant.
First time I ask, you name will be daisy.
Second time I ask, you name will be hector.
"#;
llm_smol
llm_small
.set_options(LlmOptions {
system_message: prompt_smol.to_string(),
top_p: Some(0.5),
..Default::default()
})
.unwrap();

let response = llm_smol.chat_request("What is your name?").unwrap();
println!("LLM Response SmolLM: {}", response);
let response = llm_small.chat_request("What is your name?").unwrap();
println!("llm_small Response: {}", response);

let response = llm_smol.chat_request("What is your name?").unwrap();
println!("LLM Response SmolLM: {}", response);
let response = llm_small.chat_request("What is your name?").unwrap();
println!("llm_small Response: {}", response);

// test if same instance is used in host/runtime
let response = llm.chat_request("What is your name?").unwrap();
println!("LLM Response: {}", response);

// For streaming responses, you can use read_response_chunk
// let mut buf = [0u8; 4096];
// while let Ok(num) = llm.read_response_chunk(&mut buf) {
// if num == 0 {
// break;
// }
// let chunk = String::from_utf8_lossy(&buf[..num as usize]);
// println!("Chunk: {}", chunk);
// }
println!("llm Response: {}", response);
}
149 changes: 108 additions & 41 deletions src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use json::JsonValue;
use std::{str::FromStr, string::ToString};

type Handle = u32;
type ExitCode = u8;
Expand Down Expand Up @@ -33,6 +34,90 @@ extern "C" {
fn llm_close(h: Handle) -> ExitCode;
}

#[derive(Debug, Clone)]
pub enum SupportedModels {
Llama321BInstruct(Option<String>),
Llama323BInstruct(Option<String>),
Mistral7BInstructV03(Option<String>),
Mixtral8x7BInstructV01(Option<String>),
Gemma22BInstruct(Option<String>),
Gemma27BInstruct(Option<String>),
Gemma29BInstruct(Option<String>),
}

impl FromStr for SupportedModels {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
// Llama 3.2 1B
"Llama-3.2-1B-Instruct" => Ok(SupportedModels::Llama321BInstruct(None)),
"Llama-3.2-1B-Instruct-Q6_K"
| "Llama-3.2-1B-Instruct_Q6_K"
| "Llama-3.2-1B-Instruct.Q6_K" => {
Ok(SupportedModels::Llama321BInstruct(Some("Q6_K".to_string())))
}
"Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => Ok(
SupportedModels::Llama321BInstruct(Some("q4f16_1".to_string())),
),

// Llama 3.2 3B
"Llama-3.2-3B-Instruct" => Ok(SupportedModels::Llama323BInstruct(None)),
"Llama-3.2-3B-Instruct-Q6_K"
| "Llama-3.2-3B-Instruct_Q6_K"
| "Llama-3.2-3B-Instruct.Q6_K" => {
Ok(SupportedModels::Llama323BInstruct(Some("Q6_K".to_string())))
}
"Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => Ok(
SupportedModels::Llama323BInstruct(Some("q4f16_1".to_string())),
),

// Mistral 7B
"Mistral-7B-Instruct-v0.3" => Ok(SupportedModels::Mistral7BInstructV03(None)),
"Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => Ok(
SupportedModels::Mistral7BInstructV03(Some("q4f16_1".to_string())),
),

// Mixtral 8x7B
"Mixtral-8x7B-Instruct-v0.1" => Ok(SupportedModels::Mixtral8x7BInstructV01(None)),
"Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => Ok(
SupportedModels::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())),
),

// Gemma models
"gemma-2-2b-it" => Ok(SupportedModels::Gemma22BInstruct(None)),
"gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => Ok(
SupportedModels::Gemma22BInstruct(Some("q4f16_1".to_string())),
),

"gemma-2-27b-it" => Ok(SupportedModels::Gemma27BInstruct(None)),
"gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => Ok(
SupportedModels::Gemma27BInstruct(Some("q4f16_1".to_string())),
),

"gemma-2-9b-it" => Ok(SupportedModels::Gemma29BInstruct(None)),
"gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => Ok(
SupportedModels::Gemma29BInstruct(Some("q4f16_1".to_string())),
),

_ => Err(format!("Unsupported model: {}", s)),
}
}
}

impl ToString for SupportedModels {
fn to_string(&self) -> String {
match self {
SupportedModels::Llama321BInstruct(_) => "Llama-3.2-1B-Instruct".to_string(),
SupportedModels::Llama323BInstruct(_) => "Llama-3.2-3B-Instruct".to_string(),
SupportedModels::Mistral7BInstructV03(_) => "Mistral-7B-Instruct-v0.3".to_string(),
SupportedModels::Mixtral8x7BInstructV01(_) => "Mixtral-8x7B-Instruct-v0.1".to_string(),
SupportedModels::Gemma22BInstruct(_) => "gemma-2-2b-it".to_string(),
SupportedModels::Gemma27BInstruct(_) => "gemma-2-27b-it".to_string(),
SupportedModels::Gemma29BInstruct(_) => "gemma-2-9b-it".to_string(),
}
}
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Default)]
pub struct BlocklessLlm {
Expand All @@ -42,7 +127,7 @@ pub struct BlocklessLlm {
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, Default, PartialEq)]
pub struct LlmOptions {
pub system_message: String,
// pub max_tokens: u32,
Expand All @@ -52,18 +137,6 @@ pub struct LlmOptions {
// pub presence_penalty: f32,
}

impl Default for LlmOptions {
fn default() -> Self {
LlmOptions {
system_message: String::new(),
temperature: None,
top_p: None,
// frequency_penalty: 0.0,
// presence_penalty: 0.0,
}
}
}

impl LlmOptions {
pub fn new() -> Self {
Self::default()
Expand Down Expand Up @@ -99,12 +172,12 @@ impl TryFrom<Vec<u8>> for LlmOptions {
let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;

// Parse the JSON string
let json = json::parse(&json_str).map_err(|_| LlmErrorKind::OptionsNotSet)?;
let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;

// Extract system_message
let system_message = json["system_message"]
.as_str()
.ok_or(LlmErrorKind::OptionsNotSet)?
.ok_or(LlmErrorKind::ModelOptionsNotSet)?
.to_string();

Ok(LlmOptions {
Expand All @@ -116,9 +189,10 @@ impl TryFrom<Vec<u8>> for LlmOptions {
}

impl BlocklessLlm {
pub fn new(model_name: &str) -> Result<Self, LlmErrorKind> {
let mut llm = Self::default();
llm.set_model(model_name)?;
pub fn new(model: SupportedModels) -> Result<Self, LlmErrorKind> {
let model_name = model.to_string();
let mut llm: BlocklessLlm = Default::default();
llm.set_model(&model_name)?;
Ok(llm)
}

Expand Down Expand Up @@ -195,7 +269,7 @@ impl BlocklessLlm {
"Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
self.options, host_options
);
return Err(LlmErrorKind::OptionsNotSet);
return Err(LlmErrorKind::ModelOptionsNotSet);
}
Ok(())
}
Expand Down Expand Up @@ -224,21 +298,6 @@ impl BlocklessLlm {
let response_vec = buf[0..num_bytes as usize].to_vec();
String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
}

// TODO: response streaming - not yet supported
// - read next available chunks
// - block until chunk is read, repeat until no more chunks
// pub fn read_response_chunk(&self, buf: &mut [u8]) -> Result<u32, LlmErrorKind> {
// let mut num: u32 = 0;
// let rs = unsafe {
// llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num)
// };

// if rs != 0 {
// return Err(LlmErrorKind::from(rs));
// }
// Ok(num)
// }
}

impl Drop for BlocklessLlm {
Expand All @@ -252,19 +311,27 @@ impl Drop for BlocklessLlm {

#[derive(Debug)]
pub enum LlmErrorKind {
ModelNotSet,
OptionsNotSet,
Utf8Error,
Unknown(u8),
ModelNotSet, // 1
ModelNotSupported, // 2
ModelInitializationFailed, // 3
ModelCompletionFailed, // 4
ModelOptionsNotSet, // 5
ModelShutdownFailed, // 6
Utf8Error, // 7
RuntimeError, // 8
}

impl From<u8> for LlmErrorKind {
fn from(code: u8) -> Self {
match code {
1 => LlmErrorKind::ModelNotSet,
2 => LlmErrorKind::OptionsNotSet,
3 => LlmErrorKind::Utf8Error,
_ => LlmErrorKind::Unknown(code),
2 => LlmErrorKind::ModelNotSupported,
3 => LlmErrorKind::ModelInitializationFailed,
4 => LlmErrorKind::ModelCompletionFailed,
5 => LlmErrorKind::ModelOptionsNotSet,
6 => LlmErrorKind::ModelShutdownFailed,
7 => LlmErrorKind::Utf8Error,
_ => LlmErrorKind::RuntimeError,
}
}
}
Loading