Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions python/cocoindex/lib.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Library level functions and states.
"""
import asyncio
import os
import sys
import functools
Expand All @@ -12,6 +11,7 @@

from . import _engine
from . import flow, query, cli
from .convert import dump_engine_object


def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
Expand All @@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
else:
target[name] = value

@dataclass
class DatabaseConnectionSpec:
uri: str
user: str | None = None
password: str | None = None

@dataclass
class Settings:
"""Settings for the cocoindex library."""
database_url: str
database: DatabaseConnectionSpec

@classmethod
def from_env(cls) -> Self:
"""Load settings from environment variables."""

kwargs: dict[str, str] = dict()
_load_field(kwargs, "database_url", "COCOINDEX_DATABASE_URL", required=True)

return cls(**kwargs)
db_kwargs: dict[str, str] = dict()
_load_field(db_kwargs, "uri", "COCOINDEX_DATABASE_URL", required=True)
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
database = DatabaseConnectionSpec(**db_kwargs)
return cls(database=database)


def init(settings: Settings):
"""Initialize the cocoindex library."""
_engine.init(settings.__dict__)
_engine.init(dump_engine_object(settings))

@dataclass
class ServerSettings:
Expand Down
12 changes: 9 additions & 3 deletions src/base/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,17 @@ impl<T> Clone for AuthEntryReference<T> {
}
}

#[derive(Serialize, Deserialize)]
struct UntypedAuthEntryReference<T> {
key: T,
}

impl<T> Serialize for AuthEntryReference<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.key.serialize(serializer)
UntypedAuthEntryReference { key: &self.key }.serialize(serializer)
}
}

Expand All @@ -336,8 +341,9 @@ impl<'de, T> Deserialize<'de> for AuthEntryReference<T> {
where
D: serde::Deserializer<'de>,
{
Ok(Self {
key: String::deserialize(deserializer)?,
let untyped_ref = UntypedAuthEntryReference::<String>::deserialize(deserializer)?;
Ok(AuthEntryReference {
key: untyped_ref.key,
_phantom: std::marker::PhantomData,
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/execution/db_tracking_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl ResourceSetupStatusCheck for TrackingTableSetupStatusCheck {
}

async fn apply_change(&self) -> Result<()> {
let pool = &get_lib_context()?.pool;
let pool = &get_lib_context()?.builtin_db_pool;
if let Some(desired) = &self.desired_state {
for lagacy_name in self.legacy_table_names.iter() {
let query = format!(
Expand Down
41 changes: 38 additions & 3 deletions src/lib_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::settings;
use crate::setup;
use crate::{builder::AnalyzedFlow, execution::query::SimpleSemanticsQueryHandler};
use axum::http::StatusCode;
use sqlx::postgres::PgConnectOptions;
use sqlx::PgPool;
use std::collections::BTreeMap;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -61,8 +62,40 @@ impl FlowContext {
static TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
static AUTH_REGISTRY: LazyLock<Arc<AuthRegistry>> = LazyLock::new(|| Arc::new(AuthRegistry::new()));

#[derive(Default)]
pub struct DbPools {
pub pools: Mutex<HashMap<(String, Option<String>), Arc<tokio::sync::OnceCell<PgPool>>>>,
}

impl DbPools {
pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result<PgPool> {
let db_pool_cell = {
let key = (conn_spec.uri.clone(), conn_spec.user.clone());
let mut db_pools = self.pools.lock().unwrap();
db_pools.entry(key).or_default().clone()
};
let pool = db_pool_cell
.get_or_try_init(|| async move {
let mut pg_options: PgConnectOptions = conn_spec.uri.parse()?;
if let Some(user) = &conn_spec.user {
pg_options = pg_options.username(user);
}
if let Some(password) = &conn_spec.password {
pg_options = pg_options.password(password);
}
let pool = PgPool::connect_with(pg_options)
.await
.context("Failed to connect to database")?;
anyhow::Ok(pool)
})
.await?;
Ok(pool.clone())
}
}

pub struct LibContext {
pub pool: PgPool,
pub db_pools: DbPools,
pub builtin_db_pool: PgPool,
pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
pub all_setup_states: RwLock<setup::AllSetupState<setup::ExistingMode>>,
}
Expand Down Expand Up @@ -100,13 +133,15 @@ pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap();
});

let db_pools = DbPools::default();
let (pool, all_setup_states) = get_runtime().block_on(async {
let pool = PgPool::connect(&settings.database_url).await?;
let pool = db_pools.get_pool(&settings.database).await?;
let existing_ss = setup::get_existing_setup_state(&pool).await?;
anyhow::Ok((pool, existing_ss))
})?;
Ok(LibContext {
pool,
db_pools,
builtin_db_pool: pool,
all_setup_states: RwLock::new(all_setup_states),
flows: Mutex::new(BTreeMap::new()),
})
Expand Down
5 changes: 5 additions & 0 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ impl<T: StorageFactoryBase> ExportTargetFactory for T {
StorageFactoryBase::describe_resource(self, &key)
}

fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value> {
let key: T::Key = serde_json::from_value(key.clone())?;
Ok(serde_json::to_value(key)?)
}

fn check_state_compatibility(
&self,
desired_state: &serde_json::Value,
Expand Down
4 changes: 4 additions & 0 deletions src/ops/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ pub trait ExportTargetFactory: Send + Sync {
auth_registry: &Arc<AuthRegistry>,
) -> Result<Box<dyn setup::ResourceSetupStatusCheck>>;

/// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed.
/// This should always return the canonical serialized form.
fn normalize_setup_key(&self, key: serde_json::Value) -> Result<serde_json::Value>;

fn check_state_compatibility(
&self,
desired_state: &serde_json::Value,
Expand Down
2 changes: 1 addition & 1 deletion src/ops/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result
functions::split_recursively::Factory.register(registry)?;
functions::extract_by_llm::Factory.register(registry)?;

Arc::new(storages::postgres::Factory::default()).register(registry)?;
storages::postgres::Factory::default().register(registry)?;
Arc::new(storages::qdrant::Factory::default()).register(registry)?;

storages::neo4j::Factory::new().register(registry)?;
Expand Down
Loading