# Run Patch Classifier

In [None]:
import json
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from tensorflow import keras

sys.path.append('../')
from scripts.get_s2_data_ee import get_history, band_descriptions
from scripts.viz_tools import *

%load_ext autoreload
%autoreload 2

## Load Model

In [None]:
model = keras.models.load_model('../models/v1.1.0_200_4-23-21_patch_classifier_45px_patch.h5')
input_width = model.input_shape[1]

# Get model input size in degrees
rect_width = np.round((input_width / 100) / 111.32, 4)

## Download Candidate Site Patches

In [None]:
# Load coordinates from the detect_candidates output
filename = 'lombok_v1.1_2020_candidates_pred-thresh_0.85_min-sigma_5_area-thresh_0.0025'
candidate_sites = pd.read_csv('../data/model_outputs/candidate_sites/' + filename + '.csv')
candidate_coords = [[lon, lat] for lat, lon in zip(list(candidate_sites['lat']), list(candidate_sites['lon']))]
candidate_names = candidate_sites['name']

start_date = '2020-06-01'
num_months = 1
patch_history = get_history(candidate_coords, 
                            candidate_names, 
                            rect_width,
                            num_months = num_months,
                            start_date = start_date,
                            cloud_mask=True)


In [None]:
dates = list(patch_history.keys())
sites = list(patch_history[dates[0]].keys())

cloud_threshold = 0.05
cloud_free_sites = []
for site in sites:
    site_cloudiness = []
    for date in dates:
        cloudiness = np.mean([np.sum(patch_history[date][site][band] < 0) / np.size(patch_history[date][site][band]) for band in band_descriptions])
        site_cloudiness.append(cloudiness)
    #print(site, "min clouds:", np.min(site_cloudiness))
    if np.min(site_cloudiness) < cloud_threshold:
        cloud_free_sites.append(site)
print(f"{len(cloud_free_sites) / (len(sites)):.0%} percent of sites have less than {cloud_threshold:.0%} cloud cover")

## Run Network and Visualize Predictions

In [None]:
patches = np.array(create_img_stack_mean(patch_history, cloud_threshold))
patches = [patch[:input_width, :input_width] for patch in patches]
rgb_img = create_rgb(patches)
patches = normalize(patches)
preds = model.predict(patches)[:,1]

In [None]:
num_img = int(np.ceil(np.sqrt(len(preds))))
plt.figure(figsize=(num_img, num_img), dpi=250, facecolor=(1,1,1))
for index, img in enumerate(rgb_img):
    plt.subplot(num_img, num_img, index + 1)
    plt.imshow(stretch_histogram(img))
    plt.title(f"{cloud_free_sites[index][-3:]}: {preds[index]:.2f}", size=6)
    plt.axis('off')
plt.tight_layout()
plt.show()

## Write Candidate Sites

In [None]:
threshold = 0.1

candidate_sites['patch_prediction'] = [-1] * len(candidate_sites)
for site, pred in zip(cloud_free_sites, preds):
    candidate_sites.loc[candidate_sites['name'] == site, 'patch_prediction'] = pred

# Write only sites with predictions greater than a threshold, or with a value of -1 (no data)
filtered_candidate_sites = candidate_sites.query(f'patch_prediction > {threshold} or patch_prediction == -1')
print(f"{len(filtered_candidate_sites)} / {len(preds)} sites found above the threshold of {threshold}")
filtered_candidate_sites.to_csv(f'../data/model_outputs/candidate_sites/{filename}v1.1.0_200_patch_classifier_thresh_{threshold}.csv', index=False)

In [None]:
filtered_candidate_sites

In [None]:
kepler_config = {
  "version": "v1",
  "config": {
     "visState": {
      "filters": [],
      "layers": [
        {
          "id": "iik903a",
          "type": "point",
          "config": {
            "dataId": "Candidate Sites",
            "label": "Point",
            "color": [
              218,
              0,
              0
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 30,
              "fixedRadius": False,
              "opacity": 0.82,
              "outline": True,
              "thickness": 3,
              "strokeColor": None,
              "colorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "strokeColorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "radiusRange": [
                0,
                20
              ],
              "filled": False
            },
            "hidden": False,
            "textLabel": [
              {
                "field": None,
                "color": [
                  255,
                  255,
                  255
                ],
                "size": 10,
                "offset": [
                  0,
                  0
                ],
                "anchor": "start",
                "alignment": "center"
              }
            ]
          },
          "visualChannels": {
            "colorField": None,
            "colorScale": "quantile",
            "strokeColorField": {
              "name": "patch_prediction",
              "type": "real"
            },
            "strokeColorScale": "quantile",
            "sizeField": None,
            "sizeScale": "linear"
          }
        }
      ],
    },
    "mapStyle": {
      "styleType": "satellite",
    }
  }
}

In [None]:
# Plot blob locations on a satellite base image
from keplergl import KeplerGl

candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map.add_data(data=filtered_candidate_sites, name='Candidate Sites')
#candidate_map.add_data(data=candidate_sites, name='Candidate Sites')
candidate_map