# IN clauses and CassIO

_Stefano, 2023-12-06_

Let's look at what can be done to support IN-type filtering for the metadata in CassIO

## Setup

In [4]:
! pip install -q "cassio>=0.1.3"

In [5]:
import os

import cassio

In [6]:
if "ASTRA_DB_DATABASE_ID" not in os.environ:
    os.environ["ASTRA_DB_DATABASE_ID"] = input("ASTRA_DB_DATABASE_ID = ")

if "ASTRA_DB_APPLICATION_TOKEN" not in os.environ:
    os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass("ASTRA_DB_APPLICATION_TOKEN = ")

if "ASTRA_DB_KEYSPACE" not in os.environ:
    ks = input("(Optional) ASTRA_DB_KEYSPACE = ")
    if ks:
        os.environ["ASTRA_DB_KEYSPACE"] = ks

(Optional) ASTRA_DB_KEYSPACE =  


In [7]:
cassio.init(
    database_id=os.environ["ASTRA_DB_DATABASE_ID"],
    token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
    keyspace=os.environ.get("ASTRA_DB_KEYSPACE"),
)

## Create a standard CassIO table

In [10]:
v_table = cassio.table.MetadataVectorCassandraTable(table="in_clause_test", vector_dimension=3)

In [34]:
v_table.put(row_id="01", metadata={"chunk_label": "1"}, body_blob="body 1", vector=[1, 0, 0])
v_table.put(row_id="01b", metadata={"chunk_label": "1"}, body_blob="body 1", vector=[1, 0.4, 0])
v_table.put(row_id="01c", metadata={"chunk_label": "1"}, body_blob="body 1", vector=[1, 0, 0])
v_table.put(row_id="02", metadata={"chunk_label": "2"}, body_blob="body 2", vector=[0, 1, 0.4])
v_table.put(row_id="02b", metadata={"chunk_label": "2"}, body_blob="body 2", vector=[0.8, 0.2, 0.3])
v_table.put(row_id="03", metadata={"chunk_label": "3"}, body_blob="body 3", vector=[1, 1, 1])

## CQL

### Metadata-filtering reads with CQL (baseline)

In [35]:
c_session = cassio.config.resolve_session()
c_keyspace = cassio.config.resolve_keyspace()

In [36]:
rows = list(c_session.execute(f"select * from {c_keyspace}.in_clause_test where metadata_s['chunk_label'] = '1' limit 3"))

print(rows)

[Row(row_id='01b', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.4000000059604645, 0.0]), Row(row_id='01c', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.0, 0.0]), Row(row_id='01', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.0, 0.0])]


### Metadata-filtering IN clause usage

We know this errors

In [37]:
try:
    rows = list(c_session.execute(f"select * from {c_keyspace}.in_clause_test where metadata_s['chunk_label'] IN ('1', '2') limit 3"))
    print(rows)
except Exception as e:
    print(str(e))

<Error from server: code=2000 [Syntax error in CQL query] message="line 1:78 no viable alternative at input 'IN' (....in_clause_test where metadata_s['chunk_label'] [IN]...)">


### Workaround: use OR

The above can be reformulated by exploiting the capability of SAI indexing to handle arbitrary AND/OR conditions:

In [38]:
rows = list(c_session.execute(
    f"select * from {c_keyspace}.in_clause_test where metadata_s['chunk_label'] = '1' or metadata_s['chunk_label'] = '2' limit 3"
))

print(rows)

[Row(row_id='01b', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.4000000059604645, 0.0]), Row(row_id='01c', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.0, 0.0]), Row(row_id='02', attributes_blob=None, body_blob='body 2', metadata_s=OrderedMapSerializedKey([('chunk_label', '2')]), vector=[0.0, 1.0, 0.4000000059604645])]


### Possibly better

A condition such as the above puts some load on a single coordinator, which generally is to be avoided when possible - preferring to spread the load across coordinators (I would say this is still true even when using SAIs to run the queries).

Given the conditions in this case are mutually exclusive, there's not even the problem of double counting, so I guess a solution like the following is preferrable.

In [39]:
from concurrent.futures import ThreadPoolExecutor

def label_to_results(label, limit=3):
    return list(c_session.execute(f"select * from {c_keyspace}.in_clause_test where metadata_s['chunk_label'] = '{label}' limit {limit}"))

# Not done here: prepared statement, partialing "limit" away ...

with ThreadPoolExecutor(max_workers=10) as tpe:
    result_list = list(
        tpe.map(
            label_to_results,
            ['1', '2'],
        )
    )

# flatten
results = [row for rows in result_list for row in rows]

print(results)

[Row(row_id='01b', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.4000000059604645, 0.0]), Row(row_id='01c', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.0, 0.0]), Row(row_id='01', attributes_blob=None, body_blob='body 1', metadata_s=OrderedMapSerializedKey([('chunk_label', '1')]), vector=[1.0, 0.0, 0.0]), Row(row_id='02', attributes_blob=None, body_blob='body 2', metadata_s=OrderedMapSerializedKey([('chunk_label', '2')]), vector=[0.0, 1.0, 0.4000000059604645]), Row(row_id='02b', attributes_blob=None, body_blob='body 2', metadata_s=OrderedMapSerializedKey([('chunk_label', '2')]), vector=[0.800000011920929, 0.20000000298023224, 0.30000001192092896])]


