# Create Pixel Spectrogram Dataset
This notebook downloads Sentinel data to produce inputs to a spectral-temporal pixelwise-classifier. 
Currently the notebook supports two time-steps, as defined by SPECTROGRAM_INTERVAL, below. 

## Inputs
The notebook operates by loading a set of sampling sites from a geojson. If the geojson contains `Point` features, a bounding rect is constructed. If the geojson contains `Polygon` or `MultiPolygon` features, only pixels within the polygon will be extracted.

The `download_patch` script attempts to mask clouds. However, cloudy pixels and patches can still come through.

Pixels that fall outside of a polygon are also masked using a numpy masked array. These pixels are not stored in the output pixel arrays.

## Outputs

### Pixel Arrays:
The output list of arrays is saved as a pickle. The arrays are not normalized. The dimensionality of each array is  `[bands][num_time_steps]`, with num_time_steps hard-coded to 2. 
### Image Plot:
To log the data in a pixel array dataset, a grid of input images is exported along with the datset.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import descarteslabs as dl
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm

from scripts import dl_utils
from scripts.viz_tools import band_descriptions, plot_image_grid

In [None]:
def save_pixel_arrays(data, basepath, label_class):
    with open(basepath + '_pixel_arrays.pkl', "wb") as f:
        pickle.dump(data, f)
    with open(basepath + '_pixel_array_labels.pkl', "wb") as f:
        pickle.dump([label_class] * len(data), f)
        
def save_patch_arrays(data, basepath, label_class):
    with open(basepath + '_patch_arrays.pkl', "wb") as f:
        pickle.dump(data, f)
    with open(basepath + '_patch_array_labels.pkl', "wb") as f:
        pickle.dump([label_class] * len(data), f)

### Define Parameters for data extraction
### Attention: make sure to set appropriate label class!
Negative sites = 0, Positive sites = 1

In [None]:
sampling_file = 'central_america_intersection_confirmed_negatives'
data_dir = '../data/sampling_locations/'
label_class = 0

START_DATE = '2021-01-01'
END_DATE = '2022-01-01'
METHOD = 'min'
MOSAIC_PERIOD = 3  # the period over which to mosaic image data in months
SPECTROGRAM_INTERVAL = 2  # For spectrogram analysis, the time from the start of one mosaic to the start of the next,
 # in number of mosaic periods

OUTPUT_DIR = f'../data/training_data/pixel_arrays_{MOSAIC_PERIOD}mo-mosaics_{SPECTROGRAM_INTERVAL}x-int'
PATCH_OUTPUT_DIR = f'../data/training_data/spectrogram_patches_{MOSAIC_PERIOD}mo-mosaics_{SPECTROGRAM_INTERVAL}x-int'
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [None]:
# Create or extract polygons from a sampling location
with open(os.path.join(data_dir, sampling_file + '.geojson'), 'r') as f:
    data = json.load(f)['features'] 
    
# Set rect width in pixels. Only required for point samples
num_pixels = 48
# Convert pixels to degrees. Heuristic, not geographically sound
# Better to go with slightly bigger patches that can then be cropped
rect_width = np.round((num_pixels / 100) / 111.32, 4)    

polygons = []
for feature in data:
    if feature['geometry']['type'] == 'Point':
        polygons.append(dl_utils.rect_from_point(feature['geometry']['coordinates'], rect_width))
    if feature['geometry']['type'] == 'MultiPolygon' or feature['geometry']['type'] == 'Polygon':
        polygons.append(feature['geometry'])
print(f'{len(polygons)} polygons loaded.')

### Download Sentinel data

In [None]:
pairs = []
for polygon in tqdm(polygons[300:]):
    try:
        data = dl_utils.SentinelData(polygon, START_DATE, END_DATE, MOSAIC_PERIOD, SPECTROGRAM_INTERVAL, method=METHOD)
        rect_width = rect_width
        data.search_scenes()
        data.download_scenes()
        data.create_composites()
        composites = data.composites
        dates = data.composite_dates
        bounds = data.metadata[0]["wgs84Extent"]["coordinates"][0][:-1]
        data.create_pairs()
        new_pairs = data.pairs
        pairs += [p for p in new_pairs if dl_utils.masks_match(p)]
        dates = data.pair_starts


        #mosaics, _ = dl_utils.download_mosaics(polygon, START_DATE, END_DATE, MOSAIC_PERIOD, method=METHOD)
        #new_pairs = dl_utils.pair(mosaics, SPECTROGRAM_INTERVAL)
        #pairs += [p for p in new_pairs if dl_utils.masks_match(p)]
    except KeyboardInterrupt:
        print("Keyboard Interrupt!")
        break
    except:
        print('Failure', polygon)
print(len(pairs), "pairs of images extracted")

In [None]:
# View all images
figure_file_path = os.path.join(PATCH_OUTPUT_DIR, f"{sampling_file}-Class_{label_class}-{START_DATE}-{END_DATE}-{METHOD}")
unpaired = [img for pair in pairs for img in pair]
plot_image_grid(unpaired, file_path=figure_file_path)

In [None]:
basepath = os.path.join(OUTPUT_DIR, f"{sampling_file}_{START_DATE}_{END_DATE}_period_{MOSAIC_PERIOD}_interval_{SPECTROGRAM_INTERVAL}_method_{METHOD}")

In [None]:
patch_basepath = os.path.join(PATCH_OUTPUT_DIR, f"{sampling_file}_{START_DATE}_{END_DATE}_period_{MOSAIC_PERIOD}_interval_{SPECTROGRAM_INTERVAL}_method_{METHOD}")
save_patch_arrays(pairs, patch_basepath, label_class)

In [None]:
# Create pixel arrays
pixel_arrays = []
for pair in tqdm(pairs):
    pixels = dl_utils.shape_gram_as_pixels(pair)
    pixel_arrays += [pixel for pixel in pixels if (np.mean(pixel[:,0]) > 0 and np.mean(pixel[:,1]) > 0)]

print(f"{len(pixel_arrays):,} pixel arrays extracted")
basepath = os.path.join(OUTPUT_DIR, f"{sampling_file}_{START_DATE}_{END_DATE}")
save_pixel_arrays(pixel_arrays, basepath, label_class)

### Optional: plot the mean pixel spectra of the extracted dataset.
Process can take time with many samples

In [None]:
all_pixels = np.moveaxis(np.array(pixel_arrays), -1, -2).reshape(2 * len(pixel_arrays), len(band_descriptions))
data = pd.DataFrame(all_pixels, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
plt.figure(figsize=(6,4), dpi=150, facecolor=(1,1,1))
sns.lineplot(x='band', y='value', data=data, ci="sd")
plt.title('Mean Value +/- SD')
plt.show()