# Run Temporal Patch Classifier
Currently set up to run a temporal patch classifier that takes inputs of shape `(batch_size, h, w, 24)`

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
from keplergl import KeplerGl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import geopandas as gpd
from tensorflow import keras
from tqdm.notebook import tqdm

from scripts.viz_tools import normalize, plot_image_grid
from scripts.dl_utils import download_patch, rect_from_point, pad_patch, unit_norm
from scripts import dl_utils

In [None]:
START_DATE = '2019-01-01'
END_DATE = '2020-02-01'
METHOD = 'min'
MOSAIC_PERIOD = 3  # the period over which to mosaic image data in months
SPECTROGRAM_INTERVAL = 2  # For spectrogram analysis, the time from the start of one mosaic to the start of the next,
 # in number of mosaic periods

In [None]:
def patch_classifier_predict(polygon, model):
    """
    Run a patch classifier on the polygon of interest.
    Outputs predictions and patches for each patch extracted.
    """
    
    mosaics, _ = dl_utils.download_mosaics(polygon, START_DATE, END_DATE, MOSAIC_PERIOD, method=METHOD)
    new_pairs = dl_utils.pair(mosaics, SPECTROGRAM_INTERVAL)
    pairs = [p for p in new_pairs if dl_utils.masks_match(p)]
    
    preds = []
    for pair in pairs:
        model_input = np.zeros((28,28,24))
        model_input[:,:,:12] = unit_norm(pad_patch(pair[0], 28))
        model_input[:,:,12:] = unit_norm(pad_patch(pair[1], 28))
        pred = model.predict(np.expand_dims(model_input, axis=0))[0][1]
        preds.append(pred)
    if len(preds) == 0:
        print("No cloud free patches extracted. Try expanding your data time period.")
    
    return preds, pairs

In [None]:
# List the model version number for pixel classifier that generated the candidate points
pixel_classifier_version = '0.0.7'
# List the desired patch classifier version
patch_classifier_version = '0.3'
output_dir = f'../data/model_outputs/candidate_sites/{pixel_classifier_version}/patches_v{patch_classifier_version}'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

## Load Model

In [None]:
model = keras.models.load_model(f'../models/v{patch_classifier_version}_weak_labels_28x28x24.h5', 
                                custom_objects={'LeakyReLU': keras.layers.LeakyReLU,
                                                'ELU': keras.layers.ELU,
                                                'ReLU': keras.layers.ReLU
                                               })
input_height = model.input_shape[1]
# Get model input size in degrees
rect_height = ((input_height + 4) / 100) / 111.1

## Download Candidate Site Patches

In [None]:
# Load coordinates from the detect_candidates output
filename = 'west_timor_v0.0.7_2019-01-01_2021-06-01mosaic-median_blobs_thresh_0.8_min-sigma_5_area-thresh_0.0025'

candidate_sites = gpd.read_file(f'../data/model_outputs/candidate_sites/{pixel_classifier_version}/' + filename + '.geojson')
candidate_sites['rects'] = [rect_from_point([point.x, point.y], rect_height) for point in candidate_sites['geometry']]
print(len(candidate_sites), "candidate sites loaded")

## Run Network and Export Data

In [None]:
patch_predictions = {}
for polygon, name in tqdm(zip(candidate_sites['rects'], candidate_sites['name']), total=len(candidate_sites['rects'])):
    try:
        preds, patches = patch_classifier_predict(polygon, model)

        print(f"{name}, {preds}")
        patch_predictions[name] = {
            'preds': preds,
            'patches': patches,
        }
    except KeyboardInterrupt:
        print("Keyboard Interrupt!")
        break
    except Exception as e:
        print('Failure', name)
        print(e)
        patch_predictions[name] = {
            'preds': [],
            'patches': [],
        }

In [None]:
# Fill candidate sites with -1 values. -1 indicates no data
# These will only be replaced if the patch classifier predicted at that location
candidate_sites['mean'] = [-1 for _ in range(len(candidate_sites))]
candidate_sites['min'] = [-1 for _ in range(len(candidate_sites))]
candidate_sites['max'] = [-1 for _ in range(len(candidate_sites))]
candidate_sites['std'] = [-1 for _ in range(len(candidate_sites))]
candidate_sites['count'] = [-1 for _ in range(len(candidate_sites))]

