In [1]:
import base64
import json
from io import BytesIO
import time
import numpy as np
from PIL import Image
import faiss
import boto3
import psycopg
import requests
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
STACK_NAME = "pg-vectors-similarity-search-test"
cloudformation = boto3.client("cloudformation")
stack_outputs = cloudformation.describe_stacks(StackName=STACK_NAME)["Stacks"][0]["Outputs"]
[secrets_arn] = [o["OutputValue"] for o in stack_outputs if o["OutputKey"]== "databasesecretsarn"]
secretsmanager = boto3.client("secretsmanager")
database_secrets = json.loads(secretsmanager.get_secret_value(SecretId=secrets_arn)['SecretString'])
database_secrets = {**{k:v for k,v in database_secrets.items() if k in ["host", "password", "port", "dbname"]}, "user":database_secrets["username"]}

In [11]:
def b64_image_to_tensor(image: str) -> torch.Tensor:
    """convert input b64image to torch tensor"""
    # handle image
    img_bytes = base64.b64decode(image)
    tmp = BytesIO()
    tmp.write(img_bytes)
    try:
        img = Image.open(tmp)
    except UnidentifiedImageError:
        raise HTTPException(status_code=400, detail="Cannot recognize image format.")
    np_img = np.expand_dims(np.moveaxis(np.array(img), -1, 0), axis=0)
    if len(np_img.shape) != 4:
        raise HTTPException(
            status_code=400,
            detail=f"Image with shape {np.moveaxis(np_img, 1, -1).shape[1:]} is not processable. Use image with 3 channels.",
        )
    return torch.tensor(np_img).float()

def get_model():
    """start torch model"""
    # load encoder
    model = torch.jit.load("./encoder.pt", map_location=torch.device("cpu"))
    return model

def prep_query_vector(image_bytes): 

    tensor = b64_image_to_tensor(encoded_image_bytes)
    embedding = model(tensor)[0].detach().numpy()
    
    reduced_embedding = pca_matrix.apply(embedding)
    query_vector = str(reduced_embedding.tolist()[0])
    return query_vector

model = get_model()
pca_matrix = faiss.read_VectorTransform("./512_to_128_pca_matrix.pca")

gibs_image_url = "https://gibs.earthdata.nasa.gov/wmts/epsg3857/best/MODIS_Terra_CorrectedReflectance_TrueColor/default/2023-08-08/GoogleMapsCompatible_Level9/8/111/15.jpg"
encoded_image_bytes = base64.b64encode(requests.get(gibs_image_url).content)


In [26]:
start = time.time()

query_vector = prep_query_vector(encoded_image_bytes)
neighbors_query = f"""
SELECT *, embedding <-> '{query_vector}' as distance 
FROM images 
ORDER BY embedding <-> '{query_vector}' 
LIMIT 5"""

with psycopg.connect(**database_secrets) as conn:
    with conn.cursor() as cursor:     
        results = list(cursor.execute(neighbors_query).fetchall())

elapsed = round(time.time() - start, 2)
print(f"Total results: {len(results)}. Took {elapsed} seconds")

Total results: 5. Took 2.59 seconds


In [32]:
distance_query = f"""
SELECT *, embedding <-> '{query_vector}' as distance 
FROM images 
WHERE embedding <-> '{query_vector}' < 5 
ORDER BY embedding <-> '{query_vector}'
"""

start = time.time()
query_vector = prep_query_vector(encoded_image_bytes)
with psycopg.connect(**database_secrets) as conn:
    with conn.cursor() as cursor:     
        results = list(cursor.execute(distance_query).fetchall())

elapsed = round(time.time() - start, 2)
print(f"Total results: {len(results)}. Took {elapsed} seconds")

Total results: 938. Took 3.07 seconds


In [36]:
distance_query = f"""
SELECT *, embedding <-> '{query_vector}' as distance 
FROM images 
WHERE embedding <-> '{query_vector}' < 5 
AND datetime BETWEEN '2020-01-01'::timestamp AND '2020-06-01'::timestamp
ORDER BY embedding <-> '{query_vector}'
"""

start = time.time()
query_vector = prep_query_vector(encoded_image_bytes)
with psycopg.connect(**database_secrets) as conn:
    with conn.cursor() as cursor:     
        results = list(cursor.execute(distance_query).fetchall())

elapsed = round(time.time() - start, 2)
print(f"Total results: {len(results)}. Took {elapsed} seconds")

Total results: 305. Took 2.67 seconds


In [14]:
STACK_NAME = "similarity-search-api-v2-dev"
cloudformation = boto3.client("cloudformation")
stack_outputs = cloudformation.describe_stacks(StackName=STACK_NAME)["Stacks"][0]["Outputs"]
[endpoint_url] = [o["OutputValue"] for o in stack_outputs if o["OutputKey"].startswith("apiEndpoint")]


In [25]:
start = time.time()
resp = requests.post(f"https://{endpoint_url}/search", data=json.dumps({"image":encoded_image_bytes.decode(), "neighbors":3})).json()
elapsed = round(time.time() - start, 2)
print(f"Total results: {resp['numberMatched']}. Took {elapsed} seconds")

Total results: 3. Took 1.43 seconds


In [30]:
start = time.time()
resp = requests.post(f"https://{endpoint_url}/distance", data=json.dumps({"image":encoded_image_bytes.decode(), "distance":5})).json()
elapsed = round(time.time() - start, 2)
print(f"Total results: {resp}. Took {elapsed} seconds")

Total results: {'type': 'FeatureCollection', 'features': [{'type': 'Feature', 'bbox': [-157.5, -21.943045533438177, -135.0, 0.0], 'id': 'Tile(x=1, y=8, z=4)', 'geometry': {'type': 'Polygon', 'coordinates': [[[-157.5, -21.943045533438177], [-157.5, 0.0], [-135.0, 0.0], [-135.0, -21.943045533438177], [-157.5, -21.943045533438177]]]}, 'properties': {'title': 'XYZ tile Tile(x=1, y=8, z=4)', 'bbox': [-157.5, -21.943045533438177, -135.0, 0.0], 'bin_start_time': '2001-01-01T00:00:00+00:00', 'count': 6}}, {'type': 'Feature', 'bbox': [-135.0, -21.943045533438177, -112.5, 0.0], 'id': 'Tile(x=2, y=8, z=4)', 'geometry': {'type': 'Polygon', 'coordinates': [[[-135.0, -21.943045533438177], [-135.0, 0.0], [-112.5, 0.0], [-112.5, -21.943045533438177], [-135.0, -21.943045533438177]]]}, 'properties': {'title': 'XYZ tile Tile(x=2, y=8, z=4)', 'bbox': [-135.0, -21.943045533438177, -112.5, 0.0], 'bin_start_time': '2001-01-01T00:00:00+00:00', 'count': 3}}, {'type': 'Feature', 'bbox': [45.0, 0.0, 67.5, 21.943