From d10fc7ce9a2e51b31dd4b68d652d92daf8ed5c1b Mon Sep 17 00:00:00 2001 From: LJ Date: Mon, 10 Mar 2025 00:24:40 -0700 Subject: [PATCH] Implement function `ExtractByMistral` to extract structured data from LLM --- Cargo.toml | 2 + src/ops/functions/extract_by_mistral.rs | 131 ++++++++++++++++++++++++ src/ops/functions/mod.rs | 1 + src/ops/registration.rs | 1 + 4 files changed, 135 insertions(+) create mode 100644 src/ops/functions/extract_by_mistral.rs diff --git a/Cargo.toml b/Cargo.toml index 823514f9e..8a902365f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,4 +40,6 @@ derivative = "2.2.0" async-lock = "3.4.0" hex = "0.4.3" pythonize = "0.23.0" +# TODO: Switch to a stable tag of mistralrs after a new release is tagged. +mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" } schemars = "0.8.22" diff --git a/src/ops/functions/extract_by_mistral.rs b/src/ops/functions/extract_by_mistral.rs new file mode 100644 index 000000000..df100610d --- /dev/null +++ b/src/ops/functions/extract_by_mistral.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use mistralrs::{self, TextMessageRole}; +use serde::Serialize; + +use crate::base::json_schema::ToJsonSchema; +use crate::ops::sdk::*; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MistralModelSpec { + model_id: String, + isq_type: mistralrs::IsqType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Spec { + model: MistralModelSpec, + output_type: EnrichedValueType, + instructions: Option, +} + +struct Executor { + model: mistralrs::Model, + output_type: EnrichedValueType, + request_base: mistralrs::RequestBuilder, +} + +fn get_system_message(instructions: &Option) -> String { + let mut message = + "You are a helpful assistant that extracts structured information from text. \ +Your task is to analyze the input text and output valid JSON that matches the specified schema. \ +Be precise and only include information that is explicitly stated in the text. \ +Output only the JSON without any additional messages or explanations." + .to_string(); + + if let Some(custom_instructions) = instructions { + message.push_str("\n\n"); + message.push_str(custom_instructions); + } + + message +} + +impl Executor { + async fn new(spec: Spec) -> Result { + let model = mistralrs::TextModelBuilder::new(spec.model.model_id) + .with_isq(spec.model.isq_type) + .with_logging() + .with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())? + .build() + .await?; + let request_base = mistralrs::RequestBuilder::new() + .set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value( + spec.output_type.to_json_schema(), + )?)) + .set_deterministic_sampler() + .add_message( + TextMessageRole::System, + get_system_message(&spec.instructions), + ); + Ok(Self { + model, + output_type: spec.output_type, + request_base, + }) + } +} + +#[async_trait] +impl SimpleFunctionExecutor for Executor { + fn behavior_version(&self) -> Option { + Some(1) + } + + fn enable_cache(&self) -> bool { + true + } + + async fn evaluate(&self, input: Vec) -> Result { + let text = input.iter().next().unwrap().as_str()?; + let request = self + .request_base + .clone() + .add_message(TextMessageRole::User, text); + let response = self.model.send_chat_request(request).await?; + let response_text = response.choices[0] + .message + .content + .as_ref() + .ok_or_else(|| anyhow!("No content in response"))?; + let json_value: serde_json::Value = serde_json::from_str(response_text)?; + let value = Value::from_json(json_value, &self.output_type.typ)?; + Ok(value) + } +} + +pub struct Factory; + +#[async_trait] +impl SimpleFunctionFactoryBase for Factory { + type Spec = Spec; + + fn name(&self) -> &str { + "ExtractByMistral" + } + + fn get_output_schema( + &self, + spec: &Spec, + input_schema: &Vec, + _context: &FlowInstanceContext, + ) -> Result { + match &expect_input_1(input_schema)?.value_type.typ { + ValueType::Basic(BasicValueType::Str) => {} + t => { + api_bail!("Expect String as input type, got {}", t) + } + } + Ok(spec.output_type.clone()) + } + + async fn build_executor( + self: Arc, + spec: Spec, + _input_schema: Vec, + _context: Arc, + ) -> Result> { + Ok(Box::new(Executor::new(spec).await?)) + } +} diff --git a/src/ops/functions/mod.rs b/src/ops/functions/mod.rs index e0620fd03..ac82c6cf9 100644 --- a/src/ops/functions/mod.rs +++ b/src/ops/functions/mod.rs @@ -1 +1,2 @@ +pub mod extract_by_mistral; pub mod split_recursively; diff --git a/src/ops/registration.rs b/src/ops/registration.rs index b4b67a1e4..33fb803d6 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -8,6 +8,7 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard}; fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> { sources::local_file::Factory.register(registry)?; functions::split_recursively::Factory.register(registry)?; + functions::extract_by_mistral::Factory.register(registry)?; Arc::new(storages::postgres::Factory::default()).register(registry)?; Ok(())