From a1eb81716e9eadf42e25bb76660d9b000ac6c483 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 16 Apr 2025 19:24:24 -0700 Subject: [PATCH] Add a `ParseJson` function to parse text into JSON. --- python/cocoindex/functions.py | 3 + src/ops/functions/mod.rs | 1 + src/ops/functions/parse_json.rs | 104 ++++++++++++++++++++++++++++++++ src/ops/registration.rs | 1 + 4 files changed, 109 insertions(+) create mode 100644 src/ops/functions/parse_json.rs diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index fa21b91a..e0e77457 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -5,6 +5,9 @@ from .typing import Float32, Vector, TypeAttr from . import op, llm +class ParseJson(op.FunctionSpec): + """Parse a text into a JSON object.""" + class SplitRecursively(op.FunctionSpec): """Split a document (in string) recursively.""" diff --git a/src/ops/functions/mod.rs b/src/ops/functions/mod.rs index 254c2998..dc6fe15a 100644 --- a/src/ops/functions/mod.rs +++ b/src/ops/functions/mod.rs @@ -1,2 +1,3 @@ pub mod extract_by_llm; +pub mod parse_json; pub mod split_recursively; diff --git a/src/ops/functions/parse_json.rs b/src/ops/functions/parse_json.rs new file mode 100644 index 00000000..b014a52d --- /dev/null +++ b/src/ops/functions/parse_json.rs @@ -0,0 +1,104 @@ +use crate::ops::sdk::*; +use anyhow::Result; +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; +use unicase::UniCase; + +pub struct Args { + text: ResolvedOpArg, + language: Option, +} + +type ParseFn = fn(&str) -> Result; +struct LanguageConfig { + parse_fn: ParseFn, +} + +fn add_language<'a>( + output: &'a mut HashMap, Arc>, + name: &'static str, + aliases: impl IntoIterator, + parse_fn: ParseFn, +) { + let lang_config = Arc::new(LanguageConfig { parse_fn }); + for name in std::iter::once(name).chain(aliases.into_iter()) { + if output.insert(name.into(), lang_config.clone()).is_some() { + panic!("Language `{name}` already exists"); + } + } +} + +fn parse_json(text: &str) -> Result { + Ok(serde_json::from_str(text)?) +} + +static PARSE_FN_BY_LANG: LazyLock, Arc>> = + LazyLock::new(|| { + let mut map = HashMap::new(); + add_language(&mut map, "json", [".json"], parse_json); + map + }); + +struct Executor { + args: Args, +} + +#[async_trait] +impl SimpleFunctionExecutor for Executor { + async fn evaluate(&self, input: Vec) -> Result { + let text = self.args.text.value(&input)?.as_str()?; + let lang_config = { + let language = self.args.language.value(&input)?; + language + .optional() + .map(|v| anyhow::Ok(v.as_str()?.as_ref())) + .transpose()? + .and_then(|lang| PARSE_FN_BY_LANG.get(&UniCase::new(lang))) + }; + let parse_fn = lang_config.map(|c| c.parse_fn).unwrap_or(parse_json); + let parsed_value = parse_fn(text)?; + Ok(value::Value::Basic(value::BasicValue::Json(Arc::new( + parsed_value, + )))) + } +} + +pub struct Factory; + +#[async_trait] +impl SimpleFunctionFactoryBase for Factory { + type Spec = EmptySpec; + type ResolvedArgs = Args; + + fn name(&self) -> &str { + "ParseJson" + } + + fn resolve_schema( + &self, + _spec: &EmptySpec, + args_resolver: &mut OpArgsResolver<'_>, + _context: &FlowInstanceContext, + ) -> Result<(Args, EnrichedValueType)> { + let args = Args { + text: args_resolver + .next_arg("text")? + .expect_type(&ValueType::Basic(BasicValueType::Str))?, + language: args_resolver + .next_optional_arg("language")? + .expect_type(&ValueType::Basic(BasicValueType::Str))?, + }; + + let output_schema = make_output_type(BasicValueType::Json); + Ok((args, output_schema)) + } + + async fn build_executor( + self: Arc, + _spec: EmptySpec, + args: Args, + _context: Arc, + ) -> Result> { + Ok(Box::new(Executor { args })) + } +} diff --git a/src/ops/registration.rs b/src/ops/registration.rs index a6195ea0..515a93f9 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -9,6 +9,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result sources::local_file::Factory.register(registry)?; sources::google_drive::Factory.register(registry)?; + functions::parse_json::Factory.register(registry)?; functions::split_recursively::Factory.register(registry)?; functions::extract_by_llm::Factory.register(registry)?;