# Early Interactive Demo of Satelite Similarity Search
To run this notebook, you must download the data from [sat-searcher on Google Drive](https://drive.google.com/drive/folders/1lac_YcJHp_6GlVFZo4AE3wN4qhx3mrct?usp=drive_link). Unzip the centroids and embeddings_8bit folders and place in the outputs directory.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import faiss
import numpy as np
import geopandas as gpd
import pandas as pd
import os
import shapely
import ipyleaflet as ipyl
import ipywidgets as ipyw
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
import matplotlib.pyplot as plt

import gee.utils as utils


In [None]:
def get_neighbors(search_vec, vectors, metric='cosine', n=5):
    """
    Find the n nearest neighbors to a search vector in a set of vectors.
    """
    # compute the similarity between the search vector and all vectors
    if metric == 'cosine':
        sims = cosine_similarity(search_vec.reshape(1, -1), vectors)[0]
        sorted_sims = np.argsort(sims)[::-1]
    elif metric == 'euclid':
        sims = euclidean_distances(search_vec.reshape(1, -1), vectors)[0]
        # sort the similarities in descending order
        sorted_sims = np.argsort(sims)
    scores = sims[sorted_sims]
    # return the indices of the n most similar vectors
    return sorted_sims[:n], scores[:n]

def get_neighbors_faiss(search_vec, index, n=5):
    distances, indices = index.search(np.expand_dims(search_vec, axis=0), n)
    return indices[0], distances[0]


def tile_from_point(x, y, size=32):
    # create a tile
    tile_geom = utils.Tile(y, x, size).create_geometry()
    # add to map
    return shapely.geometry.mapping(tile_geom)


def retrieve_neighbors(search_vec,
                       index,
                       map_data,
                       centroids,
                       threshold = 100,
                       n=100):
    neighbors, distances = get_neighbors_faiss(search_vec, index, n=n+1)
    neighbors = neighbors[distances < threshold][1:]
    print(f"Found {len(neighbors)} neighbors beneath threshold. Min distance: {distances[1]}, max distance: {distances[-1]}", end='\r')
    result_fc = {"type": "FeatureCollection", "features": []}
    # add the matching neighbors to the map
    for i, index in enumerate(neighbors):
        neighbor_geom = tile_from_point(centroids[index][0], centroids[index][1])
        result_fc['features'].append(neighbor_geom)
    map_data.data = result_fc
    #print(f"{len(result_fc['features'])} of {n} tiles added to the map", end='\r')


def normalize_and_clip(embeddings):
    # Step 1: Calculate mean and standard deviation
    mean = np.mean(embeddings)
    std_dev = np.std(embeddings)
    print(f"Mean: {np.mean(embeddings)}, std: {np.std(embeddings)}")
    # Step 2: Normalize using min-max scaling
    normalized_embeddings = (embeddings - mean) / std_dev

    # Step 3: Clip values beyond 4 standard deviations
    clipped_embeddings = np.clip(normalized_embeddings, -4, 4)

    # Map values to [0, 1] range
    min_value = clipped_embeddings.min()
    max_value = clipped_embeddings.max()
    print(f"Min value: {min_value}, max value: {max_value}")
    normalized_and_clipped_embeddings = (clipped_embeddings - min_value) / (max_value - min_value)

    return normalized_and_clipped_embeddings

In [None]:
centroid_dir = './outputs/centroids/'
embedding_dir = './outputs/embeddings_8bit/'
file_names = [f.split('.npy')[0] for f in os.listdir(centroid_dir) if f.endswith('.npy')]

centroids = []
embeddings = []
# load the centroids
for f in file_names:
    centroids.append(np.load(centroid_dir + f + '.npy'))
    embeddings.append(np.load(embedding_dir + f + '.npy'))
centroids = np.concatenate(centroids, axis=0)
embeddings = np.concatenate(embeddings, axis=0)
index = faiss.IndexFlatL2(len(embeddings[0]))
index.add(embeddings)
print(f"Loaded and indexed {len(embeddings):,} embeddings")

In [None]:
boundaries = []
for f in ['israel', 'bali', 'alabama']:
    boundaries.append(gpd.read_file(f'./data/boundaries/{f}.geojson'))

center = [32.3182, -86.9023] # center on alabama
m = ipyl.Map(basemap=ipyl.basemaps.Esri.WorldImagery, center=center, zoom=8, scroll_wheel_zoom=True)
# set the map size
m.layout.height = '800px'

# add the boundaries to the map with a thin white line and transparent fill
for b in boundaries:
    m.add_layer(ipyl.GeoData(geo_dataframe=b,
                             style={'color': 'white', 'fillOpacity': 0.0}))
# create an empty layer that I can use to write geometries to on click
search_data = ipyl.GeoJSON(data={'type': 'FeatureCollection', 'features': []}, style={'color': 'blue'})
result_data = ipyl.GeoJSON(data={'type': 'FeatureCollection', 'features': []}, style={'color': 'red'})
result_data_8_bit = ipyl.GeoJSON(data={'type': 'FeatureCollection', 'features': []}, style={'color': 'yellow'})

m.add_layer(search_data)
m.add_layer(result_data)
m.add_layer(result_data_8_bit)

# create two buttons on the map. One that is called m.pos and one that is called m.neg
# m.pos should be green and have a checkmark, m.neg should be red and have an x

m.pos = ipyw.Button(description='✅', button_style='success')
m.neg = ipyw.Button(description='❌', button_style='danger')
m.pos.on_click(lambda b: setattr(m, 'mode', 'pos'))
m.neg.on_click(lambda b: setattr(m, 'mode', 'neg'))
m.pos.layout.width = '35px'
m.neg.layout.width = '35px'
m.pos.layout.margin = '3px'
m.neg.layout.margin = '3px'

# put them on the map
mode_selector = ipyl.WidgetControl(widget=ipyw.VBox([m.pos, m.neg]),
                                   position='bottomright')
m.add_control(mode_selector)
m.mode = 'pos'


m.pos_vectors = []
m.neg_vectors = []

# get coordinates when clicking on the map
def click(**kwargs):
    if kwargs.get('type') == 'click':
        # clear the result data
        #search_data.data = {'type': 'FeatureCollection', 'features': []}
        click_y, click_x = kwargs.get('coordinates')
        tile_index, _ = get_neighbors(np.array([click_x, click_y]),
                                      centroids,
                                      metric='euclid',
                                      n=1)
        tile_index = tile_index[0]
        tile_geom = tile_from_point(centroids[tile_index][0],
                                    centroids[tile_index][1])
        # add the layer in blue to result data
        search_data.data = tile_geom
        tile_embedding = embeddings[tile_index]
        if m.mode == 'pos':
            m.pos_vectors.append(tile_embedding)
        elif m.mode == 'neg':
            m.neg_vectors.append(tile_embedding)
        
        if len(m.neg_vectors) == 0:
            search_embedding = np.mean(m.pos_vectors, axis=0)
        else:
            search_embedding = 2 * np.mean(m.pos_vectors, axis=0) - np.mean(m.neg_vectors, axis=0)
        retrieve_neighbors(search_embedding,
                           index,
                           result_data,
                           centroids,
                           threshold=200000,
                           n=200)
m.on_interaction(click)
m