diff --git a/Cargo.toml b/Cargo.toml index 9a7bc252..b065d3e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ name = "cocoindex_engine" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.23.5" } +pyo3 = { version = "0.23.5", features = ["chrono"] } anyhow = { version = "1.0.97", features = ["std"] } async-trait = "0.1.88" axum = "0.7.9" @@ -25,7 +25,12 @@ log = "0.4.26" regex = "1.11.1" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" -sqlx = { version = "0.8.3", features = ["chrono", "postgres", "runtime-tokio", "uuid"] } +sqlx = { version = "0.8.3", features = [ + "chrono", + "postgres", + "runtime-tokio", + "uuid", +] } tokio = { version = "1.44.1", features = [ "macros", "rt-multi-thread", diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 2a17b702..c7e50ae6 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -1,6 +1,7 @@ import typing import collections import dataclasses +import datetime import types import inspect import uuid @@ -26,6 +27,8 @@ def __init__(self, key: str, value: Any): Float64 = Annotated[float, TypeKind('Float64')] Range = Annotated[tuple[int, int], TypeKind('Range')] Json = Annotated[Any, TypeKind('Json')] +LocalDateTime = Annotated[datetime.datetime, TypeKind('LocalDateTime')] +OffsetDateTime = Annotated[datetime.datetime, TypeKind('OffsetDateTime')] COLLECTION_TYPES = ('Table', 'List') @@ -133,6 +136,12 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: kind = 'Float64' elif t is uuid.UUID: kind = 'Uuid' + elif t is datetime.date: + kind = 'Date' + elif t is datetime.time: + kind = 'Time' + elif t is datetime.datetime: + kind = 'OffsetDateTime' else: raise ValueError(f"type unsupported yet: {t}") diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index 2a1124f7..66b83466 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -8,6 +8,9 @@ pub struct ToJsonSchemaOptions { /// Use union type (with `null`) for optional fields instead. /// Models like OpenAI will reject the schema if a field is not required. pub fields_always_required: bool, + + /// If true, the JSON schema supports the `format` keyword. + pub supports_format: bool, } pub trait ToJsonSchema { @@ -49,15 +52,51 @@ impl ToJsonSchema for schema::BasicValueType { max_items: Some(2), ..Default::default() })); - schema - .metadata - .get_or_insert_with(Default::default) - .description = + schema.metadata.get_or_insert_default().description = Some("A range, start pos (inclusive), end pos (exclusive).".to_string()); } schema::BasicValueType::Uuid => { schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); - schema.format = Some("uuid".to_string()); + if options.supports_format { + schema.format = Some("uuid".to_string()); + } else { + schema.metadata.get_or_insert_default().description = + Some("A UUID, e.g. 123e4567-e89b-12d3-a456-426614174000".to_string()); + } + } + schema::BasicValueType::Date => { + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); + if options.supports_format { + schema.format = Some("date".to_string()); + } else { + schema.metadata.get_or_insert_default().description = + Some("A date, e.g. 2025-03-27".to_string()); + } + } + schema::BasicValueType::Time => { + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); + if options.supports_format { + schema.format = Some("time".to_string()); + } else { + schema.metadata.get_or_insert_default().description = + Some("A time, e.g. 13:32:12".to_string()); + } + } + schema::BasicValueType::LocalDateTime => { + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); + if options.supports_format { + schema.format = Some("date-time".to_string()); + } + schema.metadata.get_or_insert_default().description = + Some("Date time without timezone offset, e.g. 2025-03-27T13:32:12".to_string()); + } + schema::BasicValueType::OffsetDateTime => { + schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String))); + if options.supports_format { + schema.format = Some("date-time".to_string()); + } + schema.metadata.get_or_insert_default().description = + Some("Date time with timezone offset in RFC3339, e.g. 2025-03-27T13:32:12Z, 2025-03-27T07:32:12.313-06:00".to_string()); } schema::BasicValueType::Json => { // Can be any value. No type constraint. diff --git a/src/base/schema.rs b/src/base/schema.rs index 24537a19..c1fb24c5 100644 --- a/src/base/schema.rs +++ b/src/base/schema.rs @@ -38,6 +38,18 @@ pub enum BasicValueType { /// A UUID. Uuid, + /// Date (without time within the current day). + Date, + + /// Time of the day. + Time, + + /// Local date and time, without timezone. + LocalDateTime, + + /// Date and time with timezone. + OffsetDateTime, + /// A JSON value. Json, @@ -56,6 +68,10 @@ impl std::fmt::Display for BasicValueType { BasicValueType::Float64 => write!(f, "float64"), BasicValueType::Range => write!(f, "range"), BasicValueType::Uuid => write!(f, "uuid"), + BasicValueType::Date => write!(f, "date"), + BasicValueType::Time => write!(f, "time"), + BasicValueType::LocalDateTime => write!(f, "local_datetime"), + BasicValueType::OffsetDateTime => write!(f, "offset_datetime"), BasicValueType::Json => write!(f, "json"), BasicValueType::Vector(s) => write!( f, diff --git a/src/base/value.rs b/src/base/value.rs index f4cdaa69..f4149627 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -1,14 +1,16 @@ use crate::{api_bail, api_error}; use super::schema::*; -use anyhow::Result; +use anyhow::{Context, Result}; use base64::prelude::*; +use chrono::Offset; +use log::warn; use serde::{ de::{SeqAccess, Visitor}, ser::{SerializeMap, SerializeSeq, SerializeTuple}, Deserialize, Serialize, }; -use std::{collections::BTreeMap, ops::Deref, sync::Arc}; +use std::{collections::BTreeMap, ops::Deref, str::FromStr, sync::Arc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RangeValue { @@ -77,6 +79,7 @@ pub enum KeyValue { Int64(i64), Range(RangeValue), Uuid(uuid::Uuid), + Date(chrono::NaiveDate), Struct(Vec), } @@ -122,6 +125,18 @@ impl From for KeyValue { } } +impl From for KeyValue { + fn from(value: uuid::Uuid) -> Self { + KeyValue::Uuid(value) + } +} + +impl From for KeyValue { + fn from(value: chrono::NaiveDate) -> Self { + KeyValue::Date(value) + } +} + impl From> for KeyValue { fn from(value: Vec) -> Self { KeyValue::Struct(value) @@ -143,6 +158,7 @@ impl std::fmt::Display for KeyValue { KeyValue::Int64(v) => write!(f, "{}", v), KeyValue::Range(v) => write!(f, "[{}, {})", v.start, v.end), KeyValue::Uuid(v) => write!(f, "{}", v), + KeyValue::Date(v) => write!(f, "{}", v), KeyValue::Struct(v) => { write!( f, @@ -172,17 +188,19 @@ impl KeyValue { KeyValue::Bytes(Arc::from(BASE64_STANDARD.decode(v)?)) } BasicValueType::Str { .. } => KeyValue::Str(Arc::from(v)), - BasicValueType::Bool => KeyValue::Bool(v.parse::()?), - BasicValueType::Int64 => KeyValue::Int64(v.parse::()?), + BasicValueType::Bool => KeyValue::Bool(v.parse()?), + BasicValueType::Int64 => KeyValue::Int64(v.parse()?), BasicValueType::Range => { let v2 = values_iter .next() .ok_or_else(|| api_error!("Key parts less than expected"))?; KeyValue::Range(RangeValue { - start: v.parse::()?, - end: v2.parse::()?, + start: v.parse()?, + end: v2.parse()?, }) } + BasicValueType::Uuid => KeyValue::Uuid(v.parse()?), + BasicValueType::Date => KeyValue::Date(v.parse()?), schema => api_bail!("Invalid key type {schema}"), } } @@ -208,6 +226,7 @@ impl KeyValue { output.push(v.end.to_string()); } KeyValue::Uuid(v) => output.push(v.to_string()), + KeyValue::Date(v) => output.push(v.to_string()), KeyValue::Struct(v) => { for part in v { part.parts_to_strs(output); @@ -239,6 +258,7 @@ impl KeyValue { KeyValue::Int64(_) => "int64", KeyValue::Range { .. } => "range", KeyValue::Uuid(_) => "uuid", + KeyValue::Date(_) => "date", KeyValue::Struct(_) => "struct", } } @@ -278,6 +298,20 @@ impl KeyValue { } } + pub fn uuid_value(&self) -> Result { + match self { + KeyValue::Uuid(v) => Ok(*v), + _ => anyhow::bail!("expected uuid value, but got {}", self.kind_str()), + } + } + + pub fn date_value(&self) -> Result { + match self { + KeyValue::Date(v) => Ok(*v), + _ => anyhow::bail!("expected date value, but got {}", self.kind_str()), + } + } + pub fn struct_value(&self) -> Result<&Vec> { match self { KeyValue::Struct(v) => Ok(v), @@ -304,6 +338,10 @@ pub enum BasicValue { Float64(f64), Range(RangeValue), Uuid(uuid::Uuid), + Date(chrono::NaiveDate), + Time(chrono::NaiveTime), + LocalDateTime(chrono::NaiveDateTime), + OffsetDateTime(chrono::DateTime), Json(Arc), Vector(Arc<[BasicValue]>), } @@ -356,6 +394,36 @@ impl From for BasicValue { } } +impl From for BasicValue { + fn from(value: uuid::Uuid) -> Self { + BasicValue::Uuid(value) + } +} + +impl From for BasicValue { + fn from(value: chrono::NaiveDate) -> Self { + BasicValue::Date(value) + } +} + +impl From for BasicValue { + fn from(value: chrono::NaiveTime) -> Self { + BasicValue::Time(value) + } +} + +impl From for BasicValue { + fn from(value: chrono::NaiveDateTime) -> Self { + BasicValue::LocalDateTime(value) + } +} + +impl From> for BasicValue { + fn from(value: chrono::DateTime) -> Self { + BasicValue::OffsetDateTime(value) + } +} + impl From for BasicValue { fn from(value: serde_json::Value) -> Self { BasicValue::Json(Arc::from(value)) @@ -379,8 +447,12 @@ impl BasicValue { BasicValue::Int64(v) => KeyValue::Int64(v), BasicValue::Range(v) => KeyValue::Range(v), BasicValue::Uuid(v) => KeyValue::Uuid(v), + BasicValue::Date(v) => KeyValue::Date(v), BasicValue::Float32(_) | BasicValue::Float64(_) + | BasicValue::Time(_) + | BasicValue::LocalDateTime(_) + | BasicValue::OffsetDateTime(_) | BasicValue::Json(_) | BasicValue::Vector(_) => api_bail!("invalid key value type"), }; @@ -395,8 +467,12 @@ impl BasicValue { BasicValue::Int64(v) => KeyValue::Int64(*v), BasicValue::Range(v) => KeyValue::Range(*v), BasicValue::Uuid(v) => KeyValue::Uuid(*v), + BasicValue::Date(v) => KeyValue::Date(*v), BasicValue::Float32(_) | BasicValue::Float64(_) + | BasicValue::Time(_) + | BasicValue::LocalDateTime(_) + | BasicValue::OffsetDateTime(_) | BasicValue::Json(_) | BasicValue::Vector(_) => api_bail!("invalid key value type"), }; @@ -413,6 +489,10 @@ impl BasicValue { BasicValue::Float64(_) => "float64", BasicValue::Range(_) => "range", BasicValue::Uuid(_) => "uuid", + BasicValue::Date(_) => "date", + BasicValue::Time(_) => "time", + BasicValue::LocalDateTime(_) => "local_datetime", + BasicValue::OffsetDateTime(_) => "offset_datetime", BasicValue::Json(_) => "json", BasicValue::Vector(_) => "vector", } @@ -445,6 +525,7 @@ impl From for Value { KeyValue::Int64(v) => Value::Basic(BasicValue::Int64(v)), KeyValue::Range(v) => Value::Basic(BasicValue::Range(v)), KeyValue::Uuid(v) => Value::Basic(BasicValue::Uuid(v)), + KeyValue::Date(v) => Value::Basic(BasicValue::Date(v)), KeyValue::Struct(v) => Value::Struct(FieldValues { fields: v.into_iter().map(Value::from).collect(), }), @@ -744,6 +825,12 @@ impl serde::Serialize for BasicValue { BasicValue::Float64(v) => serializer.serialize_f64(*v), BasicValue::Range(v) => v.serialize(serializer), BasicValue::Uuid(v) => serializer.serialize_str(&v.to_string()), + BasicValue::Date(v) => serializer.serialize_str(&v.to_string()), + BasicValue::Time(v) => serializer.serialize_str(&v.to_string()), + BasicValue::LocalDateTime(v) => serializer.serialize_str(&v.to_string()), + BasicValue::OffsetDateTime(v) => { + serializer.serialize_str(&v.to_rfc3339_opts(chrono::SecondsFormat::AutoSi, true)) + } BasicValue::Json(v) => v.serialize(serializer), BasicValue::Vector(v) => v.serialize(serializer), } @@ -774,8 +861,27 @@ impl BasicValue { .ok_or_else(|| anyhow::anyhow!("invalid fp64 value {v}"))?, ), (v, BasicValueType::Range) => BasicValue::Range(serde_json::from_value(v)?), - (serde_json::Value::String(v), BasicValueType::Uuid) => { - BasicValue::Uuid(uuid::Uuid::parse_str(v.as_str())?) + (serde_json::Value::String(v), BasicValueType::Uuid) => BasicValue::Uuid(v.parse()?), + (serde_json::Value::String(v), BasicValueType::Date) => BasicValue::Date(v.parse()?), + (serde_json::Value::String(v), BasicValueType::Time) => BasicValue::Time(v.parse()?), + (serde_json::Value::String(v), BasicValueType::LocalDateTime) => { + BasicValue::LocalDateTime(v.parse()?) + } + (serde_json::Value::String(v), BasicValueType::OffsetDateTime) => { + match chrono::DateTime::parse_from_rfc3339(&v) { + Ok(dt) => BasicValue::OffsetDateTime(dt), + Err(e) => { + if let Ok(dt) = v.parse::() { + warn!("Datetime without timezone offset, assuming UTC"); + BasicValue::OffsetDateTime(chrono::DateTime::from_naive_utc_and_offset( + dt, + chrono::Utc.fix(), + )) + } else { + Err(e)? + } + } + } } (v, BasicValueType::Json) => BasicValue::Json(Arc::from(v)), ( diff --git a/src/execution/query.rs b/src/execution/query.rs index 82031c93..ed5e8fc6 100644 --- a/src/execution/query.rs +++ b/src/execution/query.rs @@ -88,6 +88,10 @@ impl SimpleSemanticsQueryHandler { | value::BasicValue::Bool(_) | value::BasicValue::Range(_) | value::BasicValue::Uuid(_) + | value::BasicValue::Date(_) + | value::BasicValue::Time(_) + | value::BasicValue::LocalDateTime(_) + | value::BasicValue::OffsetDateTime(_) | value::BasicValue::Json(_) | value::BasicValue::Vector(_) => { bail!("Query results is not a vector of number") diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 39cc71a8..be7cde23 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -54,9 +54,15 @@ pub trait LlmGenerationClient: Send + Sync { false } + /// If true, the LLM supports the `format` keyword in the JSON schema. + fn json_schema_supports_format(&self) -> bool { + true + } + fn to_json_schema_options(&self) -> ToJsonSchemaOptions { ToJsonSchemaOptions { fields_always_required: self.json_schema_fields_always_required(), + supports_format: self.json_schema_supports_format(), } } } diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 47c82932..ca16dede 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -101,4 +101,8 @@ impl LlmGenerationClient for Client { fn json_schema_fields_always_required(&self) -> bool { true } + + fn json_schema_supports_format(&self) -> bool { + false + } } diff --git a/src/ops/storages/postgres.rs b/src/ops/storages/postgres.rs index efdab32d..7eb8e412 100644 --- a/src/ops/storages/postgres.rs +++ b/src/ops/storages/postgres.rs @@ -82,6 +82,9 @@ fn bind_key_field<'arg>( KeyValue::Uuid(v) => { builder.push_bind(v); } + KeyValue::Date(v) => { + builder.push_bind(v); + } KeyValue::Struct(fields) => { builder.push_bind(sqlx::types::Json(fields)); } @@ -123,6 +126,18 @@ fn bind_value_field<'arg>( BasicValue::Uuid(v) => { builder.push_bind(v); } + BasicValue::Date(v) => { + builder.push_bind(v); + } + BasicValue::Time(v) => { + builder.push_bind(v); + } + BasicValue::LocalDateTime(v) => { + builder.push_bind(v); + } + BasicValue::OffsetDateTime(v) => { + builder.push_bind(v); + } BasicValue::Json(v) => { builder.push_bind(sqlx::types::Json(&**v)); } @@ -196,6 +211,18 @@ fn from_pg_value(row: &PgRow, field_idx: usize, typ: &ValueType) -> Result row .try_get::, _>(field_idx)? .map(BasicValue::Uuid), + BasicValueType::Date => row + .try_get::, _>(field_idx)? + .map(BasicValue::Date), + BasicValueType::Time => row + .try_get::, _>(field_idx)? + .map(BasicValue::Time), + BasicValueType::LocalDateTime => row + .try_get::, _>(field_idx)? + .map(BasicValue::LocalDateTime), + BasicValueType::OffsetDateTime => row + .try_get::>, _>(field_idx)? + .map(BasicValue::OffsetDateTime), BasicValueType::Json => row .try_get::, _>(field_idx)? .map(|v| BasicValue::Json(Arc::from(v))), @@ -666,6 +693,10 @@ fn to_column_type_sql(column_type: &ValueType) -> Cow<'static, str> { BasicValueType::Float64 => "double precision".into(), BasicValueType::Range => "int8range".into(), BasicValueType::Uuid => "uuid".into(), + BasicValueType::Date => "date".into(), + BasicValueType::Time => "time".into(), + BasicValueType::LocalDateTime => "timestamp".into(), + BasicValueType::OffsetDateTime => "timestamp with time zone".into(), BasicValueType::Json => "jsonb".into(), BasicValueType::Vector(vec_schema) => { if convertible_to_pgvector(vec_schema) { diff --git a/src/py/convert.rs b/src/py/convert.rs index 02bf59db..45f882b7 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -65,6 +65,10 @@ fn basic_value_to_py_object<'py>( value::BasicValue::Float64(v) => v.into_bound_py_any(py)?, value::BasicValue::Range(v) => pythonize(py, v).into_py_result()?, value::BasicValue::Uuid(v) => v.as_bytes().into_bound_py_any(py)?, + value::BasicValue::Date(v) => v.into_bound_py_any(py)?, + value::BasicValue::Time(v) => v.into_bound_py_any(py)?, + value::BasicValue::LocalDateTime(v) => v.into_bound_py_any(py)?, + value::BasicValue::OffsetDateTime(v) => v.into_bound_py_any(py)?, value::BasicValue::Json(v) => pythonize(py, v).into_py_result()?, value::BasicValue::Vector(v) => v .iter() @@ -130,6 +134,14 @@ fn basic_value_from_py_object<'py>( schema::BasicValueType::Uuid => { value::BasicValue::Uuid(uuid::Uuid::from_bytes(v.extract::()?)) } + schema::BasicValueType::Date => value::BasicValue::Date(v.extract::()?), + schema::BasicValueType::Time => value::BasicValue::Time(v.extract::()?), + schema::BasicValueType::LocalDateTime => { + value::BasicValue::LocalDateTime(v.extract::()?) + } + schema::BasicValueType::OffsetDateTime => { + value::BasicValue::OffsetDateTime(v.extract::>()?) + } schema::BasicValueType::Json => { value::BasicValue::Json(Arc::from(depythonize::(v)?)) }