# Create Patch Dataset
This notebook is the source for downloading Sentinel data for the 2D patch classifier

### Inputs
The notebook operates by loading a set of coordinates either from a geojson or csv. For each location in the list, it downloads a patch of width `RECT_WIDTH` across a specified period of time.

### Outputs:
Multispectral patches with the structure `[num_patches, height, width, bands]`

In [None]:
import json
import os
import pickle

import ee
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import sys
sys.path.append('../')
from scripts.get_s2_data_ee import get_history, get_history_polygon, get_pixel_vectors
from scripts.viz_tools import visualize_history, create_img_stack, band_descriptions

%load_ext autoreload
%autoreload 2

In [None]:
def load_points(file_name):
    """Load points saved as a GeoJSON and return a dictionary"""
    with open(os.path.join(DATA_DIR, file_name)) as f:
        sites = json.load(f)
    f.close()

    site_table = pd.DataFrame({
        'name': [file_name.split('_')[0] + '_' + str(index) for index in range(len(sites['features']))],
        'lon': [site['geometry']['coordinates'][0] for site in sites['features']],
        'lat': [site['geometry']['coordinates'][1] for site in sites['features']],
        'coords': [site['geometry']['coordinates'][0:2] for site in sites['features']],
    })
    
    return site_table

def visualize_patch_history(data, name):
    first_date = list(patch_history.keys())[0]
    first_site = list(patch_history[first_date].keys())[0]
    num_pixels = np.shape(patch_history[first_date][first_site]['B2'])[0]
    file_name = f"{name}_patches_{num_months}_months_{start_date}"
    visualize_history(data, file_path=os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patches.png"))
        
def save_patches(data, name, label_class):
    num_pixels = np.shape(data)[1]
    file_name = f"{name}_patches_{num_months}_months_{start_date}"
    with open(os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patches.pkl"),"wb") as f:
        pickle.dump(data, f)
        
    with open(os.path.join(OUTPUT_DIR, 'patches', f"{file_name}_{num_pixels}px_patch_labels.pkl"),"wb") as f:
        pickle.dump([label_class] * len(data), f)

## Load Sampling Locations

In [None]:
# Configuration:
# Set directory where training site json files are located and files are saved
# Set rect width for all patches that are not TPA sites
DATA_DIR = '../data/sampling_locations'
OUTPUT_DIR = '../data/training_data'
    
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [None]:
# Site coordinates from candidate geojson

with open(os.path.join(DATA_DIR, 'v12_java_bali_validated_positives.geojson'), 'r') as f:
    positive_sites = json.load(f)['features']
positive_coords = [site['geometry']['coordinates'] for site in positive_sites]
positive_names = ['site_' + str(i) for i in range(len(positive_sites))]

print(len(positive_coords), 'positive sites loaded')

In [None]:
# Read site coordinates from candidate csv
sampling_points = pd.read_csv(os.path.join(DATA_DIR, 'w_nusa_tenggara_v1.1_negatives.csv'), converters={'coords': eval})
sampling_points.head()

In [None]:
# Read sites from a sampling geojson
sampling_points = load_points('city_points_30.geojson')

## Download Data

In [None]:
RECT_WIDTH = 0.004

In [None]:
# Create a list of patch histories
# Each patch history is a dictionary with the format:
# patch_history[date][site_name][band][band_img]
# This function takes a while to run as it is extracting data from GEE

filename = 'w_nusa_tenggara_v1.1_negatives'
label_class = 0
num_months = 12
start_date = '2020-01-01'
patch_history = get_history(sampling_points['coords'], 
                            sampling_points['name'], 
                            RECT_WIDTH,
                            num_months=num_months,
                            start_date=start_date,
                            cloud_mask=True)

visualize_patch_history(patch_history, filename)
patches = create_img_stack(patch_history)
save_patches(patches, filename, label_class)
print(len(patches), 'images extracted')

## Example from Bali/Java Candidates

In [None]:
import pandas as pd
import numpy as np

java_bootstrap = pd.read_csv('../data/model_outputs/candidate_sites/v12_java_2D_candidates_0.4_threshold.csv')
del java_bootstrap['Unnamed: 0']
java_bootstrap_coords = [[lon, lat] for lon, lat in zip(java_bootstrap['lon'], java_bootstrap['lat'])]
java_bootstrap_names = ['java_' + str(index) for index in range(len(java_bootstrap))]

java_curated = pd.read_csv('../data/model_outputs/candidate_sites/v12_java_2D_candidates_0.4_threshold_validated.csv')
java_positive_index = java_curated['label'] == 1
java_positive_coords = np.array(java_bootstrap_coords)[java_positive_index]
java_positive_names = np.array(java_bootstrap_names)[java_positive_index]

java_negative_index = java_curated['label'] == 0
java_negative_coords = np.array(java_bootstrap_coords)[java_negative_index]
java_negative_names = np.array(java_bootstrap_names)[java_negative_index]

pd.DataFrame({
    'name': java_positive_names,
    'lon': java_positive_coords[:,0],
    'lat': java_positive_coords[:,1],
    'coords': [[coord[0], coord[1]] for coord in java_positive_coords]
}).to_csv('../data/sampling_locations/v12_java_validated_positives.csv', index=False)


pd.DataFrame({
    'name': java_negative_names,
    'lon': java_negative_coords[:,0],
    'lat': java_negative_coords[:,1],
    'coords': [[coord[0], coord[1]] for coord in java_negative_coords]
}).to_csv('../data/sampling_locations/v12_java_validated_negatives.csv', index=False)