# Run Pixel Spectrogram Classifier on a Test Sites

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np 
from tqdm.notebook import tqdm

from scripts import dl_utils

In [None]:
def visualize_time_series(pairs, preds, dates, threshold=0.6, title=None, path=None):
    num_img = int(np.ceil(np.sqrt(len(pairs)))) + 1
    patches = [np.mean(pair, axis=0) for pair in pairs]
    plt.figure(figsize=(num_img,num_img), dpi=250, facecolor=(1,1,1))
    for i, (img, pred, date) in enumerate(zip(patches, preds, dates)):
        rgb = img[:,:,3:0:-1] / 3000
        rgb[pred > threshold, 0] = 0.9
        rgb[pred > threshold, 1] = 0
        rgb[pred > threshold, 2] = 0.1
        plt.subplot(num_img, num_img, i + 1)
        plt.title(date[:7], size=5, y=0.9)
        plt.imshow(np.clip(rgb, 0, 1))
        plt.axis('off')
    mean_patch = np.ma.mean([pair[0] for pair in pairs], axis=0)
    mean_pred = np.ma.mean(preds, axis=0)
    mean_patch = mean_patch[:,:,3:0:-1] / 3000
    mean_patch[mean_pred > threshold, 0] = 0.9
    mean_patch[mean_pred > threshold, 1] = 0
    mean_patch[mean_pred > threshold, 2] = 0.1
    plt.subplot(num_img, num_img, i + 2)
    plt.title('Mean', size=5)
    plt.imshow(np.clip(mean_patch, 0, 1))
    plt.axis('off')
    if title:
        plt.suptitle(title, size=num_img * 2, y=0.93)
    plt.tight_layout()
    if path:
        plt.savefig(path + '.png', bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
RECT_WIDTH = 0.004
START_DATE = '2019-06-01'
END_DATE = '2021-09-01'
MOSAIC_PERIOD = 3
SPECTROGRAM_INTERVAL = 2

In [None]:
ensemble_name = 'v0.0.11_ensemble-8-25-21'
model_list = dl_utils.load_ensemble(f'../models/{ensemble_name}')

In [None]:
coord = [113.39, -1.82]

In [None]:
mosaics, metadata = dl_utils.download_mosaics(
                        dl_utils.rect_from_point(coord, RECT_WIDTH), 
                        START_DATE, 
                        END_DATE, 
                        MOSAIC_PERIOD, 
                        method='min')
dates = [m['metadata']['']['id'][15:25] for m in metadata]
pairs, pair_dates = dl_utils.pair(mosaics, SPECTROGRAM_INTERVAL, dates=dates)
preds = dl_utils.predict_ensemble(pairs, model_list, method='median')
patches = [np.mean(pair, axis=0) for pair in pairs]
title = f'{coord[0]:.2f}, {coord[1]:.2f}'
path = f'../figures/{title}'
visualize_time_series(pairs, preds, pair_dates, threshold=0.6, title=title, path=path)