Notes:

1. You will have to cut the final list to keep only the `limit` first items. In other words, be aware that there's a slight waste in retrieving (you must use the same limit in each subquery as the final one).
2. The arbitrariness of the result (due to the cut) is not different than the one you would get with a genuine `IN` clause if it worked.
3. On the other hand, consider that if the items in the "IN clause" start to be more than a handful, even more important it is to avoid overloading a single query coordinator, more than making up for the previous point.

## Within CassIO

Such "custom" queries as the ones with an "OR" are not _yet_ supported in CassIO (there are plans to extend the syntax for the metadata filtering control).

Likewise, (though perhaps made irrelevant by the solutions presented in this notebook) there are plans to allow for a _user-provided_ table table schema in CassIO. But not something coming in 1-2 weeks, to be clear.

In other words: if you want to avoid descending to the CQL level (to be avoided when possible), concurrency is your friend:

In [40]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial

def cassio_labeled_ann(label, query_vector, limit):
    return v_table.metric_ann_search(
        query_vector,
        n=limit,
        metric="cos",
        metadata={"chunk_label": label},
    )

with ThreadPoolExecutor(max_workers=10) as tpe:
    searcher = partial(cassio_labeled_ann, query_vector=[2, 1, 1], limit=3)
    result_list = list(
        tpe.map(
            searcher,
            ['1', '2'],
        )
    )

# flatten
results = [row for rows in result_list for row in rows]

print(results)

[{'metadata': {'chunk_label': '1'}, 'row_id': '01b', 'body_blob': 'body 1', 'vector': [1.0, 0.4000000059604645, 0.0], 'distance': 0.90971765268422}, {'metadata': {'chunk_label': '1'}, 'row_id': '01c', 'body_blob': 'body 1', 'vector': [1.0, 0.0, 0.0], 'distance': 0.8164965809277261}, {'metadata': {'chunk_label': '1'}, 'row_id': '01', 'body_blob': 'body 1', 'vector': [1.0, 0.0, 0.0], 'distance': 0.8164965809277261}, {'metadata': {'chunk_label': '2'}, 'row_id': '02b', 'body_blob': 'body 2', 'vector': [0.800000011920929, 0.20000000298023224, 0.30000001192092896], 'distance': 0.9770084215486352}, {'metadata': {'chunk_label': '2'}, 'row_id': '02', 'body_blob': 'body 2', 'vector': [0.0, 1.0, 0.4000000059604645], 'distance': 0.53066863167384}]


Remark: you will have to pass the same `limit` as the full query, and then cut the full results to keep exactly `limit` items. But of course now you don't have them sorted!

So this is a "final" recipe:

In [43]:
# sort and cut
final_result = sorted(results, key=lambda res: res["distance"], reverse=True)[:3]

print(final_result)

[{'metadata': {'chunk_label': '2'}, 'row_id': '02b', 'body_blob': 'body 2', 'vector': [0.800000011920929, 0.20000000298023224, 0.30000001192092896], 'distance': 0.9770084215486352}, {'metadata': {'chunk_label': '1'}, 'row_id': '01b', 'body_blob': 'body 1', 'vector': [1.0, 0.4000000059604645, 0.0], 'distance': 0.90971765268422}, {'metadata': {'chunk_label': '1'}, 'row_id': '01c', 'body_blob': 'body 1', 'vector': [1.0, 0.0, 0.0], 'distance': 0.8164965809277261}]


Note: the above shows the standard usage of running queries in a vector store, i.e. ANN searches.

Nothing prevents you from doing the same with a just-metadata query (expecting arbitrariness in the results if cuts have to be done):

In [44]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial

def cassio_labeled_md_search(label, limit):
    return v_table.find_entries(
        n=limit,
        metadata={"chunk_label": label},
    )

with ThreadPoolExecutor(max_workers=10) as tpe:
    searcher = partial(cassio_labeled_md_search, limit=3)
    result_list = list(
        tpe.map(
            searcher,
            ['1', '2'],
        )
    )

# flatten
results = [row for rows in result_list for row in rows]

print(results)

[{'metadata': {'chunk_label': '1'}, 'row_id': '01b', 'body_blob': 'body 1', 'vector': [1.0, 0.4000000059604645, 0.0]}, {'metadata': {'chunk_label': '1'}, 'row_id': '01c', 'body_blob': 'body 1', 'vector': [1.0, 0.0, 0.0]}, {'metadata': {'chunk_label': '1'}, 'row_id': '01', 'body_blob': 'body 1', 'vector': [1.0, 0.0, 0.0]}, {'metadata': {'chunk_label': '2'}, 'row_id': '02', 'body_blob': 'body 2', 'vector': [0.0, 1.0, 0.4000000059604645]}, {'metadata': {'chunk_label': '2'}, 'row_id': '02b', 'body_blob': 'body 2', 'vector': [0.800000011920929, 0.20000000298023224, 0.30000001192092896]}]
