# Run Patch Classifier

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import descarteslabs as dl
import json
from keplergl import KeplerGl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import sys
import shapely
from tensorflow import keras
from tqdm.notebook import tqdm

sys.path.append('../')
from scripts.get_s2_data_ee import get_history, band_descriptions
from scripts.viz_tools import *
from scripts.dl_utils import download_patch, rect_from_point, flatten_stack

## Load Model

In [None]:
model = keras.models.load_model('../models/v1.1.0_200_4-23-21_patch_classifier_45px_patch.h5')
input_width = model.input_shape[1]

# Get model input size in degrees
rect_width = np.round((input_width / 100) / 111.32, 4)

## Download Candidate Site Patches

In [None]:
sentinel_bands = ['coastal-aerosol',
                  'blue',
                  'green',
                  'red',
                  'red-edge',
                  'red-edge-2',
                  'red-edge-3',
                  'nir',
                  'red-edge-4',
                  'water-vapor',
                  'swir1',
                  'swir2']

In [None]:
# Load coordinates from the detect_candidates output
filename = 'Bali_v1.1.5_2019-2020_candidates_pred-thresh_0.8_min-sigma_5_area-thresh_0.0025'

candidate_sites = pd.read_csv('../data/model_outputs/candidate_sites/' + filename + '.csv')
candidate_coords = [[lon, lat] for lat, lon in zip(list(candidate_sites['lat']), list(candidate_sites['lon']))]
candidate_names = candidate_sites['name']
candidate_polygons = [rect_from_point(coord, rect_width) for coord in candidate_coords]

In [None]:
START_DATE = '2020-04-01'
END_DATE = '2020-05-31'
OUTPUT_DIR = '../data/training_data'
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)
    
img_stack = []
for polygon in tqdm(candidate_polygons):
     img_stack.append(download_patch(polygon, START_DATE, END_DATE))

## Run Network and Visualize Predictions

In [None]:
patches = []
for patch in img_stack:
    patch = np.ma.median(patch, axis=0)
    h, w, c = patch.shape
    if h < input_width or w < input_width:
        patch = np.pad(patch, input_width - np.min([h, w]), mode='reflect')
    # got a site with 14 bands?
    patch = patch[:input_width, :input_width, :len(sentinel_bands)]
    patches.append(patch)
rgb_img = create_rgb(patches)
patches = normalize(patches)
preds = model.predict(patches)[:,1]

In [None]:
num_img = int(np.ceil(np.sqrt(len(preds))))
plt.figure(figsize=(num_img, num_img), dpi=250, facecolor=(1,1,1))
for index, img in enumerate(rgb_img):
    plt.subplot(num_img, num_img, index + 1)
    plt.imshow(np.clip(img, 0, 1))
    plt.title(f"{candidate_names[index][-3:]}: {preds[index]:.2f}", size=6)
    plt.axis('off')
plt.tight_layout()
plt.show()

## Write Candidate Sites

In [None]:
threshold = 0.1

candidate_sites['patch_prediction'] = [-1] * len(candidate_sites)
for site, pred in zip(candidate_names, preds):
    candidate_sites.loc[candidate_sites['name'] == site, 'patch_prediction'] = pred

# Write only sites with predictions greater than a threshold, or with a value of -1 (no data)
filtered_candidate_sites = candidate_sites.query(f'patch_prediction > {threshold} or patch_prediction == -1')
print(f"{len(filtered_candidate_sites)} / {len(preds)} sites found above the threshold of {threshold}")
filtered_candidate_sites.to_csv(f'../data/model_outputs/candidate_sites/{filename}v1.1.0_200_patch_classifier_thresh_{threshold}.csv', index=False)

In [None]:
kepler_config = {
  "version": "v1",
  "config": {
     "visState": {
      "filters": [],
      "layers": [
        {
          "id": "iik903a",
          "type": "point",
          "config": {
            "dataId": "Candidate Sites",
            "label": "Point",
            "color": [
              218,
              0,
              0
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 30,
              "fixedRadius": False,
              "opacity": 0.82,
              "outline": True,
              "thickness": 3,
              "strokeColor": None,
              "colorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "strokeColorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "radiusRange": [
                0,
                20
              ],
              "filled": False
            },
            "hidden": False,
            "textLabel": [
              {
                "field": None,
                "color": [
                  255,
                  255,
                  255
                ],
                "size": 10,
                "offset": [
                  0,
                  0
                ],
                "anchor": "start",
                "alignment": "center"
              }
            ]
          },
          "visualChannels": {
            "colorField": None,
            "colorScale": "quantile",
            "strokeColorField": {
              "name": "patch_prediction",
              "type": "real"
            },
            "strokeColorScale": "quantile",
            "sizeField": None,
            "sizeScale": "linear"
          }
        }
      ],
    },
    "mapStyle": {
      "styleType": "satellite",
    }
  }
}

In [None]:
positive_patches = np.array(rgb_img)[preds > threshold]
num_img = int(np.ceil(np.sqrt(len(positive_patches))))
plt.figure(figsize=(num_img, num_img), dpi=250, facecolor=(1,1,1))
for index, img in enumerate(positive_patches):
    plt.subplot(num_img, num_img, index + 1)
    plt.imshow(np.clip(img, 0, 1))
    plt.title(f"Site {filtered_candidate_sites.iloc[index]['name'].split('_')[-1]}: {filtered_candidate_sites.iloc[index]['patch_prediction']:.2f}", size=6)
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Plot blob locations on a satellite base image
candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map.add_data(data=filtered_candidate_sites, name='Candidate Sites')
#candidate_map.add_data(data=candidate_sites, name='Candidate Sites')
candidate_map

In [None]:
filtered_candidate_sites