Skip to content

Commit

Permalink
optimizations and refactoring of core logic
Browse files Browse the repository at this point in the history
  • Loading branch information
snowmead committed Mar 22, 2024
1 parent 8e92dc7 commit a36930f
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 89 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
publish = true
name = "llm-weaver"
version = "0.1.6"
version = "0.1.7"
edition = "2021"
description = "Manage long conversations with any LLM"
readme = "README.md"
Expand Down
157 changes: 80 additions & 77 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ pub mod architecture;
pub mod storage;
pub mod types;

#[cfg(test)]
mod tests;

pub use storage::TapestryChestHandler;
use types::{LoomError, SummaryModelTokens, WeaveError, ASSISTANT_ROLE, SYSTEM_ROLE};

Expand Down Expand Up @@ -169,6 +172,7 @@ pub trait Llm<T: Config>:
/// Prompt LLM with the supplied messages and parameters.
async fn prompt(
&self,
is_summarizing: bool,
prompt_tokens: Self::Tokens,
msgs: Vec<Self::Request>,
params: &Self::Parameters,
Expand Down Expand Up @@ -253,6 +257,10 @@ pub struct TapestryFragment<T: Config> {
}

impl<T: Config> TapestryFragment<T> {
fn new() -> Self {
Self::default()
}

/// Add a [`ContextMessage`] to the `context_messages` list.
///
/// Also increments the `context_tokens` by the number of tokens in the message.
Expand Down Expand Up @@ -304,81 +312,94 @@ pub trait Loom<T: Config> {
prompt_config: LlmConfig<T, T::PromptModel>,
summary_model_config: LlmConfig<T, T::SummaryModel>,
tapestry_id: TID,
system: String,
instructions: String,
msgs: Vec<ContextMessage<T>>,
) -> Result<<<T as Config>::PromptModel as Llm<T>>::Response> {
let system_ctx_msg = Self::build_context_message(SYSTEM_ROLE.into(), system, None);
let sys_req_msg: PromptModelRequest<T> = system_ctx_msg.clone().into();
) -> Result<(<<T as Config>::PromptModel as Llm<T>>::Response, u64, bool)> {
let instructions_ctx_msg =
Self::build_context_message(SYSTEM_ROLE.into(), instructions, None);
let instructions_req_msg: PromptModelRequest<T> = instructions_ctx_msg.clone().into();

// get latest tapestry fragment instance from storage
// Get current tapestry fragment to work with
let current_tapestry_fragment = T::Chest::get_tapestry_fragment(tapestry_id.clone(), None)
.await?
.unwrap_or_default();

// number of tokens available according to the configured model or custom max tokens
// Get max token limit which cannot be exceeded in a tapestry fragment
let max_tokens_limit = prompt_config.model.get_max_token_limit();

// allocate space for the system message
// Request messages which will be sent as a whole to the LLM
let mut req_msgs = VecPromptMsgsDeque::<T, T::PromptModel>::with_capacity(
current_tapestry_fragment.context_messages.len() + 1,
current_tapestry_fragment.context_messages.len() + 1, /* +1 for the instruction
* message to add */
);
req_msgs.push_front(sys_req_msg);

// Add instructions as the first message
req_msgs.push_front(instructions_req_msg);

// Convert and append all tapestry fragment messages to the request messages.
let mut ctx_msgs = VecDeque::from(
prompt_config
.model
.ctx_msgs_to_prompt_requests(&current_tapestry_fragment.context_messages),
);
req_msgs.append(&mut ctx_msgs);

// New messages are not added here yet since we first calculate if the new messages would
// have the tapestry fragment exceed the maximum token limit and require a summary
// generation resulting in a new tapestry fragment.
//
// Either we are starting a new tapestry fragment with the instruction and summary messages
// or we are continuing the current tapestry fragment.
let msgs_tokens = Self::count_tokens_in_messages(msgs.iter());
let does_exceeding_max_token_limit = max_tokens_limit <=
current_tapestry_fragment.context_tokens.saturating_add(&msgs_tokens);

let (mut tapestry_fragment_to_persist, was_summary_generated) =
if does_exceeding_max_token_limit {
let summary =
Self::generate_summary(summary_model_config, &current_tapestry_fragment)
.await?;
let summary_ctx_msg = Self::build_context_message(
SYSTEM_ROLE.into(),
format!("\n\"\"\"\n {}", summary),
None,
);

// Generate summary and start new tapestry instance if context tokens would exceed maximum
// amount of allowed tokens.
//
// Either we are starting a new tapestry fragment with the summary and system message or we
// are continuing the current tapestry fragment.
let (mut tapestry_fragment_to_persist, was_summary_generated) = if max_tokens_limit <=
current_tapestry_fragment.context_tokens.saturating_add(&msgs_tokens)
{
let summary =
Self::generate_summary(summary_model_config, &current_tapestry_fragment).await?;
let summary_ctx_msg = Self::build_context_message(
SYSTEM_ROLE.into(),
format!("\n\"\"\"\n {}", summary),
None,
);

req_msgs.push_front(summary_ctx_msg.clone().into());

// keep system and summary messages
req_msgs.truncate(2);

let mut new_tapestry_fragment = TapestryFragment {
context_messages: vec![system_ctx_msg, summary_ctx_msg],
..Default::default()
};
// Truncate all tapestry fragment messages except for the instructions and add the
// summary
req_msgs.truncate(1);
req_msgs.extend(vec![summary_ctx_msg.clone().into()]);

new_tapestry_fragment.context_tokens =
Self::count_tokens_in_messages(new_tapestry_fragment.context_messages.iter());
// Create new tapestry fragment
let mut new_tapestry_fragment = TapestryFragment::new();
new_tapestry_fragment
.extend_messages(vec![instructions_ctx_msg, summary_ctx_msg])?;

(new_tapestry_fragment, true)
} else {
(current_tapestry_fragment, false)
};

(new_tapestry_fragment, true)
} else {
(current_tapestry_fragment, false)
};
// Add new messages to the request messages
req_msgs.extend(msgs.iter().map(|m| m.clone().into()).collect::<Vec<_>>());

// Tokens available for LLM response which would not exceed maximum token limit
let max_tokens = max_tokens_limit
.saturating_sub(&tapestry_fragment_to_persist.context_tokens)
.saturating_sub(&msgs_tokens);

// Execute prompt to LLM
let response = prompt_config
.model
.prompt(req_msgs.tokens, req_msgs.into_vec(), &prompt_config.params, max_tokens)
.prompt(false, req_msgs.tokens, req_msgs.into_vec(), &prompt_config.params, max_tokens)
.await
.map_err(|e| {
error!("Failed to prompt LLM: {}", e);
e
})?;

// Add new messages and response to the tapestry fragment which will be persisted in the
// database
if let Err(e) = tapestry_fragment_to_persist.extend_messages(
msgs.into_iter()
.chain(vec![Self::build_context_message(
Expand All @@ -393,10 +414,10 @@ pub trait Loom<T: Config> {
}
debug!("Saving tapestry fragment: {:?}", tapestry_fragment_to_persist);

// save tapestry fragment to storage
// when summarized, the tapestry_fragment will be saved under a new instance
T::Chest::save_tapestry_fragment(
tapestry_id,
// Save tapestry fragment to database
// When summarized, the tapestry_fragment will be saved under a new instance
let tapestry_fragment_id = T::Chest::save_tapestry_fragment(
&tapestry_id,
tapestry_fragment_to_persist,
was_summary_generated,
)
Expand All @@ -406,7 +427,7 @@ pub trait Loom<T: Config> {
e
})?;

Ok(response)
Ok((response, tapestry_fragment_id, was_summary_generated))
}

/// Generates the summary of the current [`TapestryFragment`] instance.
Expand All @@ -419,32 +440,18 @@ pub trait Loom<T: Config> {
summary_model_config: LlmConfig<T, T::SummaryModel>,
tapestry_fragment: &TapestryFragment<T>,
) -> Result<String> {
let summary_model_tokens =
T::convert_prompt_tokens_to_summary_model_tokens(tapestry_fragment.context_tokens);

let mut summary_generation_prompt =
VecPromptMsgsDeque::<T, T::SummaryModel>::new(summary_model_tokens);

let gen_summary_prompt = Self::build_context_message(
SYSTEM_ROLE.into(),
format!(
"Generate a summary of the entire adventure so far. Respond with {} words or less",
summary_model_config.model.convert_tokens_to_words(summary_model_tokens)
),
None,
)
.into();
let mut summary_generation_prompt = VecPromptMsgsDeque::<T, T::SummaryModel>::new();

summary_generation_prompt.extend(
summary_model_config
.model
.ctx_msgs_to_prompt_requests(tapestry_fragment.context_messages.as_slice()),
);
summary_generation_prompt.push(gen_summary_prompt);

let res = summary_model_config
.model
.prompt(
true,
summary_generation_prompt.tokens,
summary_generation_prompt.into_vec(),
&summary_model_config.params,
Expand Down Expand Up @@ -491,33 +498,26 @@ struct VecPromptMsgsDeque<T: Config, L: Llm<T>> {
}

impl<T: Config, L: Llm<T>> VecPromptMsgsDeque<T, L> {
fn new(tokens: L::Tokens) -> Self {
Self { tokens, inner: VecDeque::new() }
fn new() -> Self {
Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::new() }
}

fn with_capacity(capacity: usize) -> Self {
Self { tokens: L::Tokens::from_u8(0).unwrap(), inner: VecDeque::with_capacity(capacity) }
}

fn push(&mut self, msg_reqs: L::Request) {
let tokens = L::count_tokens(msg_reqs.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
self.inner.push_back(msg_reqs);
}

fn push_front(&mut self, msg_reqs: L::Request) {
let tokens = L::count_tokens(msg_reqs.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
self.inner.push_front(msg_reqs);
}

fn append(&mut self, msg_reqs: &mut VecDeque<L::Request>) {
msg_reqs.iter().for_each(|msg_req| {
let msg_tokens = L::count_tokens(msg_req.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&msg_tokens);
});
self.inner.append(msg_reqs);

for msg_req in msg_reqs {
let tokens = L::count_tokens(msg_req.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
}
}

fn truncate(&mut self, len: usize) {
Expand All @@ -527,14 +527,17 @@ impl<T: Config, L: Llm<T>> VecPromptMsgsDeque<T, L> {
tokens = tokens.saturating_add(&msg_tokens);
}
self.inner.truncate(len);
self.tokens = tokens;
}

fn extend(&mut self, msg_reqs: Vec<L::Request>) {
let mut tokens = L::Tokens::from_u8(0).unwrap();
for msg_req in &msg_reqs {
let tokens = L::count_tokens(msg_req.to_string()).unwrap_or_default();
self.tokens = self.tokens.saturating_add(&tokens);
let msg_tokens = L::count_tokens(msg_req.to_string()).unwrap_or_default();
tokens = tokens.saturating_add(&msg_tokens);
}
self.inner.extend(msg_reqs);
self.tokens = tokens;
}

fn into_vec(self) -> Vec<L::Request> {
Expand Down

0 comments on commit a36930f

Please sign in to comment.