Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 11 additions & 56 deletions src/ops/functions/split_by_separators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@ use anyhow::{Context, Result};
use regex::Regex;
use std::sync::Arc;

use crate::base::field_attrs;
use crate::ops::registry::ExecutorFactoryRegistry;
use crate::ops::shared::split::{Position, set_output_positions};
use crate::ops::shared::split::{Position, make_common_chunk_schema, set_output_positions};
use crate::{fields_value, ops::sdk::*};

#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "UPPERCASE")]
enum KeepSep {
NONE,
LEFT,
RIGHT,
Left,
Right,
}

#[derive(Serialize, Deserialize)]
struct Spec {
// Python SDK provides defaults/values.
separators_regex: Vec<String>,
keep_separator: KeepSep,
keep_separator: Option<KeepSep>,
include_empty: bool,
trim: bool,
}
Expand Down Expand Up @@ -90,13 +88,13 @@ impl SimpleFunctionExecutor for Executor {
let mut start = 0usize;
for m in re.find_iter(full_text) {
let end = match self.spec.keep_separator {
KeepSep::LEFT => m.end(),
KeepSep::NONE | KeepSep::RIGHT => m.start(),
Some(KeepSep::Left) => m.end(),
Some(KeepSep::Right) | None => m.start(),
};
add_range(start, end);
start = match self.spec.keep_separator {
KeepSep::RIGHT => m.start(),
KeepSep::NONE | KeepSep::LEFT => m.end(),
Some(KeepSep::Right) => m.start(),
_ => m.end(),
};
}
add_range(start, full_text.len());
Expand Down Expand Up @@ -154,50 +152,7 @@ impl SimpleFunctionFactoryBase for Factory {
.required()?,
};

// start/end structs exactly like SplitRecursively
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
fields: Arc::new(vec![
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
]),
description: None,
});

let mut struct_schema = StructSchema::default();
let mut sb = StructSchemaBuilder::new(&mut struct_schema);
sb.add_field(FieldSchema::new(
"location",
make_output_type(BasicValueType::Range),
));
sb.add_field(FieldSchema::new(
"text",
make_output_type(BasicValueType::Str),
));
sb.add_field(FieldSchema::new(
"start",
schema::EnrichedValueType {
typ: pos_struct.clone(),
nullable: false,
attrs: Default::default(),
},
));
sb.add_field(FieldSchema::new(
"end",
schema::EnrichedValueType {
typ: pos_struct,
nullable: false,
attrs: Default::default(),
},
));
let output_schema = make_output_type(TableSchema::new(
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
struct_schema,
))
.with_attr(
field_attrs::CHUNK_BASE_TEXT,
serde_json::to_value(args_resolver.get_analyze_value(&args.text))?,
);
let output_schema = make_common_chunk_schema(args_resolver, &args.text)?;
Ok((args, output_schema))
}

