# Visualize all candidate patches
For a given list of coordinates, create a grid of images centered at each coordinate

In [None]:
import os

import ee
import pandas as pd
import matplotlib.pyplot as plt

import sys
sys.path.append('../')
from scripts.get_s2_data_ee import get_history, get_history_polygon, get_pixel_vectors

%load_ext autoreload
%autoreload 2

In [None]:
# Sentinel 2 band descriptions
band_descriptions = {
    'B1': 'Aerosols, 442nm',
    'B2': 'Blue, 492nm',
    'B3': 'Green, 559nm',
    'B4': 'Red, 665nm',
    'B5': 'Red Edge 1, 704nm',
    'B6': 'Red Edge 2, 739nm',
    'B7': 'Red Edge 3, 779nm',
    'B8': 'NIR, 833nm',
    'B8A': 'Red Edge 4, 864nm',
    'B9': 'Water Vapor, 943nm',
    'B11': 'SWIR 1, 1610nm',
    'B12': 'SWIR 2, 2186nm'
}

## Load Sampling Locations
Depending on the input format, you may need to modify this section. Ultimately need to create a list of coordinates

In [None]:
# Java Validated Sites
candidates = pd.read_csv('../data/model_outputs/candidate_sites/v12_java_validated_positives.csv')

# Mining Sites
candidates = pd.read_csv('../../mining/outputs/tambopata_grid_full.csv')
display(candidates.head())
candidates = candidates[candidates['pred'] > 0.5]
coords = [[lat, lon] for lat, lon in zip(np.array(candidates['lon']), np.array(candidates['lat']))]
names = ['candidate_' + str(i) for i in range(len(coords))]
print(len(coords), "coordinates loaded")

## Download Data

In [None]:
# Create a list of patch histories
# Each patch history is a dictionary with the format:
# patch_history[date][site_name][band][band_img]
# This function takes a while to run as it is extracting data from GEE
num_months = 3
start_date = '2019-06-01'
patch_histories = get_history(coords, 
                              names, 
                              0.004,
                              num_months=num_months,
                              start_date=start_date,
                              cloud_mask=True)

## Create Spatial Patches

In [None]:
def create_img_stack(patch_history, band_combinations):
    img_stack = []
    for site in patch_history[start_date]:
        rgb_stack = []
        for date in patch_history:
            spectral_stack = []
            band_shapes = [np.shape(patch_history[date][site][band]) for band in band_descriptions]
            if np.array(band_shapes).all() > 0:
                for band in band_descriptions:
                    spectral_stack.append(patch_history[date][site][band])
            if np.min(spectral_stack) > 0:        
                rgb_stack.append(np.stack((np.array(spectral_stack)[band_combinations[0],:,:],
                                           np.array(spectral_stack)[band_combinations[1],:,:],
                                           np.array(spectral_stack)[band_combinations[2],:,:]), axis=-1))
        img_stack.append(np.median(rgb_stack, axis=0))
    return img_stack

def normalize(array):
    return np.array(array) / 3000

def stretch_histogram(array, min_val=0.1, max_val=0.75, gamma=1.2):
    clipped = np.clip(array, min_val, max_val)
    stretched = (clipped - min_val) / (max_val - min_val) ** gamma
    return stretched

In [None]:
patches = create_img_stack(patch_histories, [3,2,1])
print(len(patches), 'candidate images extracted')

In [None]:
# Manually filter duplicate patches
duplicate_list = [19, 28, 30, 65, 68]
unique_patch_index = [index for index in range(len(patches)) if index not in duplicate_list]
patches = np.array(patches)[unique_patch_index]

In [None]:
#name = 'Confirmed Dump Sites on Java'
name = 'Mining Sites Detected'

num_images = int(np.ceil(np.sqrt(len(patches))))
plt.figure(figsize=(12,12), dpi=150)
for index, patch in enumerate(patches):
    plt.subplot(num_images, num_images, index + 1)
    plt.imshow(np.clip(stretch_histogram(normalize(patch)), 0, 1))
    plt.axis('off')
plt.suptitle(f'{len(patches)} {name}', size=16)
plt.tight_layout()
plt.savefig(f'figures/patches/{len(patches)} {name} - RGB.png', bbox_inches='tight')
plt.show()

In [None]:
band_combinations = {
    'rgb': ['B4', 'B3', 'B2'],
    'false_color': ['B12', 'B11', 'B4'],
    'color_infrared': ['B8', 'B4', 'B3'],
    'agriculture': ['B11', 'B8', 'B2'],
    'atmosphere': ['B12', 'B11', 'B8'],
    'healthy_vegetation': ['B8', 'B11', 'B2'],
    'land_water': ['B8', 'B11', 'B4'],
    'swir': ['B12', 'B8A', 'B4'],
    'vegetation': ['B11', 'B8', 'B4'],
    'geology': ['B12', 'B11', 'B2'],
}

## Visualize a set of multispectral predictions

In [None]:
for combination in band_combinations:
    band_combo = [np.argmax(np.array(list(band_descriptions.keys())) == band_combinations[combination][index]) for index in range(3)]
    patches = create_img_stack(patch_histories, band_combo)
    num_images = int(np.ceil(np.sqrt(len(patches))))
    plt.figure(figsize=(12,12), dpi=150)
    for index, patch in enumerate(patches):
        plt.subplot(num_images, num_images, index + 1)
        max_val = np.mean(normalize(patch)) + 4 * np.std(normalize(patch))
        plt.imshow(np.clip(stretch_histogram(normalize(patch), max_val=max_val), 0, 1))
        plt.axis('off')
    plt.suptitle(f'{len(patches)} {name} - {combination} {band_combinations[combination]}', size=16)
    plt.tight_layout()
    plt.savefig(f'figures/patches/{len(patches)} {name} - {combination}.png', bbox_inches='tight')
    plt.show()