# Run Patch Classifier

In [None]:
import json
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from tensorflow import keras

sys.path.append('../')
from scripts.get_s2_data_ee import get_history, band_descriptions
from scripts.viz_tools import *

%load_ext autoreload
%autoreload 2

## Load Model

In [None]:
model = keras.models.load_model('../models/2d_java_classifier_0.0075_patch.h5')
input_width = model.input_shape[1]

# Get model input size in degrees
rect_width = np.round((input_width / 100) / 111.32, 4)

## Download Candidate Site Patches

In [None]:
# Load coordinates from the detect_candidates output
filename = 'bali_v12_candidates_pred-thresh_0.75_min-sigma_3.5_area-thresh_0.0025'
candidate_sites = pd.read_csv('../data/model_outputs/candidate_sites/' + filename + '.csv')
candidate_coords = [[lon, lat] for lat, lon in zip(list(candidate_sites['lat']), list(candidate_sites['lon']))]
candidate_names = candidate_sites['name']

start_date = '2020-06-01'
num_months = 1
patch_history = get_history(candidate_coords, 
                            candidate_names, 
                            rect_width,
                            num_months = num_months,
                            start_date = start_date,
                            cloud_mask=True)


In [None]:
cloudiness = np.mean([np.sum(patch_history[date][site][band] < 0) / np.size(patch_history[date][site][band]) for band in band_descriptions])
cloudiness

In [None]:
dates = list(patch_history.keys())
sites = list(patch_history[dates[0]].keys())

cloud_threshold = 0.05
cloud_free_sites = []
for date in dates:
    for site in sites:
        cloudiness = np.mean([np.sum(patch_history[date][site][band] < 0) / np.size(patch_history[date][site][band]) for band in band_descriptions])
        if cloudiness < cloud_threshold:
            cloud_free_sites.append(site)
print(f"{len(cloud_free_sites) / len(sites):.0%} percent of sites have less than {cloud_threshold:.0%} cloud cover")

## Run Network and Visualize Predictions

In [None]:
patches = np.array(create_img_stack_mean(patch_history, cloud_threshold))
patches = [patch[:input_width, :input_width] for patch in patches]
rgb_img = create_rgb(patches)
patches = normalize(patches)
preds = model.predict(patches)[:,1]

In [None]:
num_img = int(np.ceil(np.sqrt(len(preds))))
plt.figure(figsize=(num_img, num_img), dpi=150, facecolor=(1,1,1))
for index, img in enumerate(rgb_img):
    plt.subplot(num_img, num_img, index + 1)
    plt.imshow(stretch_histogram(img))
    plt.title(f"{cloud_free_sites[index]}: {preds[index]:.2f}", size=6)
    plt.axis('off')
plt.tight_layout()
plt.show()

## Write Candidate Sites

In [None]:
filename = 'bali_v12_candidates_pred-thresh_0.75_min-sigma_3.5_area-thresh_0.0025'
candidate_sites['patch_prediction'] = [-1] * len(candidate_sites)
for site, pred in zip(cloud_free_sites, preds):
    candidate_sites.loc[candidate_sites['name'] == site, 'patch_prediction'] = pred
candidate_sites.to_csv('../data/model_outputs/candidate_sites/' + filename + '_patch_classifier.csv', index=False)