# Run Pixel Spectrogram Classifier on a Test Sites

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np 
from tqdm.notebook import tqdm
import os
import sys

parent_dir = os.path.split(os.getcwd())[0]
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from scripts import dl_utils

In [None]:
def visualize_time_series(pairs, preds, dates, threshold=0.6, title=None, path=None):
    patches = [np.mean(pair, axis=0) for pair in pairs]
    plt.figure(figsize=(len(pairs), 1), dpi=250, facecolor=(1,1,1))
    for i, (img, pred, date) in enumerate(zip(patches, preds, dates)):
        rgb = img[:,:,3:0:-1] / 3000
        rgb[pred > threshold, 0] = 0.9
        rgb[pred > threshold, 1] = 0
        rgb[pred > threshold, 2] = 0.1
        plt.subplot(1, len(pairs) + 1, i + 1)
        plt.title(date[:7], size=5, y=0.9)
        plt.imshow(np.clip(rgb, 0, 1))
        plt.axis('off')
    mean_patch = np.ma.mean([pair[0] for pair in pairs], axis=0)
    mean_pred = np.ma.mean(preds, axis=0)
    mean_patch = mean_patch[:,:,3:0:-1] / 3000
    mean_patch[mean_pred > threshold, 0] = 0.9
    mean_patch[mean_pred > threshold, 1] = 0
    mean_patch[mean_pred > threshold, 2] = 0.1
    plt.subplot(1, len(pairs) + 1, i + 2)
    plt.title('Mean', size=5)
    plt.imshow(np.clip(mean_patch, 0, 1))
    plt.axis('off')
    if title:
        plt.suptitle(title, size=len(pairs) * 2, y=0.93)
    plt.tight_layout()
    if path:
        plt.savefig(path + '.png', bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
RECT_WIDTH = 0.004
START_DATE = '2021-01-01'
END_DATE = '2022-03-01'
MOSAIC_PERIOD = 3
SPECTROGRAM_INTERVAL = 2

In [None]:
ensemble_name = 'v0.0.11_ensemble-8-25-21'
model_list = dl_utils.load_ensemble(f'../models/{ensemble_name}')

model_name = 'spectrogram_v0.0.11_2021-07-13'
model_file = '../models/' + model_name + '.h5'
from tensorflow import keras
model = keras.models.load_model(model_file)

In [None]:
coord = [2.051444,43.780970]

In [None]:
data = dl_utils.SentinelData(
    dl_utils.rect_from_point(coord, RECT_WIDTH),
    START_DATE,
    END_DATE,
    MOSAIC_PERIOD,
    SPECTROGRAM_INTERVAL,
    'min'
    )

data.search_scenes()
data.download_scenes()
data.create_composites()
data.create_pairs()
pairs = data.pairs
patches = [np.mean(pair, axis=0) for pair in pairs]
dates = data.pair_starts
bounds = data.metadata[0]["wgs84Extent"]["coordinates"][0][:-1]

In [None]:
threshold = 0.2

ensemble_method = 'mean'
ensemble_preds = dl_utils.predict_ensemble(pairs, model_list, method=ensemble_method)
title = f'{coord[0]:.2f}, {coord[1]:.2f} - {ensemble_method} ensemble composite - threshold {threshold} - {START_DATE} - {END_DATE}'
path = f'../figures/{title}'
visualize_time_series(pairs, ensemble_preds, dates, threshold=threshold, title=title, path=path)

title = f'{coord[0]:.2f}, {coord[1]:.2f} - single prediction - threshold {threshold} - {START_DATE} - {END_DATE}'
path = f'../figures/{title}'
preds = [dl_utils.predict_spectrogram(pair, model) for pair in pairs]
visualize_time_series(pairs, preds, dates, threshold=threshold, title=title, path=path)