# Demo 4 - Similar "products"

In [None]:
!pip install --quiet "astrapy>=1.0.0" "python-dotenv>=1.0.0" "matplotlib>=3.8.4"

In [None]:
import getpass
import os

import matplotlib.pyplot as plt
from dotenv import load_dotenv

from astrapy import DataAPIClient
from astrapy.constants import VectorMetric

## Setup DB

In [None]:
load_dotenv()

if "ASTRA_DB_APPLICATION_TOKEN" not in os.environ:
    os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass.getpass("Please input your Astra DB Token:")

if "ASTRA_DB_API_ENDPOINT" not in os.environ:
    os.environ["ASTRA_DB_API_ENDPOINT"] = input("Please input your Astra DB API Endpoint:")

if "ASTRA_DB_KEYSPACE" not in os.environ:
    _namespace = input("(Optional) Input your Astra DB namespace if desired, or leave blank:")
    if _namespace:
        os.environ["ASTRA_DB_KEYSPACE"] = _namespace

ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]
ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE")

In [None]:
db = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN).get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT, namespace=ASTRA_DB_KEYSPACE)

## Raw data

In [None]:
raw_critters = [
    {
        "name": "Eratigena atrica",
        "family": "Agelenidae",
        "specs": {
            "speed": 0.8,   # m/s (max running speed)
            "size": 3,      # max body length, cm
            "threat": 5.5,  # 0=harmless, 5=sore skin, 10=lethal
            "eyesight": 5,  # cm
        },
    },
    {
        "name": "Salticus scenicus",
        "family": "Salticidae",
        "specs": {
            "speed": 0.3,
            "size": 0.4,
            "threat": 0,
            "eyesight": 35,
        },
    },
    {
        "name": "Holocnemus pluchei",
        "family": "Pholcidae",
        "specs": {
            "speed": 0.05,
            "size": 0.8,
            "threat": 0,
            "eyesight": 10,
        },
    },
    {
        "name": "Hogna radiata",
        "family": "Lycosidae",
        "specs": {
            "speed": 0.65,
            "size": 2,
            "threat": 4,
            "eyesight": 20,
        },
    },
    {
        "name": "Atrax robustus",
        "family": "Atracidae",
        "specs": {
            "speed": 0.40,
            "size": 5,
            "threat": 9,
            "eyesight": 15,
        },
    },
    {
        "name": "Argiope bruennichi",
        "family": "Araneidae",
        "specs": {
            "speed": 0.10,
            "size": 2.5,
            "threat": 6,
            "eyesight": 12,
        },
    },
    {
        "name": "Loxosceles rufescens",
        "family": "Sicariidae",
        "specs": {
            "speed": 0.45,
            "size": 0.8,
            "threat": 7.5,
            "eyesight": 8,
        },
    },
    {
        "name": "Scytodes thoracica",
        "family": "Scytodidae",
        "specs": {
            "speed": 0.15,
            "size": 0.6,
            "threat": 0,
            "eyesight": 10,
        },
    },
    {
        "name": "Phoneutria fera",
        "family": "Ctenizidae",
        "specs": {
            "speed": 0.75,
            "size": 4.8,
            "threat": 10,
            "eyesight": 35,
        },
    },
    {
        "name": "Uloborus plumipes",
        "family": "Uloboridae",
        "specs": {
            "speed": 0.25,
            "size": 1.4,
            "threat": 0,
            "eyesight": 18,
        },
    },
]

In [None]:
ranges = {}
traits = ["speed", "size", "threat", "eyesight"]
for trait in traits:
    max_val = max(critter["specs"][trait] for critter in raw_critters)
    min_val = min(critter["specs"][trait] for critter in raw_critters)
    print(f"{trait}: {min_val} to {max_val}")
    ranges[trait] = (min_val, max_val)

### A (simplistic) way to make traits 'comparable'

_Nothing replaces careful inspection, statistics and human-made decisions here. Also nonlinear scales and/or [-1/+1] ranges might be a good idea ..._

In [None]:
def _rescale(val, v_range):
    if val < v_range[0]:
        return 0
    elif val > v_range[1]:
        return 1
    else:
        return (val - v_range[0]) / (v_range[1] - v_range[0])


def normalize_specs(raw_specs):
    return {
        trait: _rescale(tr_val, ranges[trait])
        for trait, tr_val in raw_specs.items()
    }

In [None]:
def make_list(specs):
    return [specs[trait] for trait in traits]

In [None]:
print(f"Specs normalization (traits = '{', '.join(traits)}'):")
for critter in raw_critters[:3]:
    print(f"    {make_list(critter['specs'])}", end=" ==> ")
    print(make_list(normalize_specs(critter["specs"])), end="")
    print(f", for '{critter['name']}'")

### A sample plot

In [None]:
trait_x = "eyesight"
trait_y = "speed"
raw_values_x = [critter["specs"][trait_x] for critter in raw_critters]
raw_values_y = [critter["specs"][trait_y] for critter in raw_critters]
norm_values_x = [normalize_specs(critter["specs"])[trait_x] for critter in raw_critters]
norm_values_y = [normalize_specs(critter["specs"])[trait_y] for critter in raw_critters]

fig = plt.figure
plt.scatter(raw_values_x, raw_values_y, s=150)
plt.title("Raw traits")
plt.xlabel(trait_x)
plt.ylabel(trait_y)
plt.show()

In [None]:
fig = plt.figure(figsize=(4, 4))
plt.scatter(norm_values_x, norm_values_y, s=150, color="purple")
plt.title("Normalized traits")
plt.xlabel(trait_x)
plt.ylabel(trait_y)
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.show()

## Write to DB

The vector comes from the _normalized_ traits (... or it will hardly make sense!)

In [None]:
spiders_collection = db.create_collection(
    "spiders",
    dimension=4,
    metric=VectorMetric.EUCLIDEAN,
    check_exists=False,
)
# just in case this demo is re-run
spiders_collection.delete_all()

In [None]:
spiders_collection.insert_many(
    raw_critters,
    vectors=[
        make_list(normalize_specs(critter["specs"]))
        for critter in raw_critters
    ]
)

## Search

Find a similar ~product~ spider:

In [None]:
def find_similar(query_specs):
    query_vector = make_list(normalize_specs(query_specs))
    results = spiders_collection.find(vector=query_vector, limit=2, include_similarity=True)
    return results

In [None]:
ref_specs = {
    "speed": 0.12,
    "size": 1.2,
    "threat": 4,
    "eyesight": 30,
}

print("Your results:")
for result_i, result in enumerate(find_similar(ref_specs)):
    print(f"  [{result_i + 1}] '{result['name']}' (fam. {result['family']})")

## Cleanup

In [None]:
spiders_collection.delete_all()

# free all resources with:
# spiders_collection.drop()