Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ derivative = "2.2.0"
async-lock = "3.4.0"
hex = "0.4.3"
pythonize = "0.23.0"
# TODO: Switch to a stable tag of mistralrs after a new release is tagged.
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }
schemars = "0.8.22"
131 changes: 131 additions & 0 deletions src/ops/functions/extract_by_mistral.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use std::sync::Arc;

use anyhow::anyhow;
use mistralrs::{self, TextMessageRole};
use serde::Serialize;

use crate::base::json_schema::ToJsonSchema;
use crate::ops::sdk::*;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MistralModelSpec {
model_id: String,
isq_type: mistralrs::IsqType,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spec {
model: MistralModelSpec,
output_type: EnrichedValueType,
instructions: Option<String>,
}

struct Executor {
model: mistralrs::Model,
output_type: EnrichedValueType,
request_base: mistralrs::RequestBuilder,
}

fn get_system_message(instructions: &Option<String>) -> String {
let mut message =
"You are a helpful assistant that extracts structured information from text. \
Your task is to analyze the input text and output valid JSON that matches the specified schema. \
Be precise and only include information that is explicitly stated in the text. \
Output only the JSON without any additional messages or explanations."
.to_string();

if let Some(custom_instructions) = instructions {
message.push_str("\n\n");
message.push_str(custom_instructions);
}

message
}

impl Executor {
async fn new(spec: Spec) -> Result<Self> {
let model = mistralrs::TextModelBuilder::new(spec.model.model_id)
.with_isq(spec.model.isq_type)
.with_logging()
.with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())?
.build()
.await?;
let request_base = mistralrs::RequestBuilder::new()
.set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value(
spec.output_type.to_json_schema(),
)?))
.set_deterministic_sampler()
.add_message(
TextMessageRole::System,
get_system_message(&spec.instructions),
);
Ok(Self {
model,
output_type: spec.output_type,
request_base,
})
}
}

#[async_trait]
impl SimpleFunctionExecutor for Executor {
fn behavior_version(&self) -> Option<u32> {
Some(1)
}

fn enable_cache(&self) -> bool {
true
}

async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let text = input.iter().next().unwrap().as_str()?;
let request = self
.request_base
.clone()
.add_message(TextMessageRole::User, text);
let response = self.model.send_chat_request(request).await?;
let response_text = response.choices[0]
.message
.content
.as_ref()
.ok_or_else(|| anyhow!("No content in response"))?;
let json_value: serde_json::Value = serde_json::from_str(response_text)?;
let value = Value::from_json(json_value, &self.output_type.typ)?;
Ok(value)
}
}

pub struct Factory;

#[async_trait]
impl SimpleFunctionFactoryBase for Factory {
type Spec = Spec;

fn name(&self) -> &str {
"ExtractByMistral"
}

fn get_output_schema(
&self,
spec: &Spec,
input_schema: &Vec<OpArgSchema>,
_context: &FlowInstanceContext,
) -> Result<EnrichedValueType> {
match &expect_input_1(input_schema)?.value_type.typ {
ValueType::Basic(BasicValueType::Str) => {}
t => {
api_bail!("Expect String as input type, got {}", t)
}
}
Ok(spec.output_type.clone())
}

async fn build_executor(
self: Arc<Self>,
spec: Spec,
_input_schema: Vec<OpArgSchema>,
_context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SimpleFunctionExecutor>> {
Ok(Box::new(Executor::new(spec).await?))
}
}
1 change: 1 addition & 0 deletions src/ops/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod extract_by_mistral;
pub mod split_recursively;
1 change: 1 addition & 0 deletions src/ops/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard};
fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
sources::local_file::Factory.register(registry)?;
functions::split_recursively::Factory.register(registry)?;
functions::extract_by_mistral::Factory.register(registry)?;
Arc::new(storages::postgres::Factory::default()).register(registry)?;

Ok(())
Expand Down
Loading