diff --git a/python/cocoindex/lib.py b/python/cocoindex/lib.py index 56fcd118..832ccb77 100644 --- a/python/cocoindex/lib.py +++ b/python/cocoindex/lib.py @@ -1,7 +1,6 @@ """ Library level functions and states. """ -import asyncio import os import sys import functools @@ -12,6 +11,7 @@ from . import _engine from . import flow, query, cli +from .convert import dump_engine_object def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False): @@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool else: target[name] = value +@dataclass +class DatabaseConnectionSpec: + uri: str + user: str | None = None + password: str | None = None + @dataclass class Settings: """Settings for the cocoindex library.""" - database_url: str + database: DatabaseConnectionSpec @classmethod def from_env(cls) -> Self: """Load settings from environment variables.""" - kwargs: dict[str, str] = dict() - _load_field(kwargs, "database_url", "COCOINDEX_DATABASE_URL", required=True) - - return cls(**kwargs) + db_kwargs: dict[str, str] = dict() + _load_field(db_kwargs, "uri", "COCOINDEX_DATABASE_URL", required=True) + _load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER") + _load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD") + database = DatabaseConnectionSpec(**db_kwargs) + return cls(database=database) def init(settings: Settings): """Initialize the cocoindex library.""" - _engine.init(settings.__dict__) + _engine.init(dump_engine_object(settings)) @dataclass class ServerSettings: diff --git a/src/base/spec.rs b/src/base/spec.rs index 3b63ab4b..09c229c6 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -322,12 +322,17 @@ impl Clone for AuthEntryReference { } } +#[derive(Serialize, Deserialize)] +struct UntypedAuthEntryReference { + key: T, +} + impl Serialize for AuthEntryReference { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { - self.key.serialize(serializer) + UntypedAuthEntryReference { key: &self.key }.serialize(serializer) } } @@ -336,8 +341,9 @@ impl<'de, T> Deserialize<'de> for AuthEntryReference { where D: serde::Deserializer<'de>, { - Ok(Self { - key: String::deserialize(deserializer)?, + let untyped_ref = UntypedAuthEntryReference::::deserialize(deserializer)?; + Ok(AuthEntryReference { + key: untyped_ref.key, _phantom: std::marker::PhantomData, }) } diff --git a/src/execution/db_tracking_setup.rs b/src/execution/db_tracking_setup.rs index 0a2a8289..89fdc873 100644 --- a/src/execution/db_tracking_setup.rs +++ b/src/execution/db_tracking_setup.rs @@ -157,7 +157,7 @@ impl ResourceSetupStatusCheck for TrackingTableSetupStatusCheck { } async fn apply_change(&self) -> Result<()> { - let pool = &get_lib_context()?.pool; + let pool = &get_lib_context()?.builtin_db_pool; if let Some(desired) = &self.desired_state { for lagacy_name in self.legacy_table_names.iter() { let query = format!( diff --git a/src/lib_context.rs b/src/lib_context.rs index 5f8ebd25..76a56a6a 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -6,6 +6,7 @@ use crate::settings; use crate::setup; use crate::{builder::AnalyzedFlow, execution::query::SimpleSemanticsQueryHandler}; use axum::http::StatusCode; +use sqlx::postgres::PgConnectOptions; use sqlx::PgPool; use std::collections::BTreeMap; use tokio::runtime::Runtime; @@ -61,8 +62,40 @@ impl FlowContext { static TOKIO_RUNTIME: LazyLock = LazyLock::new(|| Runtime::new().unwrap()); static AUTH_REGISTRY: LazyLock> = LazyLock::new(|| Arc::new(AuthRegistry::new())); +#[derive(Default)] +pub struct DbPools { + pub pools: Mutex), Arc>>>, +} + +impl DbPools { + pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result { + let db_pool_cell = { + let key = (conn_spec.uri.clone(), conn_spec.user.clone()); + let mut db_pools = self.pools.lock().unwrap(); + db_pools.entry(key).or_default().clone() + }; + let pool = db_pool_cell + .get_or_try_init(|| async move { + let mut pg_options: PgConnectOptions = conn_spec.uri.parse()?; + if let Some(user) = &conn_spec.user { + pg_options = pg_options.username(user); + } + if let Some(password) = &conn_spec.password { + pg_options = pg_options.password(password); + } + let pool = PgPool::connect_with(pg_options) + .await + .context("Failed to connect to database")?; + anyhow::Ok(pool) + }) + .await?; + Ok(pool.clone()) + } +} + pub struct LibContext { - pub pool: PgPool, + pub db_pools: DbPools, + pub builtin_db_pool: PgPool, pub flows: Mutex>>, pub all_setup_states: RwLock>, } @@ -100,13 +133,15 @@ pub fn create_lib_context(settings: settings::Settings) -> Result { pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap(); }); + let db_pools = DbPools::default(); let (pool, all_setup_states) = get_runtime().block_on(async { - let pool = PgPool::connect(&settings.database_url).await?; + let pool = db_pools.get_pool(&settings.database).await?; let existing_ss = setup::get_existing_setup_state(&pool).await?; anyhow::Ok((pool, existing_ss)) })?; Ok(LibContext { - pool, + db_pools, + builtin_db_pool: pool, all_setup_states: RwLock::new(all_setup_states), flows: Mutex::new(BTreeMap::new()), }) diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index a83ba8b7..cf2aa345 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -420,6 +420,11 @@ impl ExportTargetFactory for T { StorageFactoryBase::describe_resource(self, &key) } + fn normalize_setup_key(&self, key: serde_json::Value) -> Result { + let key: T::Key = serde_json::from_value(key.clone())?; + Ok(serde_json::to_value(key)?) + } + fn check_state_compatibility( &self, desired_state: &serde_json::Value, diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 4c26e431..3598a3a7 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -198,6 +198,10 @@ pub trait ExportTargetFactory: Send + Sync { auth_registry: &Arc, ) -> Result>; + /// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed. + /// This should always return the canonical serialized form. + fn normalize_setup_key(&self, key: serde_json::Value) -> Result; + fn check_state_compatibility( &self, desired_state: &serde_json::Value, diff --git a/src/ops/registration.rs b/src/ops/registration.rs index 515a93f9..50c086d6 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -13,7 +13,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result functions::split_recursively::Factory.register(registry)?; functions::extract_by_llm::Factory.register(registry)?; - Arc::new(storages::postgres::Factory::default()).register(registry)?; + storages::postgres::Factory::default().register(registry)?; Arc::new(storages::qdrant::Factory::default()).register(registry)?; storages::neo4j::Factory::new().register(registry)?; diff --git a/src/ops/storages/postgres.rs b/src/ops/storages/postgres.rs index d5254ed6..7d0818cf 100644 --- a/src/ops/storages/postgres.rs +++ b/src/ops/storages/postgres.rs @@ -2,13 +2,11 @@ use crate::prelude::*; use crate::base::spec::{self, *}; use crate::ops::sdk::*; -use crate::service::error::{shared_ok, SharedError, SharedResultExt}; +use crate::settings::DatabaseConnectionSpec; use crate::setup; use crate::utils::db::ValidIdentifier; use async_trait::async_trait; use bytes::Bytes; -use derivative::Derivative; -use futures::future::{BoxFuture, Shared}; use futures::FutureExt; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -21,7 +19,7 @@ use uuid::Uuid; #[derive(Debug, Deserialize)] pub struct Spec { - database_url: Option, + database: Option>, table_name: Option, } const BIND_LIMIT: usize = 65535; @@ -264,7 +262,8 @@ fn from_pg_value(row: &PgRow, field_idx: usize, typ: &ValueType) -> Result, + db_ref: Option>, + db_pool: PgPool, table_name: ValidIdentifier, key_fields_schema: Vec, value_fields_schema: Vec, @@ -277,7 +276,8 @@ pub struct ExportContext { impl ExportContext { fn new( - database_url: Option, + db_ref: Option>, + db_pool: PgPool, table_name: String, key_fields_schema: Vec, value_fields_schema: Vec, @@ -305,7 +305,8 @@ impl ExportContext { .collect::>(); let table_name = ValidIdentifier::try_from(table_name)?; Ok(Self { - database_url, + db_ref, + db_pool, key_fields_schema, value_fields_schema, all_fields_comma_separated: all_fields @@ -463,30 +464,21 @@ fn distance_to_similarity(metric: VectorSimilarityMetric, distance: f64) -> f64 } } -pub struct Factory { - db_pools: - Mutex, Shared>>>>, -} - -impl Default for Factory { - fn default() -> Self { - Self { - db_pools: Mutex::new(HashMap::new()), - } - } -} +#[derive(Default)] +pub struct Factory {} #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct TableId { - database_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + database: Option>, table_name: String, } impl std::fmt::Display for TableId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.table_name)?; - if let Some(database_url) = &self.database_url { - write!(f, " (database: {})", database_url)?; + if let Some(database) = &self.database { + write!(f, " (database: {database})")?; } Ok(()) } @@ -582,12 +574,10 @@ impl TableSetupAction { } } -#[derive(Derivative)] -#[derivative(Debug)] +#[derive(Debug)] pub struct SetupStatusCheck { - #[derivative(Debug = "ignore")] - factory: Arc, - table_id: TableId, + db_pool: PgPool, + table_name: String, desired_state: Option, drop_existing: bool, @@ -597,8 +587,8 @@ pub struct SetupStatusCheck { impl SetupStatusCheck { fn new( - factory: Arc, - table_id: TableId, + db_pool: PgPool, + table_name: String, desired_state: Option, existing: setup::CombinedState, ) -> Self { @@ -676,8 +666,8 @@ impl SetupStatusCheck { && !existing.current.map(|s| s.uses_pgvector()).unwrap_or(false); Self { - factory, - table_id, + db_pool, + table_name, desired_state, drop_existing, create_pgvector_extension, @@ -832,25 +822,21 @@ impl setup::ResourceSetupStatusCheck for SetupStatusCheck { } async fn apply_change(&self) -> Result<()> { - let db_pool = self - .factory - .get_db_pool(&self.table_id.database_url) - .await?; - let table_name = &self.table_id.table_name; + let table_name = &self.table_name; if self.drop_existing { sqlx::query(&format!("DROP TABLE IF EXISTS {table_name}")) - .execute(&db_pool) + .execute(&self.db_pool) .await?; } if self.create_pgvector_extension { sqlx::query("CREATE EXTENSION IF NOT EXISTS vector;") - .execute(&db_pool) + .execute(&self.db_pool) .await?; } if let Some(desired_table_setup) = &self.desired_table_setup { for index_name in desired_table_setup.indexes_to_delete.iter() { let sql = format!("DROP INDEX IF EXISTS {}", index_name); - sqlx::query(&sql).execute(&db_pool).await?; + sqlx::query(&sql).execute(&self.db_pool).await?; } match &desired_table_setup.table_upsertion { TableUpsertionAction::Create { keys, values } => { @@ -867,7 +853,7 @@ impl setup::ResourceSetupStatusCheck for SetupStatusCheck { fields.join(", "), keys.keys().join(", ") ); - sqlx::query(&sql).execute(&db_pool).await?; + sqlx::query(&sql).execute(&self.db_pool).await?; } TableUpsertionAction::Update { columns_to_delete, @@ -877,33 +863,47 @@ impl setup::ResourceSetupStatusCheck for SetupStatusCheck { let sql = format!( "ALTER TABLE {table_name} DROP COLUMN IF EXISTS {column_name}", ); - sqlx::query(&sql).execute(&db_pool).await?; + sqlx::query(&sql).execute(&self.db_pool).await?; } for (column_name, column_type) in columns_to_upsert.iter() { let sql = format!( "ALTER TABLE {table_name} DROP COLUMN IF EXISTS {column_name}, ADD COLUMN {column_name} {}", to_column_type_sql(column_type) ); - sqlx::query(&sql).execute(&db_pool).await?; + sqlx::query(&sql).execute(&self.db_pool).await?; } } } for (index_name, index_spec) in desired_table_setup.indexes_to_create.iter() { let sql = format!( - "CREATE INDEX IF NOT EXISTS {} ON {} {}", - index_name, - self.table_id.table_name, + "CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} {}", to_index_spec_sql(index_spec) ); - sqlx::query(&sql).execute(&db_pool).await?; + sqlx::query(&sql).execute(&self.db_pool).await?; } } Ok(()) } } +async fn get_db_pool( + db_ref: Option<&spec::AuthEntryReference>, + auth_registry: &AuthRegistry, +) -> Result { + let lib_context = get_lib_context()?; + let db_conn_spec = db_ref + .as_ref() + .map(|db_ref| auth_registry.get(db_ref)) + .transpose()?; + let db_pool = match db_conn_spec { + Some(db_conn_spec) => lib_context.db_pools.get_pool(&db_conn_spec).await?, + None => lib_context.builtin_db_pool.clone(), + }; + Ok(db_pool) +} + #[async_trait] -impl StorageFactoryBase for Arc { +impl StorageFactoryBase for Factory { type Spec = Spec; type DeclarationSpec = (); type SetupState = SetupState; @@ -927,7 +927,7 @@ impl StorageFactoryBase for Arc { .into_iter() .map(|d| { let table_id = TableId { - database_url: d.spec.database_url.clone(), + database: d.spec.database.clone(), table_name: d .spec .table_name @@ -940,16 +940,19 @@ impl StorageFactoryBase for Arc { &d.index_options, ); let table_name = table_id.table_name.clone(); - let export_context = Arc::new(ExportContext::new( - d.spec.database_url.clone(), - table_name, - d.key_fields_schema, - d.value_fields_schema, - )?); - let factory = self.clone(); + let db_ref = d.spec.database; + let auth_registry = context.auth_registry.clone(); let executors = async move { + let db_pool = get_db_pool(db_ref.as_ref(), &auth_registry).await?; + let export_context = Arc::new(ExportContext::new( + db_ref, + db_pool.clone(), + table_name, + d.key_fields_schema, + d.value_fields_schema, + )?); let query_target = Arc::new(PostgresQueryTarget { - db_pool: factory.get_db_pool(&d.spec.database_url).await?, + db_pool, context: export_context.clone(), }); Ok(TypedExportTargetExecutors { @@ -972,9 +975,14 @@ impl StorageFactoryBase for Arc { key: TableId, desired: Option, existing: setup::CombinedState, - _auth_registry: &Arc, + auth_registry: &Arc, ) -> Result { - Ok(SetupStatusCheck::new(self.clone(), key, desired, existing)) + Ok(SetupStatusCheck::new( + get_db_pool(key.database.as_ref(), auth_registry).await?, + key.table_name, + desired, + existing, + )) } fn check_state_compatibility( @@ -1011,15 +1019,19 @@ impl StorageFactoryBase for Arc { &self, mutations: Vec>, ) -> Result<()> { - let mut mut_groups_by_db_url = HashMap::new(); + let mut mut_groups_by_db_ref = HashMap::new(); for mutation in mutations.iter() { - mut_groups_by_db_url - .entry(mutation.export_context.database_url.clone()) + mut_groups_by_db_ref + .entry(mutation.export_context.db_ref.clone()) .or_insert_with(Vec::new) .push(mutation); } - for (db_url, mut_groups) in mut_groups_by_db_url.iter() { - let db_pool = self.get_db_pool(db_url).await?; + for mut_groups in mut_groups_by_db_ref.values() { + let db_pool = &mut_groups + .first() + .ok_or_else(|| anyhow!("empty group"))? + .export_context + .db_pool; let mut txn = db_pool.begin().await?; for mut_group in mut_groups.iter() { mut_group @@ -1038,29 +1050,3 @@ impl StorageFactoryBase for Arc { Ok(()) } } - -impl Factory { - async fn get_db_pool(&self, database_url: &Option) -> Result { - let pool_fut = { - let mut db_pools = self.db_pools.lock().unwrap(); - if let Some(shared_fut) = db_pools.get(database_url) { - shared_fut.clone() - } else { - let pool_fut = { - let database_url = database_url.clone(); - async move { - shared_ok(if let Some(database_url) = database_url { - PgPool::connect(&database_url).await? - } else { - get_lib_context().map_err(SharedError::new)?.pool.clone() - }) - } - }; - let shared_fut = pool_fut.boxed().shared(); - db_pools.insert(database_url.clone(), shared_fut.clone()); - shared_fut - } - }; - Ok(pool_fut.await.std_result()?) - } -} diff --git a/src/py/mod.rs b/src/py/mod.rs index 2c457924..3e21262f 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -126,7 +126,7 @@ impl FlowLiveUpdater { future_into_py(py, async move { let live_updater = execution::FlowLiveUpdater::start( flow, - &get_lib_context().into_py_result()?.pool, + &get_lib_context().into_py_result()?.builtin_db_pool, options.into_inner(), ) .await @@ -185,7 +185,7 @@ impl Flow { &exec_plan, &self.0.flow.data_schema, options.into_inner(), - &get_lib_context()?.pool, + &get_lib_context()?.builtin_db_pool, ) .await }) @@ -348,7 +348,7 @@ fn apply_setup_changes(py: Python<'_>, setup_status: &SetupStatusCheck) -> PyRes setup::apply_changes( &mut std::io::stdout(), &setup_status.0, - &get_lib_context()?.pool, + &get_lib_context()?.builtin_db_pool, ) .await }) diff --git a/src/service/flows.rs b/src/service/flows.rs index b65398af..18816bc2 100644 --- a/src/service/flows.rs +++ b/src/service/flows.rs @@ -155,7 +155,7 @@ pub async fn evaluate_data( enable_cache: true, evaluation_only: true, }, - &lib_context.pool, + &lib_context.builtin_db_pool, ) .await? .ok_or_else(|| api_error!("value not found for source at the specified key: {key:?}"))?; @@ -173,7 +173,7 @@ pub async fn update( let flow_ctx = lib_context.get_flow_context(&flow_name)?; let mut live_updater = execution::FlowLiveUpdater::start( flow_ctx.clone(), - &lib_context.pool, + &lib_context.builtin_db_pool, execution::FlowLiveUpdaterOptions { live_mode: false, ..Default::default() diff --git a/src/settings.rs b/src/settings.rs index 2644f142..33cd90c2 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,6 +1,13 @@ use serde::Deserialize; +#[derive(Deserialize, Debug)] +pub struct DatabaseConnectionSpec { + pub uri: String, + pub user: Option, + pub password: Option, +} + #[derive(Deserialize, Debug)] pub struct Settings { - pub database_url: String, + pub database: DatabaseConnectionSpec, } diff --git a/src/setup/db_metadata.rs b/src/setup/db_metadata.rs index 83dc8614..6e64949f 100644 --- a/src/setup/db_metadata.rs +++ b/src/setup/db_metadata.rs @@ -329,7 +329,7 @@ impl ResourceSetupStatusCheck for MetadataTableSetup { if !self.metadata_table_missing { return Ok(()); } - let pool = &get_lib_context()?.pool; + let pool = &get_lib_context()?.builtin_db_pool; let query_str = format!( "CREATE TABLE IF NOT EXISTS {SETUP_METADATA_TABLE_NAME} ( flow_name TEXT NOT NULL, diff --git a/src/setup/driver.rs b/src/setup/driver.rs index e2650426..f5e2b52e 100644 --- a/src/setup/driver.rs +++ b/src/setup/driver.rs @@ -109,9 +109,18 @@ pub async fn get_existing_setup_state(pool: &PgPool) -> Result { + let normalized_key = { + let registry = executor_factory_registry(); + match registry.get(&target_type) { + Some(ExecutorFactory::ExportTarget(factory)) => { + factory.normalize_setup_key(metadata_record.key)? + } + _ => metadata_record.key.clone(), + } + }; flow_ss.targets.insert( super::ResourceIdentifier { - key: metadata_record.key.clone(), + key: normalized_key, target_kind: target_type, }, from_metadata_record(state, staging_changes)?,