Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/docs/targets/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ The spec takes the following fields:

* `table_name` (`str`, optional): The name of the table to store to. If unspecified, will use the table name `[${AppNamespace}__]${FlowName}__${TargetName}`, e.g. `DemoFlow__doc_embeddings` or `Staging__DemoFlow__doc_embeddings`.

* `schema` (`str`, optional): The PostgreSQL schema to create the table in. If unspecified, the table will be created in the default schema (usually `public`). When specified, `table_name` must also be explicitly specified. CocoIndex will automatically create the schema if it doesn't exist.

## Example
<ExampleButton
href="https://github.com/cocoindex-io/cocoindex/tree/main/examples/text_embedding"
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/targets/_engine_builtin_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Postgres(op.TargetSpec):

database: AuthEntryReference[DatabaseConnectionSpec] | None = None
table_name: str | None = None
schema: str | None = None


@dataclass
Expand Down
48 changes: 41 additions & 7 deletions src/ops/targets/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::ops::Bound;
pub struct Spec {
database: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
table_name: Option<String>,
schema: Option<String>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at least the field also needs to be added to this dataclass in Python SDK:

class Postgres(op.TargetSpec):
"""Target powered by Postgres and pgvector."""
database: AuthEntryReference[DatabaseConnectionSpec] | None = None
table_name: str | None = None

So that users will be able to use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    database: AuthEntryReference[DatabaseConnectionSpec] | None = None
    table_name: str | None = None
    schema: str | None = None

}
const BIND_LIMIT: usize = 65535;

Expand Down Expand Up @@ -143,10 +144,12 @@ impl ExportContext {
fn new(
db_ref: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
db_pool: PgPool,
table_name: String,
table_id: &TableId,
key_fields_schema: Box<[FieldSchema]>,
value_fields_schema: Vec<FieldSchema>,
) -> Result<Self> {
let table_name = qualified_table_name(table_id);

let key_fields = key_fields_schema
.iter()
.map(|f| format!("\"{}\"", f.name))
Expand Down Expand Up @@ -255,12 +258,18 @@ pub struct Factory {}
pub struct TableId {
#[serde(skip_serializing_if = "Option::is_none")]
database: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
schema: Option<String>,
table_name: String,
}

impl std::fmt::Display for TableId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.table_name)?;
if let Some(schema) = &self.schema {
write!(f, "{}.{}", schema, self.table_name)?;
} else {
write!(f, "{}", self.table_name)?;
}
if let Some(database) = &self.database {
write!(f, " (database: {database})")?;
}
Expand Down Expand Up @@ -345,6 +354,13 @@ fn to_column_type_sql(column_type: &ValueType) -> String {
}
}

fn qualified_table_name(table_id: &TableId) -> String {
match &table_id.schema {
Some(schema) => format!("\"{}\".{}", schema, table_id.table_name),
None => table_id.table_name.clone(),
}
}

impl<'a> From<&'a SetupState> for Cow<'a, TableColumnsSchema<String>> {
fn from(val: &'a SetupState) -> Self {
Cow::Owned(TableColumnsSchema {
Expand Down Expand Up @@ -554,7 +570,9 @@ impl setup::ResourceSetupChange for SetupChange {
}

impl SetupChange {
async fn apply_change(&self, db_pool: &PgPool, table_name: &str) -> Result<()> {
async fn apply_change(&self, db_pool: &PgPool, table_id: &TableId) -> Result<()> {
let table_name = qualified_table_name(table_id);

if self.actions.table_action.drop_existing {
sqlx::query(&format!("DROP TABLE IF EXISTS {table_name}"))
.execute(db_pool)
Expand All @@ -572,6 +590,12 @@ impl SetupChange {
if let Some(table_upsertion) = &self.actions.table_action.table_upsertion {
match table_upsertion {
TableUpsertionAction::Create { keys, values } => {
// Create schema if specified
if let Some(schema) = &table_id.schema {
let sql = format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", schema);
sqlx::query(&sql).execute(db_pool).await?;
}

let mut fields = (keys
.iter()
.map(|(name, typ)| format!("\"{name}\" {typ} NOT NULL")))
Expand Down Expand Up @@ -638,8 +662,18 @@ impl TargetFactoryBase for Factory {
let data_coll_output = data_collections
.into_iter()
.map(|d| {
// Validate: if schema is specified, table_name must be explicit
if d.spec.schema.is_some() && d.spec.table_name.is_none() {
bail!(
"Postgres target '{}': when 'schema' is specified, 'table_name' must also be explicitly provided. \
Auto-generated table names are not supported with custom schemas",
d.name
);
}

let table_id = TableId {
database: d.spec.database.clone(),
schema: d.spec.schema.clone(),
table_name: d.spec.table_name.unwrap_or_else(|| {
utils::db::sanitize_identifier(&format!(
"{}__{}",
Expand All @@ -653,15 +687,15 @@ impl TargetFactoryBase for Factory {
&d.value_fields_schema,
&d.index_options,
);
let table_name = table_id.table_name.clone();
let table_id_clone = table_id.clone();
let db_ref = d.spec.database;
let auth_registry = context.auth_registry.clone();
let export_context = Box::pin(async move {
let db_pool = get_db_pool(db_ref.as_ref(), &auth_registry).await?;
let export_context = Arc::new(ExportContext::new(
db_ref,
db_pool.clone(),
table_name,
&table_id_clone,
d.key_fields_schema,
d.value_fields_schema,
)?);
Expand Down Expand Up @@ -699,7 +733,7 @@ impl TargetFactoryBase for Factory {
}

fn describe_resource(&self, key: &TableId) -> Result<String> {
Ok(format!("Postgres table {}", key.table_name))
Ok(format!("Postgres table {}", key))
}

async fn apply_mutation(
Expand Down Expand Up @@ -746,7 +780,7 @@ impl TargetFactoryBase for Factory {
let db_pool = get_db_pool(change.key.database.as_ref(), &context.auth_registry).await?;
change
.setup_change
.apply_change(&db_pool, &change.key.table_name)
.apply_change(&db_pool, &change.key)
.await?;
}
Ok(())
Expand Down