# Run Blob Detection on Pixel Heatmap to Identify Candidate Sites
Note: This is only working on inputs with EPSG CRS 4326. I may need to make it more general in the future.

In [None]:
from keplergl import KeplerGl
import matplotlib.pyplot as plt
import numpy as np
import geopandas as gpd
import rasterio as rs
from rasterio.windows import Window
from skimage.feature import blob_doh
from skimage.feature.peak import peak_local_max
from sklearn.neighbors import KDTree

In [None]:
def merge_similar_sites(candidate_sites, search_radius=0.01):
    """
    This process iteratively moves points closer together by taking the mean position of all
    matched points. It then searches the KD tree using the unique clusters in these new points.
    The algorithm stops once the number of unique sites is the same as in the previous round.
    
    search_radius is given in degrees
    """
    coords = np.array(candidate_sites)
    
    # Create a KD tree for efficient lookup of points within a radius
    tree = KDTree(coords, leaf_size=2)
    
    # Initialize a mean_coords array for the search
    mean_coords = []
    for elem in tree.query_radius(coords, search_radius):
        mean_coords.append(np.mean(coords[elem], axis=0))
    mean_coords = np.array(mean_coords)
    
    num_coords = len(mean_coords)
    while True:
        search = tree.query_radius(mean_coords, search_radius)
        uniques = [list(x) for x in set(tuple(elem) for elem in search)]
        mean_coords = []
        for elem in uniques:
            mean_coords.append(np.mean(coords[elem], axis=0))
        if len(mean_coords) == num_coords:
            print(len(mean_coords), "unique sites detected")
            mean_coords = np.array(mean_coords)
            break
        num_coords = len(mean_coords)
        
    unique_sites = gpd.GeoDataFrame(mean_coords, columns=['lon', 'lat'], geometry=gpd.points_from_xy(*mean_coords.T))
    unique_sites['name'] = [f"{name}_{i+1}" for i in unique_sites.index]
    plt.figure(figsize=(10,10), dpi=150, facecolor=(1,1,1))
    plt.scatter(coords[:,0], coords[:,1], s=5, label='Original')
    plt.scatter(mean_coords[:,0], mean_coords[:,1], s=3, c='r', label='Unique')
    plt.axis('equal')
    plt.legend()
    plt.show()
    
    return unique_sites

def detect_blobs(source, name, pred_threshold=0.75, min_sigma=3.5, max_sigma=100, area_threshold=0.0025, window_size=5000, save=True):
    """
    Identify candidates using blob detection on the heatmap.
    prediction_threshold masks any prediction below a 0-1 threshold.
    min_sigma and area_threshold control the size sensitivity of the blob detection.
    Keep min_sigma low to detect smaller blobs
    area_threshold establishes a lower bound on candidate blob size. Reduce to detect smaller blobs
    """
    candidate_sites = []
    max_val = source.read(1).max()
    for x in range(0, source.shape[0], window_size):
        for y in range(0, source.shape[1], window_size):
            print(f"Processing row {(x // window_size) + 1} of {int(source.shape[0] / window_size) + 1}, column {(y // window_size) + 1} of {int(source.shape[1] / window_size) + 1}")
            # Set min and max to analyze a subset of the image
            window = Window.from_slices((x,x + window_size), (y, y + window_size))
            window_median = (source.read(1, window=window) / max_val).astype('float')
            # mask predictions below a threshold
            mask = np.ma.masked_where(window_median < pred_threshold, window_median).mask
            window_median[mask] = 0

            blobs = blob_doh(window_median, min_sigma=min_sigma, max_sigma=max_sigma, threshold=area_threshold)
            print(len(blobs), "candidates detected in window")
            
            overlap_threshold = 0.01
            transform = source.window_transform(window)
            for candidate in blobs:
                lon, lat = (transform * [candidate[1], candidate[0]])
                # Size doesn't mean anything at the moment. Should look into this later
                #size = candidate[2]
                candidate_sites.append([lon, lat])
    
    print(len(candidate_sites), "candidate sites detected in total")
    
    candidate_gdf = merge_similar_sites(candidate_sites, search_radius=0.01)
    
    if save:
        file_path = f"../data/model_outputs/candidate_sites/{name}_blobs_thresh_{pred_threshold}_min-sigma_{min_sigma}_area-thresh_{area_threshold}"
        candidate_gdf.loc[:, ['lon', 'lat', 'name']].to_csv(file_path + '.csv', index=False)
        candidate_gdf.to_file(file_path + '.geojson', driver='GeoJSON')
    
    return candidate_gdf

