diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py index 78c598ffa9ab36..8c31df029e6ed0 100644 --- a/libs/community/langchain_community/document_loaders/cassandra.py +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -14,7 +14,7 @@ from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader -from langchain_community.utilities.cassandra import wrapped_response_future +from langchain_community.utilities.cassandra import aexecute_cql _NOT_SET = object() @@ -118,11 +118,7 @@ def lazy_load(self) -> Iterator[Document]: ) async def alazy_load(self) -> AsyncIterator[Document]: - for row in await wrapped_response_future( - self.session.execute_async, - self.query, - **self.query_kwargs, - ): + for row in await aexecute_cql(self.session, self.query, **self.query_kwargs): metadata = self.metadata.copy() metadata.update(self.metadata_mapper(row)) yield Document( diff --git a/libs/community/langchain_community/storage/cassandra.py b/libs/community/langchain_community/storage/cassandra.py new file mode 100644 index 00000000000000..280ce5b3a5a422 --- /dev/null +++ b/libs/community/langchain_community/storage/cassandra.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import asyncio +from asyncio import InvalidStateError, Task +from typing import ( + TYPE_CHECKING, + AsyncIterator, + Iterator, + List, + Optional, + Sequence, + Tuple, +) + +from langchain_core.stores import ByteStore + +from langchain_community.utilities.cassandra import SetupMode, aexecute_cql + +if TYPE_CHECKING: + from cassandra.cluster import Session + from cassandra.query import PreparedStatement + +CREATE_TABLE_CQL_TEMPLATE = """ + CREATE TABLE IF NOT EXISTS {keyspace}.{table} + (row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id)); +""" +SELECT_TABLE_CQL_TEMPLATE = ( + """SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;""" +) +SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};""" +INSERT_TABLE_CQL_TEMPLATE = ( + """INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);""" +) +DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;""" + + +class CassandraByteStore(ByteStore): + def __init__( + self, + table: str, + *, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + ) -> None: + if not session or not keyspace: + try: + from cassio.config import check_resolve_keyspace, check_resolve_session + + self.keyspace = keyspace or check_resolve_keyspace(keyspace) + self.session = session or check_resolve_session() + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent cassio package." + "Please install it with `pip install --upgrade cassio`." + ) + else: + self.keyspace = keyspace + self.session = session + self.table = table + self.select_statement = None + self.insert_statement = None + self.delete_statement = None + + create_cql = CREATE_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, + table=self.table, + ) + self.db_setup_task: Optional[Task[None]] = None + if setup_mode == SetupMode.ASYNC: + self.db_setup_task = asyncio.create_task( + aexecute_cql(self.session, create_cql) + ) + else: + self.session.execute(create_cql) + + def ensure_db_setup(self) -> None: + if self.db_setup_task: + try: + self.db_setup_task.result() + except InvalidStateError: + raise ValueError( + "Asynchronous setup of the DB not finished. " + "NB: AstraDB components sync methods shouldn't be called from the " + "event loop. Consider using their async equivalents." + ) + + async def aensure_db_setup(self) -> None: + if self.db_setup_task: + await self.db_setup_task + + def get_select_statement(self) -> PreparedStatement: + if not self.select_statement: + self.select_statement = self.session.prepare( + SELECT_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, table=self.table + ) + ) + return self.select_statement + + def get_insert_statement(self) -> PreparedStatement: + if not self.insert_statement: + self.insert_statement = self.session.prepare( + INSERT_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, table=self.table + ) + ) + return self.insert_statement + + def get_delete_statement(self) -> PreparedStatement: + if not self.delete_statement: + self.delete_statement = self.session.prepare( + DELETE_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, table=self.table + ) + ) + return self.delete_statement + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + from cassandra.query import ValueSequence + + self.ensure_db_setup() + docs_dict = {} + for row in self.session.execute( + self.get_select_statement(), [ValueSequence(keys)] + ): + docs_dict[row.row_id] = row.body_blob + return [docs_dict.get(key) for key in keys] + + async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + from cassandra.query import ValueSequence + + await self.aensure_db_setup() + docs_dict = {} + for row in await aexecute_cql( + self.session, self.get_select_statement(), parameters=[ValueSequence(keys)] + ): + docs_dict[row.row_id] = row.body_blob + return [docs_dict.get(key) for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + self.ensure_db_setup() + insert_statement = self.get_insert_statement() + for k, v in key_value_pairs: + self.session.execute(insert_statement, (k, v)) + + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + await self.aensure_db_setup() + insert_statement = self.get_insert_statement() + for k, v in key_value_pairs: + await aexecute_cql(self.session, insert_statement, parameters=(k, v)) + + def mdelete(self, keys: Sequence[str]) -> None: + from cassandra.query import ValueSequence + + self.ensure_db_setup() + self.session.execute(self.get_delete_statement(), [ValueSequence(keys)]) + + async def amdelete(self, keys: Sequence[str]) -> None: + from cassandra.query import ValueSequence + + await self.aensure_db_setup() + await aexecute_cql( + self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)] + ) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + self.ensure_db_setup() + for row in self.session.execute( + SELECT_ALL_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, table=self.table + ) + ): + key = row.row_id + if not prefix or key.startswith(prefix): + yield key + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + await self.aensure_db_setup() + for row in await aexecute_cql( + self.session, + SELECT_ALL_TABLE_CQL_TEMPLATE.format( + keyspace=self.keyspace, table=self.table + ), + ): + key = row.row_id + if not prefix or key.startswith(prefix): + yield key diff --git a/libs/community/langchain_community/utilities/cassandra.py b/libs/community/langchain_community/utilities/cassandra.py index cd588508965618..52b0963c896fa5 100644 --- a/libs/community/langchain_community/utilities/cassandra.py +++ b/libs/community/langchain_community/utilities/cassandra.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: - from cassandra.cluster import ResponseFuture + from cassandra.cluster import ResponseFuture, Session async def wrapped_response_future( @@ -35,6 +35,10 @@ def error_handler(exc: BaseException) -> None: return await asyncio_future +async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any: + return await wrapped_response_future(session.execute_async, query, **kwargs) + + class SetupMode(Enum): SYNC = 1 ASYNC = 2 diff --git a/libs/community/tests/integration_tests/storage/test_cassandra.py b/libs/community/tests/integration_tests/storage/test_cassandra.py new file mode 100644 index 00000000000000..88f240ed791714 --- /dev/null +++ b/libs/community/tests/integration_tests/storage/test_cassandra.py @@ -0,0 +1,155 @@ +"""Implement integration tests for Cassandra storage.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from langchain_community.storage.cassandra import CassandraByteStore +from langchain_community.utilities.cassandra import SetupMode + +if TYPE_CHECKING: + from cassandra.cluster import Session + +KEYSPACE = "storage_test_keyspace" + + +@pytest.fixture(scope="session") +def session() -> Session: + from cassandra.cluster import Cluster + + cluster = Cluster() + session = cluster.connect() + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {KEYSPACE} " + f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + ) + ) + return session + + +def init_store(table_name: str, session: Session) -> CassandraByteStore: + store = CassandraByteStore(table=table_name, keyspace=KEYSPACE, session=session) + store.mset([("key1", b"value1"), ("key2", b"value2")]) + return store + + +async def init_async_store(table_name: str, session: Session) -> CassandraByteStore: + store = CassandraByteStore( + table=table_name, keyspace=KEYSPACE, session=session, setup_mode=SetupMode.ASYNC + ) + await store.amset([("key1", b"value1"), ("key2", b"value2")]) + return store + + +def drop_table(table_name: str, session: Session) -> None: + session.execute(f"DROP TABLE {KEYSPACE}.{table_name}") + + +async def test_mget(session: Session) -> None: + """Test CassandraByteStore mget method.""" + table_name = "lc_test_store_mget" + try: + store = init_store(table_name, session) + assert store.mget(["key1", "key2"]) == [b"value1", b"value2"] + assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"] + finally: + drop_table(table_name, session) + + +async def test_amget(session: Session) -> None: + """Test CassandraByteStore amget method.""" + table_name = "lc_test_store_amget" + try: + store = await init_async_store(table_name, session) + assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"] + finally: + drop_table(table_name, session) + + +def test_mset(session: Session) -> None: + """Test that multiple keys can be set with CassandraByteStore.""" + table_name = "lc_test_store_mset" + try: + init_store(table_name, session) + result = session.execute( + "SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset " + "WHERE row_id = 'key1';" + ).one() + assert result.body_blob == b"value1" + result = session.execute( + "SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset " + "WHERE row_id = 'key2';" + ).one() + assert result.body_blob == b"value2" + finally: + drop_table(table_name, session) + + +async def test_amset(session: Session) -> None: + """Test that multiple keys can be set with CassandraByteStore.""" + table_name = "lc_test_store_amset" + try: + await init_async_store(table_name, session) + result = session.execute( + "SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset " + "WHERE row_id = 'key1';" + ).one() + assert result.body_blob == b"value1" + result = session.execute( + "SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset " + "WHERE row_id = 'key2';" + ).one() + assert result.body_blob == b"value2" + finally: + drop_table(table_name, session) + + +def test_mdelete(session: Session) -> None: + """Test that deletion works as expected.""" + table_name = "lc_test_store_mdelete" + try: + store = init_store(table_name, session) + store.mdelete(["key1", "key2"]) + result = store.mget(["key1", "key2"]) + assert result == [None, None] + finally: + drop_table(table_name, session) + + +async def test_amdelete(session: Session) -> None: + """Test that deletion works as expected.""" + table_name = "lc_test_store_amdelete" + try: + store = await init_async_store(table_name, session) + await store.amdelete(["key1", "key2"]) + result = await store.amget(["key1", "key2"]) + assert result == [None, None] + finally: + drop_table(table_name, session) + + +def test_yield_keys(session: Session) -> None: + table_name = "lc_test_store_yield_keys" + try: + store = init_store(table_name, session) + assert set(store.yield_keys()) == {"key1", "key2"} + assert set(store.yield_keys(prefix="key")) == {"key1", "key2"} + assert set(store.yield_keys(prefix="lang")) == set() + finally: + drop_table(table_name, session) + + +async def test_ayield_keys(session: Session) -> None: + table_name = "lc_test_store_ayield_keys" + try: + store = await init_async_store(table_name, session) + assert {key async for key in store.ayield_keys()} == {"key1", "key2"} + assert {key async for key in store.ayield_keys(prefix="key")} == { + "key1", + "key2", + } + assert {key async for key in store.ayield_keys(prefix="lang")} == set() + finally: + drop_table(table_name, session)