# Run Blob Detection on Pixel Heatmap to Identify Candidate Sites
Note: This is only working on inputs with EPSG CRS 4326. I may need to make it more general in the future.

In [None]:
import json

import geojson
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio as rs
from rasterio.windows import Window
from rasterio import warp
from skimage.feature import blob_doh
from sklearn.neighbors import KDTree

In [None]:
def pixels_to_coords(x, y, src_img):
    lon, lat = warp.transform(src_img.crs, rs.crs.CRS.from_epsg(4326), [src_img.xy(x, y)[0]], [src_img.xy(x, y)[1]])
    return lon[0], lat[0]

def coords_to_pixels(lon, lat, src_img):
    x, y = warp.transform(rs.crs.CRS.from_epsg(4326), src_img.crs, [lon],  [lat])
    pixel_y, pixel_x = src_img.index(x, y)
    return pixel_x[0], pixel_y[0]


def detect_candidates(source_img, name, pred_threshold=0.75, min_sigma=3.5, max_sigma=100, area_threshold=0.0025, window_size=5000):
    """
    Identify candidates using blob detection on the heatmap.
    prediction_threshold masks any prediction below a 0-1 threshold.
    min_sigma and area_threshold control the size sensitivity of the blob detection.
    Keep min_sigma low to detect smaller blobs
    area_threshold establishes a lower bound on candidate blob size. Reduce to detect smaller blobs
    """
    candidate_sites = []
    max_val = source.read(1).max()
    for x in range(0, source.shape[0], window_size):
        for y in range(0, source.shape[1], window_size):
            print(f"Processing row {(x // window_size) + 1} of {int(source.shape[0] / window_size) + 1}, column {(y // window_size) + 1} of {int(source.shape[1] / window_size) + 1}")
            # Set min and max to analyze a subset of the image
            window = Window.from_slices((x,x + window_size), (y, y + window_size))
            window_median = (source.read(1, window=window) / max_val).astype('float')
            # mask predictions below a threshold
            mask = np.ma.masked_where(window_median < pred_threshold, window_median).mask
            window_median[mask] = 0

            blobs = blob_doh(window_median, min_sigma=min_sigma, max_sigma=max_sigma, threshold=area_threshold)
            print(len(blobs), "candidates detected in window")
            
            overlap_threshold = 0.01
            transform = source.window_transform(window)
            for candidate in blobs:
                #lat, lon = pixels_to_coords(*(transform * [candidate[1], candidate[0]]), source)
                lon, lat = (transform * [candidate[1], candidate[0]])
                # Size doesn't mean anything at the moment. Should look into this later
                #size = candidate[2]
                candidate_sites.append([lon, lat])
    
    print(len(candidate_sites), "candidate sites detected in total")
    
    # Combine nearby points into a single point with a median value of the set of grouped points
    nearest_neighbor_threshold = 0.01
    tree = KDTree(np.array(candidate_sites), leaf_size=2)
    neighbor_dict = {}
    for index, point in enumerate(candidate_sites):
        neighbors = tree.query_radius([point], 0.01)[0]
        neighbor_dict[index] = list(neighbors)
    unique_sites = []
    for elem in np.unique(list(neighbor_dict.values())):
        unique_sites.append(list(np.median(np.array(candidate_sites)[elem], axis=0)))
        
    print(len(unique_sites), "unique candidate sites detected")
    
    candidate_site_df = pd.DataFrame(unique_sites, columns=['lon', 'lat'])
    candidate_site_df['name'] = [f"{name}_{i+1}" for i in range(len(unique_sites))]
    
    # Write candidates to CSV and GeoJSON
    file_path = f"../data/model_outputs/candidate_sites/{name}_candidates_pred-thresh_{pred_threshold}_min-sigma_{min_sigma}_area-thresh_{area_threshold}"
    candidate_site_df.to_csv(file_path + '.csv', index = False)
    
    geojson_features = []
    for i, site in enumerate(unique_sites):
        geojson_features.append(geojson.Feature(geometry = geojson.Point(site),
                                                properties={'name': f"{name}_{i+1}"}))
    feature_collection = geojson.FeatureCollection(geojson_features)
    with open(file_path + '.geojson', 'w') as f:
        f.write(geojson.dumps(feature_collection))
    
    return candidate_site_df

In [None]:
#source = rs.open('../data/model_outputs/heatmaps/tpa_nn_toa_main_island_median.tif')
source = rs.open('../data/model_outputs/heatmaps/bali_v12_wgs84.tif')

In [None]:
name = 'java_v12'

# These values are tuned for Bali detections and are still a work in progress
pred_threshold = 0.75
min_sigma=3.5
max_sigma=100
area_threshold=0.0025

candidate_site_df = detect_candidates(source, name, pred_threshold=pred_threshold, min_sigma=min_sigma, max_sigma=max_sigma, area_threshold=area_threshold, window_size=5000)

In [None]:
kepler_config = {
  "version": "v1",
  "config": {
    "visState": {
      "layers": [
        {
          "id": "iik903a",
          "type": "point",
          "config": {
            "dataId": "Candidate Sites",
            "label": "Point",
            "color": [
              218,
              0,
              0
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 20,
              "fixedRadius": False,
              "opacity": 0.99,
              "outline": True,
              "thickness": 3,
              "strokeColor": [
                210,
                0,
                0
              ],
              "filled": False
            },
          },
        }
      ],
    },
    "mapStyle": {
      "styleType": "satellite",
    }
  }
}

In [None]:
# Plot blob locations on a satellite base image
from keplergl import KeplerGl

candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map.add_data(data=candidate_site_df, name='Candidate Sites')
candidate_map