# Text and ANN Search on Astra DB (powered by cassandra)

This notebook demonstrates how Astra DB can combine vector similarity search with term search to improve performance, and increase relevance on Generative AI use cases.

## Dependencies

In [49]:
pip install datasets cassandra-driver sentence-transformers transformers

Note: you may need to restart the kernel to use updated packages.


## Astra DB Setup

In [50]:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider

In [51]:
import os
from getpass import getpass

try:
    from google.colab import files
    IS_COLAB = True
except ModuleNotFoundError:
    IS_COLAB = False

In [52]:
# Your database's Secure Connect Bundle zip file is needed:
if IS_COLAB:
    print('Please upload your Secure Connect Bundle zipfile: ')
    uploaded = files.upload()
    if uploaded:
        astraBundleFileTitle = list(uploaded.keys())[0]
        ASTRA_DB_SECURE_BUNDLE_PATH = os.path.join(os.getcwd(), astraBundleFileTitle)
    else:
        raise ValueError(
            'Cannot proceed without Secure Connect Bundle. Please re-run the cell.'
        )
else:
    # you are running a local-jupyter notebook:
    ASTRA_DB_SECURE_BUNDLE_PATH = input("Please provide the full path to your Secure Connect Bundle zipfile: ")

ASTRA_DB_APPLICATION_TOKEN = getpass("Please provide your Database Token ('AstraCS:...' string): ")
ASTRA_DB_KEYSPACE = input("Please provide the Keyspace name for your Database: ")

In [56]:
# Don't mind the "Closing connection" error after "downgrading protocol..." messages,
# it is really just a warning: the connection will work smoothly.
cluster = Cluster(
    cloud={
        "secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH,
    },
    auth_provider=PlainTextAuthProvider(
        "token",
        ASTRA_DB_APPLICATION_TOKEN,
    ),
)

session = cluster.connect()
keyspace = ASTRA_DB_KEYSPACE

## Load Dataset

