# Explore embedding walks
To run this notebook, you must download the data from [sat-searcher on Google Drive](https://drive.google.com/drive/folders/1lac_YcJHp_6GlVFZo4AE3wN4qhx3mrct?usp=drive_link). Unzip the centroids and embeddings_8bit folders and place in the outputs directory.

Pick a start and end point. Visualize all nearest neighbors on a linear path between the two points.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import geopandas as gpd
import ipyleaflet as ipyl
import ipywidgets as ipyw
import matplotlib.pyplot as plt

import gee
import ee

import gee.utils as utils
import tools

ee.Initialize()

In [None]:
centroids, embeddings = tools.load_embeddings(
    centroid_dir='./outputs/centroids/', 
    embedding_dir='./outputs/embeddings/')
index = tools.index_embeddings(embeddings)

In [None]:
def get_tile_data(tile):
        """
        Download Sentinel-2 data for a tile.
        Inputs:
            - tile: a DLTile object
            - composite: a Sentinel-2 image collection
        Outputs:
            - pixels: a numpy array containing the Sentinel-2 data
        """
        # Harmonized Sentinel-2 Level 2A collection.
        s2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")

        # Cloud Score+ image collection. Note Cloud Score+ is produced from Sentinel-2
        # Level 1C data and can be applied to either L1C or L2A collections.
        csPlus = ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")
        QA_BAND = "cs_cdf"

        composite = (s2.filterDate('2023-01-01', '2024-01-01').linkCollection(
            csPlus, [QA_BAND]).map(lambda img: img.updateMask(
                img.select(QA_BAND).gte(0.7))).median())
        
        tile_geom = ee.Geometry.Rectangle(tile.geometry.bounds)
        composite_tile = composite.clipToBoundsAndScale(
            geometry=tile_geom, width=tile.tilesize, height=tile.tilesize)
        pixels = ee.data.computePixels({
            "bandIds": [
                "B2",
                "B3",
                "B4",
            ],
            "expression":
            composite_tile,
            "fileFormat":
            "NUMPY_NDARRAY",
            #'grid': {'crsCode': tile.crs} this was causing weird issues that I believe caused problems.
        })

        # convert from a structured array to a numpy array
        pixels = np.array(pixels.tolist())

        return pixels, tile

def get_image(x, y, tilesize):
    t = utils.Tile(y, x, tilesize)
    t.create_geometry()
    pixels, tile = get_tile_data(t)
    return pixels

def show_image(x, y, tilesize=32):
    pixels = get_image(x, y, tilesize)
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(pixels)
    ax.set_axis_off()
    return fig, ax


def show_interpolation(index_a, index_b, num_steps = 8):
    a_embedding = embeddings[index_a]
    b_embedding = embeddings[index_b]

    a_geom = centroids[index_a]
    b_geom = centroids[index_b]

    # interpolate between the two embeddings
    interp_embeddings = []
    for i in range(num_steps):
        interp_embeddings.append(a_embedding * (1 - i / num_steps) + b_embedding * (i / num_steps))
    interp_embeddings = np.array(interp_embeddings)
        

    # get the nearest neighbor to the interpolated embeddings
    interp_neighbors = []
    interp_scores = []
    for i in range(num_steps):
        print(f"matching point {i}", end="\r")
        indices, distances = tools.get_neighbors_faiss(interp_embeddings[i], index, n=2)
        interp_neighbors.append(indices[1])
        interp_scores.append(distances[1])

    plt.figure(figsize=(5, 5 * (num_steps + 2)), dpi=300)
    plt.subplot(1, num_steps + 2, 1)
    # plot the first image
    pixels = get_image(a_geom[0], a_geom[1], 32)
    plt.imshow(np.clip(pixels[:, :, (2,1,0)] / 2500, 0, 1))
    plt.axis('off')
    for num, i in enumerate(interp_neighbors):
        print(f"getting image {num}, {centroids[i][0]}, {centroids[i][1]}", end="\r")
        plt.subplot(1, num_steps + 2, num + 2)
        pixels = get_image(centroids[i][0], centroids[i][1], 32)
        plt.imshow(np.clip(pixels[:, :, (2,1,0)] / 2500, 0, 1))
        plt.axis('off')

    # plot the final image
    plt.subplot(1, num_steps + 2, num_steps + 2)
    pixels = get_image(b_geom[0], b_geom[1], 32)
    plt.imshow(np.clip(pixels[:, :, (2,1,0)] / 2500, 0, 1))
    plt.axis('off')
    plt.show()


In [None]:
boundaries = []
for f in ['israel', 'bali', 'alabama']:
    boundaries.append(gpd.read_file(f'./data/boundaries/{f}.geojson'))

center = [32.3182, -86.9023] # center on alabama
m = ipyl.Map(basemap=ipyl.basemaps.Esri.WorldImagery, center=center, zoom=8, scroll_wheel_zoom=True)
# set the map size
m.layout.height = '800px'

# add the boundaries to the map with a thin white line and transparent fill
for b in boundaries:
    m.add_layer(ipyl.GeoData(geo_dataframe=b,
                             style={'color': 'white', 'fillOpacity': 0.0}))
# create an empty layer that I can use to write geometries to on click
location_a = ipyl.GeoJSON(data={'type': 'FeatureCollection', 'features': []}, style={'color': 'yellow'})
location_b = ipyl.GeoJSON(data={'type': 'FeatureCollection', 'features': []}, style={'color': 'red'})

m.add_layer(location_a)
m.add_layer(location_b)

m.a = 0
m.b = 0

m.click_count = 0

# get coordinates when clicking on the map
def click(**kwargs):
    if kwargs.get('type') == 'click':

        click_y, click_x = kwargs.get('coordinates')
        tile_index, _ = tools.get_neighbors(np.array([click_x, click_y]),
                                      centroids,
                                      metric='euclid',
                                      n=1)
        tile_index = tile_index[0]
        tile_geom = tools.tile_from_point(centroids[tile_index][0],
                                    centroids[tile_index][1])
        # add the layer in blue to result data
        if m.click_count == 0:
            location_a.data = tile_geom
            m.a = tile_index
            m.click_count = 1
        elif m.click_count == 1:
            location_b.data = tile_geom
            m.b = tile_index
            m.click_count = 0
            show_interpolation(m.a, m.b, num_steps=20)
            
m.on_interaction(click)
m