In [22]:
import os
from dotenv import load_dotenv
import hashlib

_ = load_dotenv('.env')

In [9]:
import cassio

In [10]:
cassio.init(token=os.environ['ASTRA_DB_APPLICATION_TOKEN'], database_id=os.environ['ASTRA_DB_ID'])

In [14]:
# we need those as we'll run 'bare CQL' here
session = cassio.config.resolve_session()
keyspace = cassio.config.resolve_keyspace()

# Toward Hybrid search

Problem: we have several text snippets which will be vector-searched.

But it will be apparent that "vector search is not enough". What we'll explore here are solutions for hybrid search that scale well at Astra DB scale.

In [91]:
snippets = [
    "I would like to buy gift cards. Where can I get discounts?",
    "The support operator is using foul language.",
    "I cannot open the support chat.",
    "I see no messages in the support chat.",
    "Are special offers available?",
    "The support chat on the website is lagging.",
    "I cannot speak with the support operator!",
    "I want to inquire about a specific product line.",
    "I have tried multiple times to make a payment but it does not get processed.",
    "I am having trouble opening my shopping cart!",
    "Speaking to a technicial is impossible, WTF?",
]

We are intentionally leaving out any field such as "metadata", to try and focus exclusively on the text. In other words, adding metadata would require an effort (sudh as manual/AI-assisted labeling) which would scale with the number of rows, which we are trying to avoid by shifting all the load on the query side.

In [16]:
vector_dimension = 1536  # openAI ...

In [17]:
CREATE_CQL = f"CREATE TABLE {keyspace}.snippets (snippet_id TEXT PRIMARY KEY, snippet TEXT, embedding VECTOR<FLOAT,{vector_dimension}>);"
session.execute(CREATE_CQL)

CREATE_V_IDX = f"CREATE CUSTOM INDEX snippets_embedding_idx ON {keyspace}.snippets (embedding) USING 'StorageAttachedIndex';"
session.execute(CREATE_V_IDX)

<cassandra.cluster.ResultSet at 0x7f3788bcdfc0>

## Get embeddings and insert rows

In [20]:
import openai

embedding_model_name = "text-embedding-ada-002"

def get_embeddings(texts):
    result = openai.Embedding.create(
        input=texts,
        engine=embedding_model_name,
    )
    return [res.embedding for res in result.data]

In [25]:
def snippet_id(sn): return hashlib.md5(sn.encode()).hexdigest()

print(snippet_id("Test snippet."))

50f0866734db8ec79171ddc6b13988d9


In [92]:
# TODO if ever needed, add batching to this.

embeddings = get_embeddings(snippets)

INSERT_ROW = session.prepare(f"INSERT INTO {keyspace}.snippets (snippet_id, snippet, embedding) VALUES (?, ?, ?);")

for snippet, embedding in zip(snippets, embeddings):
    session.execute(INSERT_ROW, (
        snippet_id(snippet),
        snippet,
        embedding,
    ))

## Simple retrieval

In [29]:
query = "I cannot even use the frigging website!"

In [37]:
SIMPLE_ANN = session.prepare(f"SELECT snippet, similarity_cosine(embedding, ?) as similarity FROM {keyspace}.snippets ORDER BY embedding ANN OF ? LIMIT ?")

def simple_ann(query, top_k=3):
    q_vector = get_embeddings([query])[0]
    return [
        (row.snippet, row.similarity)
        for row in session.execute(SIMPLE_ANN, (
            q_vector,
            q_vector,
            top_k,
        ))
    ]

In [40]:
def show(results):
    for ri, (sn, si) in enumerate(results):
        print(f"[{ri+1}] {si:.5f} \"{sn}\"")

In [41]:
show(simple_ann(query, 5))

[1] 0.92787 "I am having trouble opening my shopping cart!"
[2] 0.92661 "I cannot speak with the support operator!"
[3] 0.91543 "I cannot open the support chat."
[4] 0.91177 "The support chat on the website is lagging."
[5] 0.89659 "The support operator is using foul language."


In [42]:
query2 = "It seems that the website is broken"
show(simple_ann(query2, 5))

[1] 0.92036 "The support chat on the website is lagging."
[2] 0.90841 "I am having trouble opening my shopping cart!"
[3] 0.90525 "I cannot open the support chat."
[4] 0.90367 "I see no messages in the support chat."
[5] 0.89755 "I cannot speak with the support operator!"


#### Lesson: beware of setting a threshold on just-ANN and calling it a day!

## Supplemental indexing for Hybrid

General idea: adding a stemming tokenized index on the `snippet` column, and then run hybrid queries of some sort.

Let's try to use the "stemming + untouched query" case from the previous part of this journey:

In [45]:
# Don't mind the "{{" and "}}", it's just to escape the F-string syntax here
CREATE_S_IDX = f'''CREATE CUSTOM INDEX snippets_snippet_idx ON {keyspace}.snippets (snippet) USING 'StorageAttachedIndex'
  WITH OPTIONS = {{
    'index_analyzer': '{{
      "tokenizer": {{
        "name": "standard"
      }},
      "filters": [
        {{
          "name": "lowercase"
        }},
        {{
          "name": "porterstem"
        }}
      ]
    }}',
    'query_analyzer': 'keyword'
  }};'''

session.execute(CREATE_S_IDX)

<cassandra.cluster.ResultSet at 0x7f37803d4a60>

#### A test query to see what this index finds

In [95]:
# we're not bothering with preparing statements here (variable shape, left as an exercise)
# Hence, note the '%s' in place of the '?'.

KEYWORD_QUERY_TEMPLATE = f"SELECT snippet FROM {keyspace}.snippets{{where_clause}} LIMIT %s ALLOW FILTERING"

def create_where_parts(keywords):
    where_clause_pieces = [
        "snippet : %s"
        for _ in sorted(set(keywords))
    ]
    where_clause_args = sorted(set(keywords))
    if where_clause_pieces:
        return (' WHERE ' + ' AND '.join(where_clause_pieces), where_clause_args)
    else:
        return ('', list())

def find_by_keywords(keywords, n=3):
    wc, wc_vals = create_where_parts(keywords)
    #
    keyword_query = KEYWORD_QUERY_TEMPLATE.format(where_clause=wc, n=n)
    vals = tuple(wc_vals + [n])
    return [
        # let's pass a number to keep the output shape
        (row.snippet, 1.0)
        for row in session.execute(keyword_query, vals)
    ]

Note that "speak" matches "speaking" and "Speaking".

But also remember that this index is (purposefully) configured not to process the query, so you should not expect results when passing keywords such as `"Speak", "speaking", "having trouble"`.

In [94]:
show(find_by_keywords(['speak'], 10))

[1] 1.00000 "The support operator is using foul language when speaking."
[2] 1.00000 "I cannot speak with the support operator!"
[3] 1.00000 "Speaking to a technicial is impossible, WTF?"
[4] 1.00000 "I cannot speak with support operators!"


In [96]:
show(find_by_keywords(['Speak'], 10))
show(find_by_keywords(['speaking'], 10))
show(find_by_keywords(['having trouble'], 10))

## _To be continued ..._