Expand All @@ -224,7 +179,7 @@ mod tests {
async fn test_split_by_separators_paragraphs() {
let spec = Spec {
separators_regex: vec![r"\n\n+".to_string()],
keep_separator: KeepSep::NONE,
keep_separator: None,
include_empty: false,
trim: true,
};
Expand Down Expand Up @@ -268,7 +223,7 @@ mod tests {
async fn test_split_by_separators_keep_right() {
let spec = Spec {
separators_regex: vec![r"\.".to_string()],
keep_separator: KeepSep::RIGHT,
keep_separator: Some(KeepSep::Right),
include_empty: false,
trim: true,
};
Expand Down
129 changes: 4 additions & 125 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use std::sync::LazyLock;
use std::{collections::HashMap, sync::Arc};
use unicase::UniCase;

use crate::base::field_attrs;
use crate::ops::registry::ExecutorFactoryRegistry;
use crate::ops::shared::split::{Position, set_output_positions};
use crate::{fields_value, ops::sdk::*};

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -479,36 +478,6 @@ impl<'s> AtomChunksCollector<'s> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
struct OutputPosition {
char_offset: usize,
line: u32,
column: u32,
}

impl OutputPosition {
fn into_output(self) -> value::Value {
value::Value::Struct(fields_value!(
self.char_offset as i64,
self.line as i64,
self.column as i64
))
}
}
struct Position {
byte_offset: usize,
output: Option<OutputPosition>,
}

impl Position {
fn new(byte_offset: usize) -> Self {
Self {
byte_offset,
output: None,
}
}
}

struct ChunkOutput<'s> {
start_pos: Position,
end_pos: Position,
Expand Down Expand Up @@ -826,55 +795,6 @@ impl Executor {
}
}

fn set_output_positions<'a>(text: &str, positions: impl Iterator<Item = &'a mut Position>) {
let mut positions = positions.collect::<Vec<_>>();
positions.sort_by_key(|o| o.byte_offset);

let mut positions_iter = positions.iter_mut();
let Some(mut next_position) = positions_iter.next() else {
return;
};

let mut char_offset = 0;
let mut line = 1;
let mut column = 1;
for (byte_offset, ch) in text.char_indices() {
while next_position.byte_offset == byte_offset {
next_position.output = Some(OutputPosition {
char_offset,
line,
column,
});
if let Some(position) = positions_iter.next() {
next_position = position;
} else {
return;
}
}
char_offset += 1;
if ch == '\n' {
line += 1;
column = 1;
} else {
column += 1;
}
}

// Offsets after the last char.
loop {
next_position.output = Some(OutputPosition {
char_offset,
line,
column,
});
if let Some(position) = positions_iter.next() {
next_position = position;
} else {
return;
}
}
}

#[async_trait]
impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
Expand Down Expand Up @@ -997,49 +917,8 @@ impl SimpleFunctionFactoryBase for Factory {
.optional(),
};

let pos_struct = schema::ValueType::Struct(schema::StructSchema {
fields: Arc::new(vec![
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
]),
description: None,
});

let mut struct_schema = StructSchema::default();
let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema);
schema_builder.add_field(FieldSchema::new(
"location",
make_output_type(BasicValueType::Range),
));
schema_builder.add_field(FieldSchema::new(
"text",
make_output_type(BasicValueType::Str),
));
schema_builder.add_field(FieldSchema::new(
"start",
schema::EnrichedValueType {
typ: pos_struct.clone(),
nullable: false,
attrs: Default::default(),
},
));
schema_builder.add_field(FieldSchema::new(
"end",
schema::EnrichedValueType {
typ: pos_struct,
nullable: false,
attrs: Default::default(),
},
));
let output_schema = make_output_type(TableSchema::new(
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
struct_schema,
))
.with_attr(
field_attrs::CHUNK_BASE_TEXT,
serde_json::to_value(args_resolver.get_analyze_value(&args.text))?,
);
let output_schema =
crate::ops::shared::split::make_common_chunk_schema(args_resolver, &args.text)?;
Ok((args, output_schema))
}

Expand All @@ -1060,7 +939,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::functions::test_utils::test_flow_function;
use crate::ops::{functions::test_utils::test_flow_function, shared::split::OutputPosition};

// Helper function to assert chunk text and its consistency with the range within the original text.
fn assert_chunk_text_consistency(
Expand Down
63 changes: 62 additions & 1 deletion src/ops/shared/split.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
use crate::{fields_value, ops::sdk::value};
use crate::{
base::field_attrs,
fields_value,
ops::sdk::value,
ops::sdk::{
BasicValueType, EnrichedValueType, FieldSchema, KTableInfo, OpArgsResolver, StructSchema,
StructSchemaBuilder, TableKind, TableSchema, make_output_type, schema,
},
};
use anyhow::Result;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OutputPosition {
Expand Down Expand Up @@ -79,3 +88,55 @@ pub fn set_output_positions<'a>(text: &str, positions: impl Iterator<Item = &'a
}
}
}

/// Build the common chunk output schema used by splitters.
/// Fields: `location: Range`, `text: Str`, `start: {offset,line,column}`, `end: {offset,line,column}`.
pub fn make_common_chunk_schema<'a>(
args_resolver: &OpArgsResolver<'a>,
text_arg: &crate::ops::sdk::ResolvedOpArg,
) -> Result<EnrichedValueType> {
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
fields: std::sync::Arc::new(vec![
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
]),
description: None,
});

let mut struct_schema = StructSchema::default();
let mut sb = StructSchemaBuilder::new(&mut struct_schema);
sb.add_field(FieldSchema::new(
"location",
make_output_type(BasicValueType::Range),
));
sb.add_field(FieldSchema::new(
"text",
make_output_type(BasicValueType::Str),
));
sb.add_field(FieldSchema::new(
"start",
schema::EnrichedValueType {
typ: pos_struct.clone(),
nullable: false,
attrs: Default::default(),
},
));
sb.add_field(FieldSchema::new(
"end",
schema::EnrichedValueType {
typ: pos_struct,
nullable: false,
attrs: Default::default(),
},
));
let output_schema = make_output_type(TableSchema::new(
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
struct_schema,
))
.with_attr(
field_attrs::CHUNK_BASE_TEXT,
serde_json::to_value(args_resolver.get_analyze_value(text_arg))?,
);
Ok(output_schema)
}
Loading