def detect_peaks(source, name, threshold_abs=0.85, min_distance=100, window_size=5000, save=True):
    """
    Identify candidates using heatmap peak detection.
    Inputs:
      source: rasterio geotiff object
      name: file name
      threshold_abs: threshold for minimum prediction value
      min_distance: candidates within this distance will be merged by default. Distance in pixel space
      window_size: chunk the image into windows to reduce memory load
      save: boolean to write outputs to disk
    """
    candidate_peaks = []
    for x in range(0, source.shape[0], window_size):
        for y in range(0, source.shape[1], window_size):
            window = Window.from_slices((x,x + window_size), (y, y + window_size))
            transform = source.window_transform(window)
            subset = source.read(1, window=window)
            peaks = peak_local_max(subset, threshold_abs=threshold_abs, min_distance=min_distance)
            for candidate in peaks:
                lon, lat = (transform * [candidate[1], candidate[0]])
                candidate_peaks.append([lon, lat])
    print(len(candidate_peaks), "peaks detected")
    candidate_peaks = np.array(candidate_peaks)
    
    candidate_gdf = gpd.GeoDataFrame(candidate_peaks, columns=['lon', 'lat'], 
                                     geometry=gpd.points_from_xy(*candidate_peaks.T))
    candidate_gdf['name'] = [f"{name}_{i+1}" for i in candidate_gdf.index]
    
    if save:
        file_path = f"../data/model_outputs/candidate_sites/{name}_peaks_thresh_{threshold_abs}_min_dist_{min_distance}"
        candidate_gdf.loc[:, ['lon', 'lat', 'name']].to_csv(file_path + '.csv', index=False)
        candidate_gdf.to_file(file_path + '.geojson', driver='GeoJSON')
    
    return candidate_gdf

In [None]:
kepler_config = {
  "version": "v1",
  "config": {
    "visState": {
      "layers": [
        {
          "id": "iik903a",
          "type": "point",
          "config": {
            "dataId": "Candidate Peaks",
            "label": "Point",
            "color": [
              218,
              0,
              0
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 20,
              "fixedRadius": False,
              "opacity": 0.99,
              "outline": True,
              "thickness": 3,
              "strokeColor": [
                210,
                0,
                0
              ],
              "filled": False
            },
          },
        },
        {
          "id": "kyoc7uj",
          "type": "point",
          "config": {
            "dataId": "Candidate Blobs",
            "label": "Point",
            "color": [
              246,
              218,
              0
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 20,
              "fixedRadius": False,
              "opacity": 0.99,
              "outline": True,
              "thickness": 3,
              "strokeColor": [
                246,
                218,
                0
              ],
              "filled": False
            },
          },
        }
      ],
    },
    "mapStyle": {
      "styleType": "satellite",
    }
  }
}

In [None]:
candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map

In [None]:
name = 'Bali_v1.1.5_2019-2020'
source = rs.open(f'../data/model_outputs/heatmaps/{name}.tif')

In [None]:
peaks = detect_peaks(source, name=name, min_distance=100, threshold_abs=0.85, save=True)
candidate_map.add_data(data=peaks, name='Candidate Peaks')

In [None]:
blobs = detect_blobs(source, name=name, pred_threshold=0.8, save=True)
candidate_map.add_data(data=blobs, name='Candidate Blobs')