# Create Patch Dataset
This notebook is the source for downloading Sentinel data for the 2D patch classifier

### Inputs
The notebook operates by loading a set of coordinates either from a geojson or csv. For each location in the list, it downloads a patch of width `RECT_WIDTH` across a specified period of time.

### Outputs:
Multispectral patches with the structure `[num_patches, height, width, bands]`

In [None]:
import json
import os
import pickle

import ee
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import sys
sys.path.append('../')
from scripts.get_s2_data_ee import get_history, get_history_polygon, get_pixel_vectors
from scripts.viz_tools import visualize_history

%load_ext autoreload
%autoreload 2

In [None]:
# Sentinel 2 band descriptions
band_descriptions = {
    'B1': 'Aerosols, 442nm',
    'B2': 'Blue, 492nm',
    'B3': 'Green, 559nm',
    'B4': 'Red, 665nm',
    'B5': 'Red Edge 1, 704nm',
    'B6': 'Red Edge 2, 739nm',
    'B7': 'Red Edge 3, 779nm',
    'B8': 'NIR, 833nm',
    'B8A': 'Red Edge 4, 864nm',
    'B9': 'Water Vapor, 943nm',
    'B11': 'SWIR 1, 1610nm',
    'B12': 'SWIR 2, 2186nm'
}

In [None]:
def load_points(file_name):
    """Load points saved as a GeoJSON and return a dictionary"""
    with open(os.path.join(DATA_DIR, file_name)) as f:
        sites = json.load(f)
    f.close()

    site_table = pd.DataFrame({
        'name': [file_name.split('_')[0] + '_' + str(index) for index in range(len(sites['features']))],
        'lon': [site['geometry']['coordinates'][0] for site in sites['features']],
        'lat': [site['geometry']['coordinates'][1] for site in sites['features']],
        'coords': [site['geometry']['coordinates'][0:2] for site in sites['features']],
    })
    
    return site_table

def create_img_stack(patch_history):
    img_stack = []
    for date in patch_history:
        for site in patch_history[date]:
            spectral_stack = []
            band_shapes = [np.shape(patch_history[date][site][band])[0] for band in band_descriptions]
            if np.array(band_shapes).all() > 0:
                for band in band_descriptions:
                    spectral_stack.append(patch_history[date][site][band])
                cloud_percentage = 1 - np.sum(np.array(spectral_stack) > 0) / np.size(spectral_stack)
                if cloud_percentage < 0.2:
                    img_stack.append(np.rollaxis(np.array(spectral_stack), 0, 3))
    
    min_x = np.min([np.shape(img)[0] for img in img_stack])
    min_y = np.min([np.shape(img)[1] for img in img_stack])
    img_stack = [img[:min_x, :min_y, :] for img in img_stack]
    return img_stack

def visualize_patch_history(data, name):
    first_date = list(patch_history.keys())[0]
    first_site = list(patch_history[first_date].keys())[0]
    num_pixels = np.shape(patch_history[first_date][first_site]['B2'])[0]
    file_name = f"{name}_patches_{num_months}_months_{start_date}"
    visualize_history(data, file_path=os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patches.png"))
        
def save_patches(data, name, label_class):
    num_pixels = np.shape(patches)[1]
    file_name = f"{name}_patches_{num_months}_months_{start_date}"
    with open(os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patches.pkl"),"wb") as f:
        pickle.dump(data, f)
        
    with open(os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patch_labels.pkl"),"wb") as f:
        pickle.dump([label_class] * len(data), f)

## Load Sampling Locations

In [None]:
# Configuration:
# Set directory where training site json files are located and files are saved
# Set rect width for all patches that are not TPA sites
DATA_DIR = '../data/sampling_locations'
OUTPUT_DIR = '../data/training_data'
RECT_WIDTH = 0.004
    
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [None]:
# Site coordinates from candidate geojson

with open(os.path.join(DATA_DIR, 'v12_java_bali_validated_positives.geojson'), 'r') as f:
    positive_sites = json.load(f)['features']
positive_coords = [site['geometry']['coordinates'] for site in positive_sites]
positive_names = ['site_' + str(i) for i in range(len(positive_sites))]

print(len(positive_coords), 'positive sites loaded')

In [None]:
# Read site coordinates from candidate csv
java_negatives = pd.read_csv(os.path.join(DATA_DIR, 'v12_java_validated_negatives.csv'), converters={'coords': eval})
java_negatives.head()

In [None]:
# Read sites from a sampling geojson
sampling_points = load_points('city_points_30.geojson')

## Download Data

### Negative Sites

In [None]:
# Create a list of patch histories
# Each patch history is a dictionary with the format:
# patch_history[date][site_name][band][band_img]
# This function takes a while to run as it is extracting data from GEE
num_months = 1
start_date = '2019-03-01'
patch_history = get_history(sampling_points['coords'], 
                            sampling_points['name'], 
                            RECT_WIDTH,
                            num_months=num_months,
                            start_date=start_date,
                            cloud_mask=True)

visualize_patch_history(patch_history, 'city_points_30')
patches = create_img_stack(patch_history)
save_patches(patches, 'city_points_30', 0)
print(len(patches), 'images extracted')