## Load all tiles as gdf, this is easier than a duckdb spatial table but less scalable

In [5]:
%load_ext autoreload
%autoreload 2

import geopandas as gpd
import duckdb
import os
from tqdm import tqdm

valid_tile_dir = "/home/christopher.x.ren/embeddings/ra_tea/valid_tiles"

from datetime import datetime
import json
import os

import annoy
import geopandas as gpd
import ipyleaflet as ipyl
from IPython.display import display
import ipywidgets as ipyw
import joblib
import numpy as np
import pandas as pd

import sys
sys.path.insert(0, '/home/christopher.x.ren/earth-index-ml/demeter/tea/src')

from ui import GeoLabeler


# TODO: all this should go into init function of GeoLabeler
with open('../config/ra_tea_aoi.json', 'r') as f:
    config = json.load(f)

local_dir = config['local_dir']
annoy_index_path = os.path.join(local_dir, 'embeddings.ann')
annoy_index = annoy.AnnoyIndex(config['index_dim'], 'angular')  # 384 dimensions for ViT-DINO embeddings
annoy_index.load(annoy_index_path)
tile_centroid_path = os.path.join(local_dir, 'centroid_gdf.parquet')
tile_centroid_gdf = gpd.read_parquet(tile_centroid_path)
duckdb_path = os.path.join(local_dir, 'embeddings.db')
embeddings_con = duckdb.connect(duckdb_path)

mgrs_ids = config['mgrs_ids']
start_date = config['start_date']
end_date = config['end_date']
imagery = config['imagery']

gdf = gpd.read_parquet(tile_centroid_path)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:

from ui import GeoLabeler

BOUNDARY_PATH = "/home/christopher.x.ren/earth-index-ml/places/ra_aoi_indonesia.geojson"
BOUNDARY = gpd.read_file(BOUNDARY_PATH)
maptiler_attribution = '<a href="https://www.maptiler.com/copyright/" target="_blank">&copy; MapTiler</a> <a href="https://www.openstreetmap.org/copyright" target="_blank">&copy; OpenStreetMap contributors</a>'

labeler = GeoLabeler(gdf=tile_centroid_gdf,
                    geojson_path=BOUNDARY_PATH,
                    mgrs_ids=mgrs_ids,
                    start_date=start_date,
                    end_date=end_date,
                    imagery=imagery,
                    annoy_index=annoy_index,
                    duckdb_connection=embeddings_con,
                    attribution=maptiler_attribution)


label = ipyw.Label(); display(label)  

def handle_mouse_move(**kwargs):
    lat, lon = kwargs.get('coordinates')
    label_type = "Erase" if labeler.select_val == -100 else "Negative" if labeler.select_val == 0 else "Positive"
    label.value = f'Lat/lon: {lat:.4f}, {lon:.4f}. Mode: {"lasso" if labeler.lasso_mode else "single"}. Labeling: {label_type}'

labeler.map.on_interaction(handle_mouse_move)

Initializing GeoLabeler...
Adding controls...


VBox(children=(Map(center=[-1.7088306035977048, 102.5723006937729], controls=(ZoomControl(options=['position',…

Label(value='')

## Search, first one takes a while as the table is loaded into memory

## Make sure to set total acceptable memory usage on connection too, unlimited the Java + Sumatra embeddings take up about 60 GB

In [9]:
# TODO: this should probably be a class method and internal
pos = labeler.gdf.loc[labeler.pos_ids]
neg = labeler.gdf.loc[labeler.neg_ids]
pos_embeddings = labeler.get_embeddings_by_tile_ids(pos['tile_id'].values)
if len(neg) > 0:
    neg_embeddings = labeler.get_embeddings_by_tile_ids(neg['tile_id'].values)
    neg_vec = neg_embeddings.drop(columns=['tile_id', 'row_number']).mean(axis=0).values
else:
    neg_vec = np.zeros(pos_embeddings.shape[1] - 2)
pos_vec = pos_embeddings.drop(columns=['tile_id', 'row_number']).mean(axis=0).values
query_vector = 2 * pos_vec - neg_vec

n_nbors = 2000
nbors = labeler.annoy_index.get_nns_by_vector(query_vector, n_nbors, include_distances=True)
# Filter out any IDs that are already in positive labels
nbors_filtered = [n for n in nbors[0] if n not in labeler.pos_ids]

detections = labeler.gdf.loc[nbors_filtered]
labeler.detection_gdf = detections[['geometry']]
labeler.update_layer(
    labeler.points, json.loads(detections.geometry.to_json()))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [9]:
pos_gdf = labeler.gdf.loc[labeler.pos_ids]
pos_gdf.to_parquet("/home/christopher.x.ren/datasets/ra_tea/pos_gdf_v1_sumatra_2024-11-10.parquet")
neg_gdf = labeler.gdf.loc[labeler.neg_ids]
neg_gdf.to_parquet("/home/christopher.x.ren/datasets/ra_tea/neg_gdf_v1_sumatra_2024-11-10.parquet")