For this example we are going to be using a variation of the [yahoo_answers dataset](https://huggingface.co/datasets/yahoo_answers_topics/viewer/yahoo_answers_topics/train) from huggingface. For demonstration purposes, we included the label for topic (instead of its id), so that we can use term based search to filter our queries. This dataset has over 1,000,000 answers so it's great for demonstrating Astra DB's performance, scalability, and to leverage filtering capabilities with hybrid search.

In [58]:
from datasets import load_dataset
dataset = load_dataset("jdabello/yahoo_answers_topics", split="train")
dataset

Downloading readme: 100%|██████████| 616/616 [00:00<00:00, 7.62MB/s]
Downloading data: 100%|██████████| 241M/241M [04:20<00:00, 927kB/s] 
Downloading data: 100%|██████████| 270M/270M [05:01<00:00, 897kB/s] 
Downloading data files: 100%|██████████| 1/1 [09:21<00:00, 561.60s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 225.11it/s]
Generating train split: 100%|██████████| 1400000/1400000 [00:01<00:00, 964290.31 examples/s]


Dataset({
    features: ['id', 'topic', 'question_title', 'question_content', 'best_answer'],
    num_rows: 1400000
})

Lets inspect one record.

In [67]:
dataset[0]

{'id': 0,
 'topic': 'Computers & Internet',
 'question_title': "why doesn't an optical mouse work on a glass table?",
 'question_content': 'or even on some surfaces?',
 'best_answer': 'Optical mice use an LED and a camera to rapidly capture images of the surface beneath the mouse.  The infomation from the camera is analyzed by a DSP (Digital Signal Processor) and used to detect imperfections in the underlying surface and determine motion. Some materials, such as glass, mirrors or other very shiny, uniform surfaces interfere with the ability of the DSP to accurately analyze the surface beneath the mouse.  \\nSince glass is transparent and very uniform, the mouse is unable to pick up enough imperfections in the underlying surface to determine motion.  Mirrored surfaces are also a problem, since they constantly reflect back the same image, causing the DSP not to recognize motion properly. When the system is unable to see surface changes associated with movement, the mouse will not work pr

## LLM Setup

In [60]:
OPENAI_API_KEY = getpass("Please enter your OpenAI API Key: ")

In [61]:
import openai

openai.api_key = OPENAI_API_KEY

The yahoo answers dataset, has 2 separate columns for the Question Title, and the Question Content. For this excercise we are concatenating those into a single column that we will call Question. For demonstration purposes, we are going to get the embedding of the first question, and get the number of dimensions for that vector. This is needed to create the table.

In [68]:
embedding_model_name = "text-embedding-ada-002"

result = openai.Embedding.create(
    input=f"{dataset[0]['question_title']}. {'question_content'}",
    engine=embedding_model_name,
)
result

<OpenAIObject list at 0x2a1015df0> JSON: {
  "object": "list",
  "data": [
    {
      "object": "embedding",
      "index": 0,
      "embedding": [
        -0.01286814920604229,
        0.005121403839439154,
        -0.015341008082032204,
        -0.03235268592834473,
        -0.0010972069576382637,
        0.022859029471874237,
        -0.012410703115165234,
        -0.029939495027065277,
        -0.023004882037639618,
        -0.03224661201238632,
        0.016057010740041733,
        0.02049887552857399,
        -0.024344071745872498,
        -0.01151570025831461,
        -0.01695864275097847,
        0.002686665393412113,
        0.00276290625333786,
        0.020909912884235382,
        0.0140415970236063,
        -0.008406395092606544,
        -0.04341094195842743,
        0.003802103688940406,
        -0.017396198585629463,
        -0.01603049226105213,
        -0.007199798710644245,
        -0.02487444318830967,
        0.017900051549077034,
        -0.02903786301612854,
     

In [72]:
#Let's get the number of dimensions for our embedding
len(result.data[0].embedding)

1536

In [57]:
# Run this to drop the table and indexes before starting over
session.execute(f"DROP TABLE IF EXISTS {ASTRA_DB_KEYSPACE}.yahoo_answers")

<cassandra.cluster.ResultSet at 0x34d502750>

Now we are going to create our table. We will include the id of the answer, the topic, the Question (concatenated title, and content), the best answer, and the embedding of the question. That's what we will use for the ANN search. We now know that the embedding model we are using generates an embedding with 1536 dimensions.

In [89]:
mktable_cql = f"""CREATE TABLE {ASTRA_DB_KEYSPACE}.yahoo_answers (
answer_id int PRIMARY KEY,
topic text,
question text,
best_answer text,
question_embedding vector<float, 1536>
);
"""
session.execute(mktable_cql)

<cassandra.cluster.ResultSet at 0x34d5a9150>

In [90]:
session.execute(f"CREATE CUSTOM INDEX ON {ASTRA_DB_KEYSPACE}.yahoo_answers(question_embedding) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'")

<cassandra.cluster.ResultSet at 0x3957fe210>

In [92]:
session.execute(f"""
    CREATE CUSTOM INDEX ON {ASTRA_DB_KEYSPACE}.yahoo_answers(topic)
    USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'
    WITH OPTIONS = {{
    'index_analyzer': '{{
    "tokenizer" : {{"name" : "standard"}},
    "filters" : [{{"name" : "porterstem"}},{{"name" : "lowercase",	"args": {{}}}}]
    }}'}};""")

<cassandra.cluster.ResultSet at 0x34ca6b850>

We created the embedding for a single record. Now let's do it for 1M+ questions we'll have to process. Note that you are calling the LLM API ~700 times. This will take ~5 minutes.

In [149]:
embedding_params = []
embedding_model_name = "text-embedding-ada-002"

input=[]
for row in dataset:
    print("*", end=" ")
    input.append(f"{row['question_title']}. {row['question_content']}")
    if len(input) % 2000 == 0:  #sending batches of 2000 questions
        print("")
        result = openai.Embedding.create(
            input=input,
            engine=embedding_model_name,
        )
        for result_data in result.data:
            print(".", end=" ")
            embedding_params.append(result_data.embedding)
        input=[]
        print("")
        print(f"{len(embedding_params)} rows processed.")


*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*


KeyboardInterrupt: 

In [124]:
params_list = []
i=0
for row in dataset:
    params_list.append((row['id'],row['best_answer'],f"{row['question_title']}. {row['question_content']}",embedding_params[i], row['topic']))
    i=+1

In [126]:
from cassandra.concurrent import execute_concurrent_with_args
request = session.prepare(
                    f"""
                INSERT INTO {ASTRA_DB_KEYSPACE}.yahoo_answers
                (answer_id, best_answer, question, question_embedding, topic)
                VALUES (?, ?, ?, ?, ?)
                """
)
execute_concurrent_with_args(session, request, params_list)

[ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34cc4b050>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34ca31990>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34cc4a050>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34cab0790>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34ca30590>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34cc488d0>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34cc49450>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34ca33f10>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34ca32f90>),
 ExecutionResult(success=True, result_or_exc=<cassandra.cluster.ResultSet object at 0x34ca32210>)]

In [137]:
from cassandra.query import SimpleStatement
question = 'Why people get embarrased?'
#question = 'What is a good vehicle to purchase?'
#question = 'What is the best way to ship a package from USA to UK?'

result = openai.Embedding.create(
    input=question,
    engine=embedding_model_name,
)
embedding=result.data[0].embedding
print(embedding)
query = SimpleStatement(
    f"""
    SELECT question, best_answer
    FROM {ASTRA_DB_KEYSPACE}.yahoo_answers
    WHERE topic: 'computers'
    ORDER BY question_embedding ANN OF {embedding} LIMIT 5;
    """
    )

results = session.execute(query)
top_5_products = results._current_rows

for row in top_5_products:
  print(f"""{row.question}, {row.best_answer}\n""")

[0.007383933290839195, -0.01889094151556492, 0.042631033807992935, -0.013244403526186943, -0.018942803144454956, -0.0011053209891542792, 0.004012866411358118, -0.006943101529031992, -0.004019349347800016, -0.0232085008174181, -0.013419440016150475, 0.0028654069174081087, -0.00897870771586895, 0.01317309308797121, -0.013704684562981129, 0.04322745278477669, 0.03451453894376755, 0.011513490229845047, -0.006852342281490564, -0.023234430700540543, -0.0024731962475925684, 0.002531541744247079, -0.009309330955147743, 0.01233032625168562, -0.039130307734012604, 0.002862165682017803, 0.03736698254942894, -0.016492297872900963, 0.023791953921318054, -0.02158779464662075, 0.01570139266550541, -0.010560516268014908, -0.013834340497851372, -0.026488807052373886, -0.050099242478609085, -0.002811923623085022, -0.006385578773915768, -0.01331571489572525, 0.02764274924993515, -0.012745226733386517, -0.0033321701921522617, -0.005429362878203392, 0.00027936906553804874, -0.0024553686380386353, -0.027357