# Demo 6 - Multi-vectors

In [None]:
!pip install --quiet "astrapy>=1.2.0" "python-dotenv>=1.0.0"

In [None]:
import getpass
import os

from dotenv import load_dotenv

from astrapy import DataAPIClient
from astrapy.constants import VectorMetric

## Setup DB

In [None]:
load_dotenv()

if "ASTRA_DB_APPLICATION_TOKEN" not in os.environ:
    os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass.getpass("Please input your Astra DB Token:")

if "ASTRA_DB_API_ENDPOINT" not in os.environ:
    os.environ["ASTRA_DB_API_ENDPOINT"] = input("Please input your Astra DB API Endpoint:")

if "ASTRA_DB_KEYSPACE" not in os.environ:
    _namespace = input("(Optional) Input your Astra DB namespace if desired, or leave blank:")
    if _namespace:
        os.environ["ASTRA_DB_KEYSPACE"] = _namespace

ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE")

In [None]:
db = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN).get_database(ASTRA_DB_API_ENDPOINT, namespace=ASTRA_DB_KEYSPACE)

## Raw data

Entries have _two_ vectors each, individually normalized, designed for Cos (/Dot) usage

_(Note: there are probably better representation for a color than a **unit-norm** vector. But bear with me for the sake of the example)_

In [None]:
def vector_norm(v):
    return sum(v_i * v_i for v_i in v) ** 0.5

def normalize_vector(v):
    norm = vector_norm(v)
    return [v_i / norm for v_i in v]

In [None]:
raw_critters = [
    {
        "name": "Argiope bruennichi",
        "main_color": normalize_vector([255, 215, 68]),  # an R-G-B point
        "web_orientation": normalize_vector([1, 0]),  # vertical
    },
    {
        "name": "Tetragnatha extensa",
        "main_color": normalize_vector([84, 255, 115]),
        "web_orientation": normalize_vector([0.5, 0.5]),
    },
    {
        "name": "Dysdera crocata",
        "main_color": normalize_vector([114, 30, 40]),
        "web_orientation": normalize_vector([0.2, 1]),
    },
    {
        "name": "Eresus cinnabarinus",
        "main_color": normalize_vector([230, 84, 100]),
        "web_orientation": normalize_vector([0.5, 0.3]),
    },
]

## Collation of vectors

In [None]:
full_dimension = 3 + 2

spiders_multivector_collection = db.create_collection(
    "spiders_multivector",
    dimension=full_dimension,
    metric=VectorMetric.DOT_PRODUCT,
    check_exists=False,
)
# just in case this demo is re-run
spiders_multivector_collection.delete_all()

In [None]:
def full_vector(raw_critter):
    return raw_critter["main_color"] + raw_critter["web_orientation"]

print(full_vector(raw_critters[2]))

In [None]:
spiders_multivector_collection.insert_many(
    raw_critters,
    vectors=[full_vector(raw_critter) for raw_critter in raw_critters],
)

## Combined similarity search

i.e. maximizing `alpha * S_dot(q_a, v_a) + (1-alpha) * S_dot(q_b, v_b)` through **Dot**

In [None]:
def full_query_vector(q_a, q_b, alpha):
    norm_q_a = normalize_vector(q_a)
    norm_q_b = normalize_vector(q_b)
    return [alpha * q_a_i for q_a_i in norm_q_a] + [(1 - alpha) * q_b_i for q_b_i in norm_q_b]

In [None]:
query_color = [100, 20, 60]  # that would be ~q_a
query_web = [0.1, 0.4]  # that would be ~q_b

print(full_query_vector(query_color, query_web, 0.15))
print(full_query_vector(query_color, query_web, 0.85))

In [None]:
qv = full_query_vector(query_color, query_web, 0.4)
print(f"vector_norm(qv):     {vector_norm(qv)}")
print(f"vector_norm(qv[:3]): {vector_norm(qv[:3])}")  # = alpha
print(f"vector_norm(qv[3:]): {vector_norm(qv[3:])}")  # = 1-alpha

In [None]:
def multivector_search(q_a, q_b, alpha, n=3):
    full_qv = full_query_vector(q_a, q_b, alpha)
    return spiders_multivector_collection.find(vector=full_qv, limit=n, projection={"$vector": False})

In [None]:
print("By color only:")
for cr_i, cr_doc in enumerate(multivector_search(query_color, query_web, 1.0)):
    print(f"  [{cr_i + 1}] '{cr_doc['name']}'")

In [None]:
print("By web only:")
for cr_i, cr_doc in enumerate(multivector_search(query_color, query_web, 0.0)):
    print(f"  [{cr_i + 1}] '{cr_doc['name']}'")

In [None]:
def list_initials(alpha):
    return "-".join(cr_doc["name"][0] for cr_doc in multivector_search(query_color, query_web, alpha=alpha))

for alpha in [i / 10 for i in range(11)]:
    print(f"Alpha {alpha:.2f} ==> results {list_initials(alpha)}")

In [None]:
### To delete the inserted data, run:
# spiders_multivector_collection.delete_all()

### To remove the very collection, run:
spiders_multivector_collection.drop()

## The End