<a href="https://colab.research.google.com/github/krishnannarayanaswamy/text2cql-datastax-astra-demo/blob/main/Text2CQL_DataStax_Astra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using LLMs to Generate CQL

Since LLMs seem to excel at a lot of things, we wanted to show how they can be used to generate CQL to query your Cassandra tables. This notebook provides a guide derived from the [SQL-PaLM](https://arxiv.org/abs/2306.00739) paper on how to automatically show the LLM your DB schema, and let it inform the LLM on querying your data.

## Setup

#### Requirements

In [1]:
# Install requirements, if not already installed
!pip install openai cassandra-driver

Collecting openai
  Downloading openai-1.13.3-py3-none-any.whl (227 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.4/227.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cassandra-driver
  Downloading cassandra_driver-3.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.8/18.8 MB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
Collecting httpx<1,>=0.23.0 (from openai)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Collecting geomet<0.3,>=0.1 (from cassandra-driver)
  Downloading geomet-0.2.1.post1-py3-none-any.whl (18 kB)
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)
  Downloading httpcore-1.0.4-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.8/77.8 kB[0m [31m10.4 MB/s[0m eta [36

#### Connect to Services

In [2]:
# Initialize the OpenAI Client
import os

from getpass import getpass
import openai

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")

client = openai.OpenAI()


OpenAI API Key: ··········


In [6]:
# Connect to a Cassandra Cluster and initialize the session
import re

from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from getpass import getpass
from google.colab import files

ASTRA_TOKEN = os.environ.get(
    "ASTRA_DB_TOKEN",
    getpass("Astra DB Token: ")
)

ASTRA_BUNDLE_PATH = os.environ.get(
    "ASTRA_DB_BUNDLE_PATH",
    list(files.upload().keys())[0],
)

ASTRA_KEYSPACE = os.environ.get(
    "ASTRA_DB_KEYSPACE",
    input("Astra DB Keyspace: "),
)

cloud_config = {
    'secure_connect_bundle': ASTRA_BUNDLE_PATH
}
auth_provider = PlainTextAuthProvider("token", ASTRA_TOKEN)

def execute_statement(statement: str):
    # This is a simple wrapper around executing CQL statements in our
    # Cassandra cluster, and either raising an error or returning the results
    try:
        rows = session.execute(statement)
        return rows.all()
    except:
        print(f"Query Failed: {statement}")
        raise


Astra DB Token: ··········


Saving secure-connect-multilingual.zip to secure-connect-multilingual (1).zip
Astra DB Keyspace: fintech


In [46]:
cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
session = cluster.connect(keyspace=ASTRA_KEYSPACE)

ERROR:cassandra.connection:Closing connection <AsyncoreConnection(135265819302864) 49293bf1-9bb8-4c65-a060-cf566066cc00-us-east1.db.astra.datastax.com:29042:b1317c54-97aa-4844-85b6-fa62d7a344ef> due to protocol error: Error from server: code=000a [Protocol error] message="Beta version of the protocol used (5/v5-beta), but USE_BETA flag is unset"


#### (Optional) Dummy DB Setup

Feel free to skip this section if you are instead adapting the notebook to fit your existing Cassandra Database. Here, we will utilize the python `cassandra-driver` package to connect to a DB and create some fake tables. This schema is pulled from [this DataStax example](https://www.datastax.com/learn/data-modeling-by-example/digital-library-data-model) on creating a data model for a digital music library.

In [31]:
# Create all necessary tables
create_tables_cql = """CREATE TABLE IF NOT EXISTS customerprofile (
    client_id INT,
    surname TEXT,
    credit_score INT,
    location TEXT,
    gender TEXT,
    age INT,
    balance DECIMAL,
    has_credit_card BOOLEAN,
    estimated_salary DECIMAL,
    satisfaction_score INT,
    card_type TEXT,
    point_earned INT,
    PRIMARY KEY (client_id)
);"""

create_index_cql = """CREATE CUSTOM INDEX IF NOT EXISTS location_idx ON customerprofile (location) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';

CREATE CUSTOM INDEX IF NOT EXISTS has_credit_card_idx ON customerprofile (has_credit_card) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';

CREATE CUSTOM INDEX IF NOT EXISTS credit_score_idx ON customerprofile (credit_score) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';
"""

In [None]:
# This parses the text above into executable strings by the driver
for statement in create_tables_cql.split(";"):
    if len(statement.strip()):
        execute_statement(statement.strip())

In [32]:
# This parses the text above into executable strings by the driver
for statement in create_index_cql.split(";"):
    if len(statement.strip()):
        execute_statement(statement.strip())

In [None]:
import csv
from cassandra.query import SimpleStatement

with open('clients-dataset.csv', 'r') as file:
    reader = csv.reader(file)
    headers = next(reader)
    query = SimpleStatement(f"INSERT INTO {ASTRA_KEYSPACE}.customerprofile (client_id, surname, credit_score, location, gender, age, " \
            "balance, has_credit_card, estimated_salary, satisfaction_score, card_type, point_earned" \
            ") VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")

    for row in reader:
        # Create a dictionary for the row using headers as keys
        row_dict = dict(zip(headers, row))

        # Insert values into Astra database
        session.execute(query, (int(row_dict['CustomerId']), row_dict['Surname'], int(row_dict['CreditScore']), row_dict['Geography'], row_dict['Gender'], int(row_dict['Age']), float(row_dict['Balance']), bool(row_dict['HasCrCard']),
                                float(row_dict['EstimatedSalary']), int(row_dict['Satisfaction Score']), row_dict['Card Type'], int(row_dict['Point Earned'])))

        print(f"Inserted client {row_dict['CustomerId']} into Astra DB")


## (Optional) Give the LLM Additional Context with the Built-in 'Comments' Column

LLM response quality greatly depends on the context they've been given - the more concise descriptions they have access to, the better. We can choose to augment the DB schema we pass to the model by utilizing the built-in `comment` property of CQL tables.

NOTE: You can also include these comments at table creation by using the `WITH <table property 1> AND <table property 2> ... AND comment = '<comment>'` syntax

In [21]:
add_comments_cql = f"""
ALTER TABLE customerprofile WITH comment = 'Customers profile with their credit scores and balance';
"""

In [22]:
# This parses the text above into executable strings by the driver
for line in add_comments_cql.split("\n"):
    sc_loc = line.find(";")
    if sc_loc > -1:
        execute_statement(line[:sc_loc])

## Run Queries from User Questions

#### Generating & Executing CQL

Now, we can ask ChatGPT to provide us with some queries that answer our questions! The prompt template we use is taken from [SQL-PaLM](https://arxiv.org/abs/2306.00739), and adapted to fit the CQL use case. In order to use it though, we need to retrieve the schema from our DB.

In [25]:
TEXT2CQL_PROMPT = """Convert the question to CQL (Cassandra Query Language) that can retrieve an appropriate answer, or answer saying that the data model does not support answering such a question in a performant way:

[Schema : values (type)]
{schema}

[Partition Keys]
{partition_keys}

[Clustering Keys]
{clustering_keys}

[Q]
{question}

[CQL]
"""


def generate_schema_partition_clustering_keys(keyspace: str = ASTRA_KEYSPACE) -> (str, str):
    """Generates a TEXT2CQL_PROMPT compatible schema for a keyspace"""
    # Get all table names in our keyspace
    table_names = execute_statement(
        f"SELECT table_name, comment FROM system_schema.tables WHERE keyspace_name = '{keyspace}' AND table_name = 'customerprofile'"
    )
    tn_str = ", ".join(["'" + tn.table_name + "'" for tn in table_names])

    # Now get all the column names corresponding to those tables
    columns = execute_statement(
        f"SELECT * FROM system_schema.columns WHERE table_name IN ({tn_str}) AND keyspace_name = '{keyspace}' ALLOW FILTERING"
    )

    # Now, we construct our prompt template formatted schema, partition_keys, and clustering keys
    # from the table and column objects returned from the DB
    schema = " | ".join([
        f"{table.table_name} '{table.comment}' : " + " , ".join([
            f"{col.column_name} ({col.type})"
            for col in columns
            if col.table_name == table.table_name
        ])
        for table in table_names
    ])
    partition_keys = " | ".join([
        f"{table.table_name} : " + " , ".join([
            col.column_name for col in columns
            if col.table_name == table.table_name
            and col.kind == "partition_key"
        ])
        for table in table_names
    ])
    clustering_keys = " | ".join([
        f"{table.table_name} : " + " , ".join([
            f"{col.column_name} ({col.clustering_order})" for col in columns
            if col.table_name == table.table_name
            and col.kind == "clustering"
        ])
        for table in table_names
    ])
    return schema, partition_keys, clustering_keys


def execute_query_from_question(question: str, debug_cql: bool = True, debug_prompt: bool = False, return_cql: bool = False):
    """Generates and executes CQL from a user question based on LLM output"""
    # Get all of the variables necessary to fill out the prompt
    schema, partition_keys, clustering_keys = generate_schema_partition_clustering_keys()
    prompt = TEXT2CQL_PROMPT.format(
        schema=schema,
        partition_keys=partition_keys,
        clustering_keys=clustering_keys,
        question=question,
    )

    if debug_prompt:
        print(f"Prompting model with:\n{prompt}")

    # Get generated CQL from the LLM (in this case gpt-3.5-turbo)
    completion = client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": prompt,
        }],
        model="gpt-3.5-turbo",
    ).choices[0].message.content

    if debug_cql:
        print(f"Question: {question}\nGenerated Query: {completion}\n")

    # Need to trim trailing ';' if present to work with cassandra-driver
    if completion.find(";") > -1:
        completion = completion[:completion.find(";")]

    results = execute_statement(completion)

    if return_cql:
        return (results, completion)
    else:
        return results

In [53]:
# Show full prompting trace
execute_query_from_question("List 3 Male customers in Thailand?", debug_prompt=True)

Prompting model with:
Convert the question to CQL (Cassandra Query Language) that can retrieve an appropriate answer, or answer saying that the data model does not support answering such a question in a performant way:

[Schema : values (type)]
customerprofile 'Customers profile with their credit scores and balance' : age (int) , balance (decimal) , card_type (text) , client_id (int) , credit_score (int) , estimated_salary (decimal) , gender (text) , has_credit_card (boolean) , location (text) , point_earned (int) , satisfaction_score (int) , surname (text)

[Partition Keys]
customerprofile : client_id

[Clustering Keys]
customerprofile : 

[Q]
List 3 Male customers in Thailand?

[CQL]

Question: List 3 Male customers in Thailand?
Generated Query: SELECT * FROM customerprofile WHERE gender = 'Male' AND location = 'Thailand' LIMIT 3; 

OR

The data model does not support answering such a question in a performant way.



[Row(client_id=15741643, age=35, balance=Decimal('122917.69'), card_type='DIAMOND', credit_score=777, estimated_salary=Decimal('76169.68'), gender='Male', has_credit_card=True, location='Thailand', point_earned=624, satisfaction_score=2, surname='Chiang'),
 Row(client_id=15622993, age=28, balance=Decimal('124695.72'), card_type='DIAMOND', credit_score=709, estimated_salary=Decimal('145251.35'), gender='Male', has_credit_card=True, location='Thailand', point_earned=665, satisfaction_score=3, surname='Boyd'),
 Row(client_id=15785899, age=33, balance=Decimal('151607.56'), card_type='SILVER', credit_score=789, estimated_salary=Decimal('4389.4'), gender='Male', has_credit_card=True, location='Thailand', point_earned=513, satisfaction_score=4, surname="Ch'en")]

In [39]:
# Show full prompting trace
execute_query_from_question("List 3 customers in Cambodia who have credit card?", debug_prompt=True)

Prompting model with:
Convert the question to CQL (Cassandra Query Language) that can retrieve an appropriate answer, or answer saying that the data model does not support answering such a question in a performant way:

[Schema : values (type)]
customerprofile 'Customers profile with their credit scores and balance' : age (int) , balance (decimal) , card_type (text) , client_id (int) , credit_score (int) , estimated_salary (decimal) , gender (text) , has_credit_card (boolean) , location (text) , point_earned (int) , satisfaction_score (int) , surname (text)

[Partition Keys]
customerprofile : client_id

[Clustering Keys]
customerprofile : 

[Q]
List 3 customers in Cambodia who have credit card?

[CQL]

Question: List 3 customers in Cambodia who have credit card?
Generated Query: SELECT * FROM customerprofile WHERE location = 'Cambodia' AND has_credit_card = true LIMIT 3;



[Row(client_id=15740147, age=44, balance=Decimal('0.0'), card_type='SILVER', credit_score=725, estimated_salary=Decimal('93777.61'), gender='Female', has_credit_card=True, location='Cambodia', point_earned=696, satisfaction_score=5, surname='Cremonesi'),
 Row(client_id=15625716, age=33, balance=Decimal('113913.53'), card_type='PLATINUM', credit_score=637, estimated_salary=Decimal('65316.5'), gender='Female', has_credit_card=True, location='Cambodia', point_earned=705, satisfaction_score=1, surname='Genovesi'),
 Row(client_id=15605263, age=33, balance=Decimal('140931.57'), card_type='PLATINUM', credit_score=552, estimated_salary=Decimal('10921.5'), gender='Male', has_credit_card=True, location='Cambodia', point_earned=330, satisfaction_score=5, surname='Chin')]

#### End to End Question Answering

Now, let's wrap up by showing how we can make a subsequent LLM call to answer the user's question with natural language. This completes a full "RAG" style pipeline!

In [40]:
ANSWER_PROMPT = """Query:
```
{cql}
```

Output:
```
{results_repr}
```
===

Given the above results from querying the DB, answer the following user question:

{question}
"""


def answer_question(question: str, debug_cql: bool = False, debug_prompt: bool = False) -> str:
    """Conducts a full RAG pipeline where the LLM retrieves relevant information
    and references it to answer the question in natural language.
    """
    # Get necessary fields to fill out prompt
    query_results, cql = execute_query_from_question(
        question=question,
        debug_cql=debug_cql,
        debug_prompt=debug_prompt,
        return_cql=True,
    )
    prompt = ANSWER_PROMPT.format(
        question=question,
        results_repr=str(query_results),
        cql=cql,
    )

    if debug_prompt:
        print(f"Prompting model with:\n{prompt}")

    # Return the generated answer from the LLM
    return client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": prompt,
        }],
        model="gpt-3.5-turbo",
    ).choices[0].message.content


In [41]:
# Show full prompting trace
print(
    answer_question("List 3 customers in Cambodia who have credit card?", debug_prompt=True)
)

Prompting model with:
Convert the question to CQL (Cassandra Query Language) that can retrieve an appropriate answer, or answer saying that the data model does not support answering such a question in a performant way:

[Schema : values (type)]
customerprofile 'Customers profile with their credit scores and balance' : age (int) , balance (decimal) , card_type (text) , client_id (int) , credit_score (int) , estimated_salary (decimal) , gender (text) , has_credit_card (boolean) , location (text) , point_earned (int) , satisfaction_score (int) , surname (text)

[Partition Keys]
customerprofile : client_id

[Clustering Keys]
customerprofile : 

[Q]
List 3 customers in Cambodia who have credit card?

[CQL]

Prompting model with:
Query:
```
SELECT * FROM customerprofile WHERE location = 'Cambodia' AND has_credit_card = true LIMIT 3
```

Output:
```
[Row(client_id=15740147, age=44, balance=Decimal('0.0'), card_type='SILVER', credit_score=725, estimated_salary=Decimal('93777.61'), gender='Fema

Awesome! Our model is answering questions based on just the data in our dummy DB, and is able to construct queries for retrieving that data in a fully automated way.