Skip to content

Commit

Permalink
community[minor]: Add Cassandra ByteStore (#22064)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed May 23, 2024
1 parent fea6b99 commit 74947ec
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.cassandra import wrapped_response_future
from langchain_community.utilities.cassandra import aexecute_cql

_NOT_SET = object()

Expand Down Expand Up @@ -118,11 +118,7 @@ def lazy_load(self) -> Iterator[Document]:
)

async def alazy_load(self) -> AsyncIterator[Document]:
for row in await wrapped_response_future(
self.session.execute_async,
self.query,
**self.query_kwargs,
):
for row in await aexecute_cql(self.session, self.query, **self.query_kwargs):
metadata = self.metadata.copy()
metadata.update(self.metadata_mapper(row))
yield Document(
Expand Down
188 changes: 188 additions & 0 deletions libs/community/langchain_community/storage/cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import asyncio
from asyncio import InvalidStateError, Task
from typing import (
TYPE_CHECKING,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
Tuple,
)

from langchain_core.stores import ByteStore

from langchain_community.utilities.cassandra import SetupMode, aexecute_cql

if TYPE_CHECKING:
from cassandra.cluster import Session
from cassandra.query import PreparedStatement

CREATE_TABLE_CQL_TEMPLATE = """
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""


class CassandraByteStore(ByteStore):
def __init__(
self,
table: str,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
if not session or not keyspace:
try:
from cassio.config import check_resolve_keyspace, check_resolve_session

self.keyspace = keyspace or check_resolve_keyspace(keyspace)
self.session = session or check_resolve_session()
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent cassio package."
"Please install it with `pip install --upgrade cassio`."
)
else:
self.keyspace = keyspace
self.session = session
self.table = table
self.select_statement = None
self.insert_statement = None
self.delete_statement = None

create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace,
table=self.table,
)
self.db_setup_task: Optional[Task[None]] = None
if setup_mode == SetupMode.ASYNC:
self.db_setup_task = asyncio.create_task(
aexecute_cql(self.session, create_cql)
)
else:
self.session.execute(create_cql)

def ensure_db_setup(self) -> None:
if self.db_setup_task:
try:
self.db_setup_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)

async def aensure_db_setup(self) -> None:
if self.db_setup_task:
await self.db_setup_task

def get_select_statement(self) -> PreparedStatement:
if not self.select_statement:
self.select_statement = self.session.prepare(
SELECT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.select_statement

def get_insert_statement(self) -> PreparedStatement:
if not self.insert_statement:
self.insert_statement = self.session.prepare(
INSERT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.insert_statement

def get_delete_statement(self) -> PreparedStatement:
if not self.delete_statement:
self.delete_statement = self.session.prepare(
DELETE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.delete_statement

def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence

self.ensure_db_setup()
docs_dict = {}
for row in self.session.execute(
self.get_select_statement(), [ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]

async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence

await self.aensure_db_setup()
docs_dict = {}
for row in await aexecute_cql(
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]

def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
self.ensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
self.session.execute(insert_statement, (k, v))

async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
await self.aensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
await aexecute_cql(self.session, insert_statement, parameters=(k, v))

def mdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence

self.ensure_db_setup()
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])

async def amdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence

await self.aensure_db_setup()
await aexecute_cql(
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
)

def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.ensure_db_setup()
for row in self.session.execute(
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key

async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.aensure_db_setup()
for row in await aexecute_cql(
self.session,
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
),
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key
6 changes: 5 additions & 1 deletion libs/community/langchain_community/utilities/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
from cassandra.cluster import ResponseFuture
from cassandra.cluster import ResponseFuture, Session


async def wrapped_response_future(
Expand Down Expand Up @@ -35,6 +35,10 @@ def error_handler(exc: BaseException) -> None:
return await asyncio_future


async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any:
return await wrapped_response_future(session.execute_async, query, **kwargs)


class SetupMode(Enum):
SYNC = 1
ASYNC = 2
Expand Down
155 changes: 155 additions & 0 deletions libs/community/tests/integration_tests/storage/test_cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Implement integration tests for Cassandra storage."""
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from langchain_community.storage.cassandra import CassandraByteStore
from langchain_community.utilities.cassandra import SetupMode

if TYPE_CHECKING:
from cassandra.cluster import Session

KEYSPACE = "storage_test_keyspace"


@pytest.fixture(scope="session")
def session() -> Session:
from cassandra.cluster import Cluster

cluster = Cluster()
session = cluster.connect()
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {KEYSPACE} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
return session


def init_store(table_name: str, session: Session) -> CassandraByteStore:
store = CassandraByteStore(table=table_name, keyspace=KEYSPACE, session=session)
store.mset([("key1", b"value1"), ("key2", b"value2")])
return store


async def init_async_store(table_name: str, session: Session) -> CassandraByteStore:
store = CassandraByteStore(
table=table_name, keyspace=KEYSPACE, session=session, setup_mode=SetupMode.ASYNC
)
await store.amset([("key1", b"value1"), ("key2", b"value2")])
return store


def drop_table(table_name: str, session: Session) -> None:
session.execute(f"DROP TABLE {KEYSPACE}.{table_name}")


async def test_mget(session: Session) -> None:
"""Test CassandraByteStore mget method."""
table_name = "lc_test_store_mget"
try:
store = init_store(table_name, session)
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
drop_table(table_name, session)


async def test_amget(session: Session) -> None:
"""Test CassandraByteStore amget method."""
table_name = "lc_test_store_amget"
try:
store = await init_async_store(table_name, session)
assert await store.amget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
drop_table(table_name, session)


def test_mset(session: Session) -> None:
"""Test that multiple keys can be set with CassandraByteStore."""
table_name = "lc_test_store_mset"
try:
init_store(table_name, session)
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
"WHERE row_id = 'key1';"
).one()
assert result.body_blob == b"value1"
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_mset "
"WHERE row_id = 'key2';"
).one()
assert result.body_blob == b"value2"
finally:
drop_table(table_name, session)


