Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add llamaindex cassandra tool #439

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions libs/e2e-tests/e2e_tests/llama_index/test_cassandra_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import uuid

import cassio

from llama_index.agent.openai import OpenAIAgent
from llama_index.llms.openai import OpenAI

from llama_index.tools.cassandra.base import CassandraDatabaseToolSpec
from llama_index.tools.cassandra.cassandra_database_wrapper import (
CassandraDatabase,
)


def test_tool_with_openai_tool(cassandra):
session = cassio.config.resolve_session()

session.execute(
"""
CREATE TABLE IF NOT EXISTS default_keyspace.tool_table_users (
user_id UUID PRIMARY KEY ,
user_name TEXT ,
password TEXT
);
"""
)
session.execute(
"""
CREATE INDEX user_name
ON default_keyspace.tool_table_users (user_name);
"""
)

user_id = uuid.uuid4()
session.execute(
f"""
INSERT INTO default_keyspace.tool_table_users (user_id, user_name)
VALUES ({user_id}, 'my_user');
"""
)
db = CassandraDatabase()

spec = CassandraDatabaseToolSpec(db=db)

tools = spec.to_tool_list()
for tool in tools:
print(tool.metadata.name)

llm = OpenAI(model="gpt-4o")
agent = OpenAIAgent.from_tools(tools, llm=llm, verbose=True)

response = agent.chat(
"What is the user_id of the user named 'my_user' in table default_keyspace.tool_table_users?"
)
print(response)
assert response is not None
assert str(user_id) in str(response)
1 change: 1 addition & 0 deletions libs/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ llama-index = "0.10.31"
llama-index-vector-stores-astra-db = "0.1.7"
llama-index-vector-stores-cassandra = "0.1.3"
llama-index-embeddings-langchain = "0.1.2"
llama-index-tools-cassandra = "0.1.1"
llama-parse = "0.4.1"
# optional integrations
## azure
Expand Down
Loading