# Run Patch Classifier

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
from keplergl import KeplerGl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from tensorflow import keras
from tqdm.notebook import tqdm

from scripts.viz_tools import normalize, plot_image_grid
from scripts.dl_utils import download_patch, rect_from_point, pad_patch

In [None]:
def patch_classifier_predict(polygon, model, start_date, end_date):
    """
    Run a patch classifier on the polygon of interest.
    Outputs predictions and patches for each patch extracted.
    """
    input_width = model.input_shape[1]
    
    img_stack = download_patch(polygon, start_date, end_date)
    img_stack = [pad_patch(img, input_width) for img in img_stack]
    
    preds = []
    patches = []
    for patch in img_stack:
        if np.sum(patch.mask) / patch.size < cloud_threshold:
            patch = pad_patch(patch, input_width)
            patches.append(patch)
            preds.append(model.predict(np.expand_dims(normalize(patch), axis=0))[0,1])
    assert len(preds) == 0, "No cloud free patches extracted. Try expanding your data time period."
    return preds, patches

## Load Model

In [None]:
model = keras.models.load_model('../models/v1.1.0_200_4-23-21_patch_classifier_45px_patch.h5')
input_width = model.input_shape[1]

# Get model input size in degrees
rect_width = np.round((input_width / 100) / 111.1, 4)

## Download Candidate Site Patches

In [None]:
# Load coordinates from the detect_candidates output
filename = 'Nusa_Tenggara_v1.1.8_2019-2020_candidates_pred-thresh_0.645_min-sigma_3.5_area-thresh_0.0025'

candidate_sites = pd.read_csv('../data/model_outputs/candidate_sites/' + filename + '.csv')
candidate_coords = [[lon, lat] for lat, lon in zip(list(candidate_sites['lat']), list(candidate_sites['lon']))]
candidate_names = candidate_sites['name']
candidate_polygons = [rect_from_point(coord, rect_width) for coord in candidate_coords]

In [None]:
start_date = '2020-01-01'
end_date = '2020-06-01'
output_dir = '../data/model_outputs/candidate_sites/'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

## Run Network and Visualize Predictions

In [None]:
cloud_threshold = 0.1

patch_predictions = {}
for polygon, name in tqdm(zip(candidate_polygons, candidate_names), total=len(candidate_polygons)):
    preds, patches = patch_classifier_predict(polygon, model, start_date, end_date)
    
    patch_predictions[name] = {
        'preds': preds,
        'patches': patches,
    }

In [None]:
file_path = os.path.join(output_dir, filename)

images = []
labels = []
for site in patch_predictions:
    images.append(np.ma.median(patch_predictions[site]['patches'], axis=0))
    labels.append(f"{site.split('_')[-1]}: {np.mean(patch_predictions[site]['preds']):.2f}")
plot_image_grid(images, labels=labels, file_path=file_path)

## Write Candidate Sites

In [None]:
mean_preds = [np.mean(patch_predictions[site]['preds']) for site in patch_predictions]
var_preds = [np.var(patch_predictions[site]['preds']) for site in patch_predictions]

candidate_sites['mean'] = [np.mean(patch_predictions[site]['preds']) for site in patch_predictions]
candidate_sites['median'] = [np.median(patch_predictions[site]['preds']) for site in patch_predictions]
candidate_sites['min'] = [np.min(patch_predictions[site]['preds']) for site in patch_predictions]
candidate_sites['max'] = [np.max(patch_predictions[site]['preds']) for site in patch_predictions]
candidate_sites['variance'] = [np.var(patch_predictions[site]['preds']) for site in patch_predictions]

In [None]:
threshold = 0.25
# Write only sites with predictions greater than a threshold, or with a value of -1 (no data)
filtered_candidate_sites = candidate_sites.query(f'mean > {threshold}')
print(f"{len(filtered_candidate_sites)} / {len(preds)} sites found above the threshold of {threshold}")
filtered_candidate_sites.to_csv(f'../data/model_outputs/candidate_sites/{filename}_patch_clf_thresh_{threshold}.csv', index=False)

In [None]:
filtered_candidate_sites

In [None]:
kepler_config={
  "version": "v1",
  "config": {
    "visState": {
      "filters": [],
      "layers": [
        {
          "id": "anqfulm",
          "type": "point",
          "config": {
            "dataId": "Candidate Sites",
            "label": "Point",
            "color": [
              221,
              178,
              124
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 30,
              "fixedRadius": False,
              "opacity": 0.8,
              "outline": True,
              "thickness": 3,
              "strokeColor": None,
              "colorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "strokeColorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "radiusRange": [
                0,
                50
              ],
              "filled": False
            },
            "hidden": False,
            "textLabel": [
              {
                "field": None,
                "color": [
                  255,
                  255,
                  255
                ],
                "size": 18,
                "offset": [
                  0,
                  0
                ],
                "anchor": "start",
                "alignment": "center"
              }
            ]
          },
          "visualChannels": {
            "colorField": None,
            "colorScale": "quantile",
            "strokeColorField": {
              "name": "mean",
              "type": "real"
            },
            "strokeColorScale": "quantile",
            "sizeField": None,
            "sizeScale": "linear"
          }
        }
      ],
      },
      "mapStyle":{
         "styleType":"satellite"
      }
   }
}

In [None]:
positive_patches = [np.ma.median(patch_predictions[site]['patches'], axis=0) for site in filtered_candidate_sites['name']]
plot_image_grid(positive_patches, labels=[name.split('_')[-1] for name in filtered_candidate_sites['name']])

In [None]:
# Plot blob locations on a satellite base image
candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map.add_data(data=filtered_candidate_sites, name='Candidate Sites')
candidate_map