async def test_amset(session: Session) -> None:
"""Test that multiple keys can be set with CassandraByteStore."""
table_name = "lc_test_store_amset"
try:
await init_async_store(table_name, session)
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
"WHERE row_id = 'key1';"
).one()
assert result.body_blob == b"value1"
result = session.execute(
"SELECT row_id, body_blob FROM storage_test_keyspace.lc_test_store_amset "
"WHERE row_id = 'key2';"
).one()
assert result.body_blob == b"value2"
finally:
drop_table(table_name, session)


def test_mdelete(session: Session) -> None:
"""Test that deletion works as expected."""
table_name = "lc_test_store_mdelete"
try:
store = init_store(table_name, session)
store.mdelete(["key1", "key2"])
result = store.mget(["key1", "key2"])
assert result == [None, None]
finally:
drop_table(table_name, session)


async def test_amdelete(session: Session) -> None:
"""Test that deletion works as expected."""
table_name = "lc_test_store_amdelete"
try:
store = await init_async_store(table_name, session)
await store.amdelete(["key1", "key2"])
result = await store.amget(["key1", "key2"])
assert result == [None, None]
finally:
drop_table(table_name, session)


def test_yield_keys(session: Session) -> None:
table_name = "lc_test_store_yield_keys"
try:
store = init_store(table_name, session)
assert set(store.yield_keys()) == {"key1", "key2"}
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
assert set(store.yield_keys(prefix="lang")) == set()
finally:
drop_table(table_name, session)


async def test_ayield_keys(session: Session) -> None:
table_name = "lc_test_store_ayield_keys"
try:
store = await init_async_store(table_name, session)
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
assert {key async for key in store.ayield_keys(prefix="key")} == {
"key1",
"key2",
}
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
finally:
drop_table(table_name, session)

0 comments on commit 74947ec

Please sign in to comment.