Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions python/cocoindex/targets/lancedb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import datetime
import dataclasses
import threading
import uuid
import weakref
import datetime

from typing import Any

import lancedb # type: ignore
Expand All @@ -21,7 +24,6 @@

@dataclasses.dataclass
class DatabaseOptions:
read_consistency_interval: datetime.timedelta | None = None
storage_options: dict[str, Any] | None = None


Expand All @@ -45,17 +47,33 @@ class _TableKey:
table_name: str


async def _open_db(
db_uri: str, db_options: DatabaseOptions | None
) -> lancedb.AsyncConnection:
db_options = db_options or DatabaseOptions()
_DbConnectionsLock = threading.Lock()
_DbConnections: weakref.WeakValueDictionary[str, lancedb.AsyncConnection] = (
weakref.WeakValueDictionary()
)


# TODO: reuse cached connections
return await lancedb.connect_async(
db_uri,
read_consistency_interval=db_options.read_consistency_interval,
storage_options=db_options.storage_options,
)
async def connect_async(
db_uri: str,
*,
db_options: DatabaseOptions | None = None,
read_consistency_interval: datetime.timedelta | None = None,
) -> lancedb.AsyncConnection:
"""
Helper function to connect to a LanceDB database.
It will reuse the connection if it already exists.
The connection will be shared with the target used by cocoindex, so it achieves strong consistency.
"""
with _DbConnectionsLock:
conn = _DbConnections.get(db_uri)
if conn is None:
db_options = db_options or DatabaseOptions()
_DbConnections[db_uri] = conn = await lancedb.connect_async(
db_uri,
storage_options=db_options.storage_options,
read_consistency_interval=read_consistency_interval,
)
return conn


def make_pa_schema(
Expand Down Expand Up @@ -262,7 +280,7 @@ async def apply_setup_change(
latest_state = current or previous
if not latest_state:
return
db_conn = await _open_db(key.db_uri, latest_state.db_options)
db_conn = await connect_async(key.db_uri, db_options=latest_state.db_options)

reuse_table = (
previous is not None
Expand Down Expand Up @@ -291,7 +309,7 @@ async def prepare(
spec: LanceDB,
setup_state: _State,
) -> _MutateContext:
db_conn = await _open_db(spec.db_uri, spec.db_options)
db_conn = await connect_async(spec.db_uri, db_options=spec.db_options)
table = await db_conn.open_table(spec.table_name)
return _MutateContext(
table=table,
Expand Down
Loading