# I round to 6 decimals since kepler.gl can sometimes be confused in thinking scientific notation is a string
for site in patch_predictions:
    index = np.argmax(candidate_sites['name'] == site)
    preds = patch_predictions[site]['preds']
    if len(preds) > 0:
        candidate_sites.loc[index, ('mean')] = round(np.mean(preds).astype(np.float), 6)
        candidate_sites.loc[index, ('min')] = round(np.min(preds).astype(np.float), 6)
        candidate_sites.loc[index, ('max')] = round(np.max(preds).astype(np.float), 6)
        candidate_sites.loc[index, ('std')] = round(np.std(preds).astype(np.float), 6)
        candidate_sites.loc[index, ('count')] = len(preds)

In [None]:
# Lon, lat are redundant given the geometry. Drop them from the exported file
candidate_sites_export = candidate_sites.drop(['lon', 'lat', 'rects'], axis=1)
candidate_sites_export.to_file(os.path.join(output_dir, 'patch_' + filename + '.geojson'), driver='GeoJSON', index=False)

## Visualize Predictions

In [None]:
num_img = int(np.ceil(np.sqrt(len([np.mean(v['preds']) for v in patch_predictions.values() if len(v['preds']) > 0]))))
plt.figure(figsize=(num_img * 2,num_img), dpi=150, facecolor=(1,1,1))

counter = 1
for v in patch_predictions.values():
    if len(v['preds']) > 0:
        plt.subplot(num_img, num_img, counter)
        images = np.zeros((28,57,12))
        patches = v['patches'][0]
        images[:,:28,:] = unit_norm(pad_patch(patches[0], 28))
        images[:,29:,:] = unit_norm(pad_patch(patches[1], 28))
        plt.imshow(np.clip(images[:,:,3:0:-1] / 4 + 0.5, 0, 1))
        plt.axis('off')
        plt.title(f"{np.mean(v['preds']):.2f}")
        counter += 1
plt.tight_layout()
plt.show()

# Display map

In [None]:
threshold = 0.25
filtered_candidate_sites = candidate_sites_export.query(f'mean > {threshold}')
print(f"{len(filtered_candidate_sites)} / {len(candidate_sites_export)} sites found above the threshold of {threshold}")

In [None]:
filtered_candidate_sites

In [None]:
kepler_config={
  "version": "v1",
  "config": {
    "visState": {
      "filters": [],
      "layers": [
        {
          "id": "anqfulm",
          "type": "point",
          "config": {
            "dataId": "Candidate Sites",
            "label": "Point",
            "color": [
              221,
              178,
              124
            ],
            "columns": {
              "lat": "lat",
              "lng": "lon",
              "altitude": None
            },
            "isVisible": True,
            "visConfig": {
              "radius": 30,
              "fixedRadius": False,
              "opacity": 0.8,
              "outline": True,
              "thickness": 3,
              "strokeColor": None,
              "colorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "strokeColorRange": {
                "name": "Global Warming",
                "type": "sequential",
                "category": "Uber",
                "colors": [
                  "#5A1846",
                  "#900C3F",
                  "#C70039",
                  "#E3611C",
                  "#F1920E",
                  "#FFC300"
                ]
              },
              "radiusRange": [
                0,
                50
              ],
              "filled": False
            },
            "hidden": False,
            "textLabel": [
              {
                "field": None,
                "color": [
                  255,
                  255,
                  255
                ],
                "size": 18,
                "offset": [
                  0,
                  0
                ],
                "anchor": "start",
                "alignment": "center"
              }
            ]
          },
          "visualChannels": {
            "colorField": None,
            "colorScale": "quantile",
            "strokeColorField": {
              "name": "mean",
              "type": "real"
            },
            "strokeColorScale": "quantile",
            "sizeField": None,
            "sizeScale": "linear"
          }
        }
      ],
      },
      "mapStyle":{
         "styleType":"satellite"
      }
   }
}

In [None]:
# Plot blob locations on a satellite base image
candidate_map = KeplerGl(height=800, config=kepler_config)
candidate_map.add_data(data=filtered_candidate_sites.copy(), name='Candidate Sites')
candidate_map