-
Notifications
You must be signed in to change notification settings - Fork 14k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add Cassandra ByteStore (#22064)
- Loading branch information
Showing
4 changed files
with
350 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
188 changes: 188 additions & 0 deletions
188
libs/community/langchain_community/storage/cassandra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from asyncio import InvalidStateError, Task | ||
from typing import ( | ||
TYPE_CHECKING, | ||
AsyncIterator, | ||
Iterator, | ||
List, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
) | ||
|
||
from langchain_core.stores import ByteStore | ||
|
||
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql | ||
|
||
if TYPE_CHECKING: | ||
from cassandra.cluster import Session | ||
from cassandra.query import PreparedStatement | ||
|
||
CREATE_TABLE_CQL_TEMPLATE = """ | ||
CREATE TABLE IF NOT EXISTS {keyspace}.{table} | ||
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id)); | ||
""" | ||
SELECT_TABLE_CQL_TEMPLATE = ( | ||
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;""" | ||
) | ||
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};""" | ||
INSERT_TABLE_CQL_TEMPLATE = ( | ||
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);""" | ||
) | ||
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;""" | ||
|
||
|
||
class CassandraByteStore(ByteStore): | ||
def __init__( | ||
self, | ||
table: str, | ||
*, | ||
session: Optional[Session] = None, | ||
keyspace: Optional[str] = None, | ||
setup_mode: SetupMode = SetupMode.SYNC, | ||
) -> None: | ||
if not session or not keyspace: | ||
try: | ||
from cassio.config import check_resolve_keyspace, check_resolve_session | ||
|
||
self.keyspace = keyspace or check_resolve_keyspace(keyspace) | ||
self.session = session or check_resolve_session() | ||
except (ImportError, ModuleNotFoundError): | ||
raise ImportError( | ||
"Could not import a recent cassio package." | ||
"Please install it with `pip install --upgrade cassio`." | ||
) | ||
else: | ||
self.keyspace = keyspace | ||
self.session = session | ||
self.table = table | ||
self.select_statement = None | ||
self.insert_statement = None | ||
self.delete_statement = None | ||
|
||
create_cql = CREATE_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, | ||
table=self.table, | ||
) | ||
self.db_setup_task: Optional[Task[None]] = None | ||
if setup_mode == SetupMode.ASYNC: | ||
self.db_setup_task = asyncio.create_task( | ||
aexecute_cql(self.session, create_cql) | ||
) | ||
else: | ||
self.session.execute(create_cql) | ||
|
||
def ensure_db_setup(self) -> None: | ||
if self.db_setup_task: | ||
try: | ||
self.db_setup_task.result() | ||
except InvalidStateError: | ||
raise ValueError( | ||
"Asynchronous setup of the DB not finished. " | ||
"NB: AstraDB components sync methods shouldn't be called from the " | ||
"event loop. Consider using their async equivalents." | ||
) | ||
|
||
async def aensure_db_setup(self) -> None: | ||
if self.db_setup_task: | ||
await self.db_setup_task | ||
|
||
def get_select_statement(self) -> PreparedStatement: | ||
if not self.select_statement: | ||
self.select_statement = self.session.prepare( | ||
SELECT_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, table=self.table | ||
) | ||
) | ||
return self.select_statement | ||
|
||
def get_insert_statement(self) -> PreparedStatement: | ||
if not self.insert_statement: | ||
self.insert_statement = self.session.prepare( | ||
INSERT_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, table=self.table | ||
) | ||
) | ||
return self.insert_statement | ||
|
||
def get_delete_statement(self) -> PreparedStatement: | ||
if not self.delete_statement: | ||
self.delete_statement = self.session.prepare( | ||
DELETE_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, table=self.table | ||
) | ||
) | ||
return self.delete_statement | ||
|
||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: | ||
from cassandra.query import ValueSequence | ||
|
||
self.ensure_db_setup() | ||
docs_dict = {} | ||
for row in self.session.execute( | ||
self.get_select_statement(), [ValueSequence(keys)] | ||
): | ||
docs_dict[row.row_id] = row.body_blob | ||
return [docs_dict.get(key) for key in keys] | ||
|
||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: | ||
from cassandra.query import ValueSequence | ||
|
||
await self.aensure_db_setup() | ||
docs_dict = {} | ||
for row in await aexecute_cql( | ||
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)] | ||
): | ||
docs_dict[row.row_id] = row.body_blob | ||
return [docs_dict.get(key) for key in keys] | ||
|
||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: | ||
self.ensure_db_setup() | ||
insert_statement = self.get_insert_statement() | ||
for k, v in key_value_pairs: | ||
self.session.execute(insert_statement, (k, v)) | ||
|
||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: | ||
await self.aensure_db_setup() | ||
insert_statement = self.get_insert_statement() | ||
for k, v in key_value_pairs: | ||
await aexecute_cql(self.session, insert_statement, parameters=(k, v)) | ||
|
||
def mdelete(self, keys: Sequence[str]) -> None: | ||
from cassandra.query import ValueSequence | ||
|
||
self.ensure_db_setup() | ||
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)]) | ||
|
||
async def amdelete(self, keys: Sequence[str]) -> None: | ||
from cassandra.query import ValueSequence | ||
|
||
await self.aensure_db_setup() | ||
await aexecute_cql( | ||
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)] | ||
) | ||
|
||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | ||
self.ensure_db_setup() | ||
for row in self.session.execute( | ||
SELECT_ALL_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, table=self.table | ||
) | ||
): | ||
key = row.row_id | ||
if not prefix or key.startswith(prefix): | ||
yield key | ||
|
||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: | ||
await self.aensure_db_setup() | ||
for row in await aexecute_cql( | ||
self.session, | ||
SELECT_ALL_TABLE_CQL_TEMPLATE.format( | ||
keyspace=self.keyspace, table=self.table | ||
), | ||
): | ||
key = row.row_id | ||
if not prefix or key.startswith(prefix): | ||
yield key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
155 changes: 155 additions & 0 deletions
155
libs/community/tests/integration_tests/storage/test_cassandra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
"""Implement integration tests for Cassandra storage.""" | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
|
||
from langchain_community.storage.cassandra import CassandraByteStore | ||
from langchain_community.utilities.cassandra import SetupMode | ||
|
||
if TYPE_CHECKING: | ||
from cassandra.cluster import Session | ||
|
||
KEYSPACE = "storage_test_keyspace" | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def session() -> Session: | ||
from cassandra.cluster import Cluster | ||
|
||
cluster = Cluster() | ||
session = cluster.connect() | ||
session.execute( | ||
( | ||
f"CREATE KEYSPACE IF NOT EXISTS {KEYSPACE} " | ||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" | ||
) | ||
) | ||
return session | ||
|
||
|
||
def init_store(table_name: str, session: Session) -> CassandraByteStore: | ||
store = CassandraByteStore(table=table_name, keyspace=KEYSPACE, session=session) | ||
store.mset([("key1", b"value1"), ("key2", b"value2")]) | ||
return store | ||
|
||
|
||
async def init_async_store(table_name: str, session: Session) -> CassandraByteStore: | ||
store = CassandraByteStore( | ||
table=table_name, keyspace=KEYSPACE, session=session, setup_mode=SetupMode.ASYNC | ||
) | ||
await store.amset([("key1", b"value1"), ("key2", b"value2")]) | ||
return store | ||
|
||
|
||
def drop_table(table_name: str, session: Session) -> None: | ||
session.execute(f"DROP TABLE {KEYSPACE}.{table_name}") | ||
|
||
|
||
async def test_mget(session: Session) -> None: | ||
"""Test CassandraByteStore mget method.""" | ||
table_name = "lc_test_store_mget" | ||
try: | ||
store = init_store(table_name, session) | ||
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"] | ||
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"] | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
async def test_amget(session: Session) -> None: | ||
"""Test CassandraByteStore amget method.""" | ||
table_name = "lc_test_store_amget" | ||
try: | ||
store = await init_async_store(table_name, session) | ||
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"] | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
def test_mset(session: Session) -> None: | ||
"""Test that multiple keys can be set with CassandraByteStore.""" | ||
table_name = "lc_test_store_mset" | ||
try: | ||
init_store(table_name, session) | ||
result = session.execute( | ||
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset " | ||
"WHERE row_id = 'key1';" | ||
).one() | ||
assert result.body_blob == b"value1" | ||
result = session.execute( | ||
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset " | ||
"WHERE row_id = 'key2';" | ||
).one() | ||
assert result.body_blob == b"value2" | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
async def test_amset(session: Session) -> None: | ||
"""Test that multiple keys can be set with CassandraByteStore.""" | ||
table_name = "lc_test_store_amset" | ||
try: | ||
await init_async_store(table_name, session) | ||
result = session.execute( | ||
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset " | ||
"WHERE row_id = 'key1';" | ||
).one() | ||
assert result.body_blob == b"value1" | ||
result = session.execute( | ||
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset " | ||
"WHERE row_id = 'key2';" | ||
).one() | ||
assert result.body_blob == b"value2" | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
def test_mdelete(session: Session) -> None: | ||
"""Test that deletion works as expected.""" | ||
table_name = "lc_test_store_mdelete" | ||
try: | ||
store = init_store(table_name, session) | ||
store.mdelete(["key1", "key2"]) | ||
result = store.mget(["key1", "key2"]) | ||
assert result == [None, None] | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
async def test_amdelete(session: Session) -> None: | ||
"""Test that deletion works as expected.""" | ||
table_name = "lc_test_store_amdelete" | ||
try: | ||
store = await init_async_store(table_name, session) | ||
await store.amdelete(["key1", "key2"]) | ||
result = await store.amget(["key1", "key2"]) | ||
assert result == [None, None] | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
def test_yield_keys(session: Session) -> None: | ||
table_name = "lc_test_store_yield_keys" | ||
try: | ||
store = init_store(table_name, session) | ||
assert set(store.yield_keys()) == {"key1", "key2"} | ||
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"} | ||
assert set(store.yield_keys(prefix="lang")) == set() | ||
finally: | ||
drop_table(table_name, session) | ||
|
||
|
||
async def test_ayield_keys(session: Session) -> None: | ||
table_name = "lc_test_store_ayield_keys" | ||
try: | ||
store = await init_async_store(table_name, session) | ||
assert {key async for key in store.ayield_keys()} == {"key1", "key2"} | ||
assert {key async for key in store.ayield_keys(prefix="key")} == { | ||
"key1", | ||
"key2", | ||
} | ||
assert {key async for key in store.ayield_keys(prefix="lang")} == set() | ||
finally: | ||
drop_table(table_name, session) |