diff --git a/src/ops/functions/split_by_separators.rs b/src/ops/functions/split_by_separators.rs index 695443b3..280d70cd 100644 --- a/src/ops/functions/split_by_separators.rs +++ b/src/ops/functions/split_by_separators.rs @@ -2,24 +2,22 @@ use anyhow::{Context, Result}; use regex::Regex; use std::sync::Arc; -use crate::base::field_attrs; use crate::ops::registry::ExecutorFactoryRegistry; -use crate::ops::shared::split::{Position, set_output_positions}; +use crate::ops::shared::split::{Position, make_common_chunk_schema, set_output_positions}; use crate::{fields_value, ops::sdk::*}; #[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq)] #[serde(rename_all = "UPPERCASE")] enum KeepSep { - NONE, - LEFT, - RIGHT, + Left, + Right, } #[derive(Serialize, Deserialize)] struct Spec { // Python SDK provides defaults/values. separators_regex: Vec, - keep_separator: KeepSep, + keep_separator: Option, include_empty: bool, trim: bool, } @@ -90,13 +88,13 @@ impl SimpleFunctionExecutor for Executor { let mut start = 0usize; for m in re.find_iter(full_text) { let end = match self.spec.keep_separator { - KeepSep::LEFT => m.end(), - KeepSep::NONE | KeepSep::RIGHT => m.start(), + Some(KeepSep::Left) => m.end(), + Some(KeepSep::Right) | None => m.start(), }; add_range(start, end); start = match self.spec.keep_separator { - KeepSep::RIGHT => m.start(), - KeepSep::NONE | KeepSep::LEFT => m.end(), + Some(KeepSep::Right) => m.start(), + _ => m.end(), }; } add_range(start, full_text.len()); @@ -154,50 +152,7 @@ impl SimpleFunctionFactoryBase for Factory { .required()?, }; - // start/end structs exactly like SplitRecursively - let pos_struct = schema::ValueType::Struct(schema::StructSchema { - fields: Arc::new(vec![ - schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)), - schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)), - schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)), - ]), - description: None, - }); - - let mut struct_schema = StructSchema::default(); - let mut sb = StructSchemaBuilder::new(&mut struct_schema); - sb.add_field(FieldSchema::new( - "location", - make_output_type(BasicValueType::Range), - )); - sb.add_field(FieldSchema::new( - "text", - make_output_type(BasicValueType::Str), - )); - sb.add_field(FieldSchema::new( - "start", - schema::EnrichedValueType { - typ: pos_struct.clone(), - nullable: false, - attrs: Default::default(), - }, - )); - sb.add_field(FieldSchema::new( - "end", - schema::EnrichedValueType { - typ: pos_struct, - nullable: false, - attrs: Default::default(), - }, - )); - let output_schema = make_output_type(TableSchema::new( - TableKind::KTable(KTableInfo { num_key_parts: 1 }), - struct_schema, - )) - .with_attr( - field_attrs::CHUNK_BASE_TEXT, - serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, - ); + let output_schema = make_common_chunk_schema(args_resolver, &args.text)?; Ok((args, output_schema)) } @@ -224,7 +179,7 @@ mod tests { async fn test_split_by_separators_paragraphs() { let spec = Spec { separators_regex: vec![r"\n\n+".to_string()], - keep_separator: KeepSep::NONE, + keep_separator: None, include_empty: false, trim: true, }; @@ -268,7 +223,7 @@ mod tests { async fn test_split_by_separators_keep_right() { let spec = Spec { separators_regex: vec![r"\.".to_string()], - keep_separator: KeepSep::RIGHT, + keep_separator: Some(KeepSep::Right), include_empty: false, trim: true, }; diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 7f13c597..5386bb96 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -6,8 +6,7 @@ use std::sync::LazyLock; use std::{collections::HashMap, sync::Arc}; use unicase::UniCase; -use crate::base::field_attrs; -use crate::ops::registry::ExecutorFactoryRegistry; +use crate::ops::shared::split::{Position, set_output_positions}; use crate::{fields_value, ops::sdk::*}; #[derive(Serialize, Deserialize)] @@ -479,36 +478,6 @@ impl<'s> AtomChunksCollector<'s> { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct OutputPosition { - char_offset: usize, - line: u32, - column: u32, -} - -impl OutputPosition { - fn into_output(self) -> value::Value { - value::Value::Struct(fields_value!( - self.char_offset as i64, - self.line as i64, - self.column as i64 - )) - } -} -struct Position { - byte_offset: usize, - output: Option, -} - -impl Position { - fn new(byte_offset: usize) -> Self { - Self { - byte_offset, - output: None, - } - } -} - struct ChunkOutput<'s> { start_pos: Position, end_pos: Position, @@ -826,55 +795,6 @@ impl Executor { } } -fn set_output_positions<'a>(text: &str, positions: impl Iterator) { - let mut positions = positions.collect::>(); - positions.sort_by_key(|o| o.byte_offset); - - let mut positions_iter = positions.iter_mut(); - let Some(mut next_position) = positions_iter.next() else { - return; - }; - - let mut char_offset = 0; - let mut line = 1; - let mut column = 1; - for (byte_offset, ch) in text.char_indices() { - while next_position.byte_offset == byte_offset { - next_position.output = Some(OutputPosition { - char_offset, - line, - column, - }); - if let Some(position) = positions_iter.next() { - next_position = position; - } else { - return; - } - } - char_offset += 1; - if ch == '\n' { - line += 1; - column = 1; - } else { - column += 1; - } - } - - // Offsets after the last char. - loop { - next_position.output = Some(OutputPosition { - char_offset, - line, - column, - }); - if let Some(position) = positions_iter.next() { - next_position = position; - } else { - return; - } - } -} - #[async_trait] impl SimpleFunctionExecutor for Executor { async fn evaluate(&self, input: Vec) -> Result { @@ -997,49 +917,8 @@ impl SimpleFunctionFactoryBase for Factory { .optional(), }; - let pos_struct = schema::ValueType::Struct(schema::StructSchema { - fields: Arc::new(vec![ - schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)), - schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)), - schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)), - ]), - description: None, - }); - - let mut struct_schema = StructSchema::default(); - let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema); - schema_builder.add_field(FieldSchema::new( - "location", - make_output_type(BasicValueType::Range), - )); - schema_builder.add_field(FieldSchema::new( - "text", - make_output_type(BasicValueType::Str), - )); - schema_builder.add_field(FieldSchema::new( - "start", - schema::EnrichedValueType { - typ: pos_struct.clone(), - nullable: false, - attrs: Default::default(), - }, - )); - schema_builder.add_field(FieldSchema::new( - "end", - schema::EnrichedValueType { - typ: pos_struct, - nullable: false, - attrs: Default::default(), - }, - )); - let output_schema = make_output_type(TableSchema::new( - TableKind::KTable(KTableInfo { num_key_parts: 1 }), - struct_schema, - )) - .with_attr( - field_attrs::CHUNK_BASE_TEXT, - serde_json::to_value(args_resolver.get_analyze_value(&args.text))?, - ); + let output_schema = + crate::ops::shared::split::make_common_chunk_schema(args_resolver, &args.text)?; Ok((args, output_schema)) } @@ -1060,7 +939,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use crate::ops::functions::test_utils::test_flow_function; + use crate::ops::{functions::test_utils::test_flow_function, shared::split::OutputPosition}; // Helper function to assert chunk text and its consistency with the range within the original text. fn assert_chunk_text_consistency( diff --git a/src/ops/shared/split.rs b/src/ops/shared/split.rs index 641c5fde..9cfde66a 100644 --- a/src/ops/shared/split.rs +++ b/src/ops/shared/split.rs @@ -1,4 +1,13 @@ -use crate::{fields_value, ops::sdk::value}; +use crate::{ + base::field_attrs, + fields_value, + ops::sdk::value, + ops::sdk::{ + BasicValueType, EnrichedValueType, FieldSchema, KTableInfo, OpArgsResolver, StructSchema, + StructSchemaBuilder, TableKind, TableSchema, make_output_type, schema, + }, +}; +use anyhow::Result; #[derive(Debug, Clone, PartialEq, Eq)] pub struct OutputPosition { @@ -79,3 +88,55 @@ pub fn set_output_positions<'a>(text: &str, positions: impl Iterator( + args_resolver: &OpArgsResolver<'a>, + text_arg: &crate::ops::sdk::ResolvedOpArg, +) -> Result { + let pos_struct = schema::ValueType::Struct(schema::StructSchema { + fields: std::sync::Arc::new(vec![ + schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)), + schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)), + schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)), + ]), + description: None, + }); + + let mut struct_schema = StructSchema::default(); + let mut sb = StructSchemaBuilder::new(&mut struct_schema); + sb.add_field(FieldSchema::new( + "location", + make_output_type(BasicValueType::Range), + )); + sb.add_field(FieldSchema::new( + "text", + make_output_type(BasicValueType::Str), + )); + sb.add_field(FieldSchema::new( + "start", + schema::EnrichedValueType { + typ: pos_struct.clone(), + nullable: false, + attrs: Default::default(), + }, + )); + sb.add_field(FieldSchema::new( + "end", + schema::EnrichedValueType { + typ: pos_struct, + nullable: false, + attrs: Default::default(), + }, + )); + let output_schema = make_output_type(TableSchema::new( + TableKind::KTable(KTableInfo { num_key_parts: 1 }), + struct_schema, + )) + .with_attr( + field_attrs::CHUNK_BASE_TEXT, + serde_json::to_value(args_resolver.get_analyze_value(text_arg))?, + ); + Ok(output_schema) +}