# Demo 5 - Sort your critters

In [None]:
!pip install --quiet "astrapy>=1.0.0" "python-dotenv>=1.0.0"

In [None]:
import getpass
import os

from dotenv import load_dotenv

from astrapy import DataAPIClient
from astrapy.constants import VectorMetric

## Setup DB

In [None]:
load_dotenv()

if "ASTRA_DB_APPLICATION_TOKEN" not in os.environ:
    os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass.getpass("Please input your Astra DB Token:")

if "ASTRA_DB_API_ENDPOINT" not in os.environ:
    os.environ["ASTRA_DB_API_ENDPOINT"] = input("Please input your Astra DB API Endpoint:")

if "ASTRA_DB_KEYSPACE" not in os.environ:
    _namespace = input("(Optional) Input your Astra DB namespace if desired, or leave blank:")
    if _namespace:
        os.environ["ASTRA_DB_KEYSPACE"] = _namespace

ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE")

In [None]:
db = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN).get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT, namespace=ASTRA_DB_KEYSPACE)

## Raw data

It's the same as the previous demo:

In [None]:
raw_critters = [
    {
        "name": "Eratigena atrica",
        "family": "Agelenidae",
        "specs": {
            "speed": 0.8,   # m/s (max running speed)
            "size": 3,      # max body length, cm
            "threat": 5.5,  # 0=harmless, 5=sore skin, 10=lethal
            "eyesight": 5,  # cm
        },
    },
    {
        "name": "Salticus scenicus",
        "family": "Salticidae",
        "specs": {
            "speed": 0.3,
            "size": 0.4,
            "threat": 0,
            "eyesight": 35,
        },
    },
    {
        "name": "Holocnemus pluchei",
        "family": "Pholcidae",
        "specs": {
            "speed": 0.05,
            "size": 0.8,
            "threat": 0,
            "eyesight": 10,
        },
    },
    {
        "name": "Hogna radiata",
        "family": "Lycosidae",
        "specs": {
            "speed": 0.65,
            "size": 2,
            "threat": 4,
            "eyesight": 20,
        },
    },
    {
        "name": "Atrax robustus",
        "family": "Atracidae",
        "specs": {
            "speed": 0.40,
            "size": 5,
            "threat": 9,
            "eyesight": 15,
        },
    },
    {
        "name": "Argiope bruennichi",
        "family": "Araneidae",
        "specs": {
            "speed": 0.10,
            "size": 2.5,
            "threat": 6,
            "eyesight": 12,
        },
    },
    {
        "name": "Loxosceles rufescens",
        "family": "Sicariidae",
        "specs": {
            "speed": 0.45,
            "size": 0.8,
            "threat": 7.5,
            "eyesight": 8,
        },
    },
    {
        "name": "Scytodes thoracica",
        "family": "Scytodidae",
        "specs": {
            "speed": 0.15,
            "size": 0.6,
            "threat": 0,
            "eyesight": 10,
        },
    },
    {
        "name": "Phoneutria fera",
        "family": "Ctenizidae",
        "specs": {
            "speed": 0.75,
            "size": 4.8,
            "threat": 10,
            "eyesight": 35,
        },
    },
    {
        "name": "Uloborus plumipes",
        "family": "Uloboridae",
        "specs": {
            "speed": 0.25,
            "size": 1.4,
            "threat": 0,
            "eyesight": 18,
        },
    },
]

In [None]:
ranges = {}
traits = ["speed", "size", "threat", "eyesight"]
for trait in traits:
    max_val = max(critter["specs"][trait] for critter in raw_critters)
    min_val = min(critter["specs"][trait] for critter in raw_critters)
    print(f"{trait}: {min_val} to {max_val}")
    ranges[trait] = (min_val, max_val)

### Write entries with their raw "vectors"

(note: "vectors" in the broadest sense here...)

In [None]:
spiders_dot_collection = db.create_collection(
    "spiders_dot",
    dimension=4,
    metric=VectorMetric.DOT_PRODUCT,
    check_exists=False,
)
# just in case this demo is re-run
spiders_dot_collection.delete_all()

In [None]:
def make_list(specs):
    return [specs[trait] for trait in traits]

print(make_list(raw_critters[3]["specs"]))

In [None]:
spiders_dot_collection.insert_many(
    raw_critters,
    vectors=[make_list(raw_critter["specs"]) for raw_critter in raw_critters],
)

## Sort by a trait

Courtesy of a contrived use of Dot

In [None]:
# the index is not designed to deal with 'similarities' below 0:
# make sure it's the case with a rescaling factor:
_factor = 1.0 / (2.0 * max(abs(val) for rng in ranges.values() for val in rng))

def sorting_vector(trait, ascending=False):
    assert trait in traits
    return [
        0 if seq_trait != trait else (-_factor if ascending else +_factor)
        for seq_trait in traits
    ]

In [None]:
print(sorting_vector("speed"))
print(sorting_vector("speed", ascending=True))

In [None]:
def sorted_results(trait, ascending=False, n=3):
    query_vector = sorting_vector(trait, ascending)
    return spiders_dot_collection.find(
        vector=query_vector,
        limit=n,
        projection={"$vector": False},
        include_similarity=True,
    )

In [None]:
print("By speed:")
for cr_i, cr_doc in enumerate(sorted_results("speed")):
    print(f"  [{cr_i + 1}, sim={cr_doc['$similarity']:.3f}] '{cr_doc['name']}' ({cr_doc['family']}), {cr_doc['specs']}")

In [None]:
print("By size, ascending, top 5:")
for cr_i, cr_doc in enumerate(sorted_results("size", ascending=True, n=5)):
    print(f"  [{cr_i + 1}, sim={cr_doc['$similarity']:.3f}] '{cr_doc['name']}' ({cr_doc['family']}), {cr_doc['specs']}")

## Sort by any combination

- A trick: "move" normalizing rescaling to the query vector
- Limitation: assume traits scale from `[0 : M]` to `[0 : 1]`, i.e. fixed at zero

In [None]:
def multiscore_sorting_vector(trait_weights):
    assert(all(w >= 0 for w in trait_weights.values()))
    sum_w = sum(trait_weights.values())
    return [
        trait_weights.get(trait, 0) / (2.0 * sum_w * ranges[trait][1])
        for trait in traits
    ]

In [None]:
print(traits)

print(multiscore_sorting_vector({"speed": 1}))
print(multiscore_sorting_vector({"speed": 2, "size": 1}))

In [None]:
def multiscore_sorted_results(trait_weights, n=3):
    query_vector = multiscore_sorting_vector(trait_weights)
    return spiders_dot_collection.find(
        vector=query_vector,
        limit=n,
        projection={"$vector": False},
        include_similarity=True,
    )

In [None]:
print("Mainly by speed, with a little size:")
for cr_i, cr_doc in enumerate(multiscore_sorted_results({"speed": 2, "size": 1}, n=5)):
    print(f"  [{cr_i + 1}, sim={cr_doc['$similarity']:.3f}] '{cr_doc['name']}' ({cr_doc['family']}), {cr_doc['specs']}")

In [None]:
print("Mainly by size, with a little speed:")
for cr_i, cr_doc in enumerate(multiscore_sorted_results({"speed": 1, "size": 2}, n=5)):
    print(f"  [{cr_i + 1}, sim={cr_doc['$similarity']:.3f}] '{cr_doc['name']}' ({cr_doc['family']}), {cr_doc['specs']}")

## Cleanup

In [None]:
spiders_dot_collection.delete_all()

# free all resources with:
# spiders_dot_collection.drop()

## The End