# Train Pixel Classifier

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import date
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tqdm.notebook import tqdm

import sys
sys.path.append('../')
from scripts.get_s2_data_ee import band_descriptions
from scripts.viz_tools import stretch_histogram, normalize

from scripts.dl_utils import download_patch, rect_from_point
from scripts.nn_predict_dl import make_predictions, visualize_predictions

np.random.seed(1)

## Create a Training Dataset
Outputs will be: `x_train`, `y_train`, `x_test`, `y_test`, `x_holdout`, `y_holdout`. Holdout data is only positive

In [None]:
train_data_dir = '../data/training_data/pixel_vectors/'

data_files = ['city_points_30_raw_24_months_2016-01-01_pixel_vectors.pkl',
              'adjacent_north_0.015_raw_24_months_2016-01-01_pixel_vectors.pkl',
              'bali_bootstrap_raw_24_months_2019-01-01_pixel_vectors.pkl',
              #'tpa_train_raw_60_months_2016-01-01_pixel_vectors.pkl', 
              'java_v12_negatives_raw_24_months_2016-01-01_pixel_vectors.pkl',
              #'java_v12_positives_raw_24_months_2016-01-01_pixel_vectors.pkl',
              'w_nusa_tenggara_v1.1_negatives_raw_12_months_2020-01-01_pixel_vectors.pkl',
              #'w_nusa_tenggara_v1.1_positives_raw_24_months_2019-01-01_pixel_vectors.pkl'
              'bali_tpa_sites_2020-01-01_2020-12-31_pixel_vectors.pkl',
              'sri_lanka_sites_2020-01-01_2020-12-31_pixel_vectors.pkl',
              'v_1.1.5_negatives_2020-04-01_2020-05-31_pixel_vectors.pkl',
              'java_v1.0_positive_polygons_2017-01-01_2020-12-31_pixel_vectors.pkl',
              'lombok_v_1.1.5_negatives_2019-01-01_2020-12-31_pixel_vectors.pkl'
             ]

label_files = ['city_points_30_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               'adjacent_north_0.015_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               'bali_bootstrap_raw_24_months_2019-01-01_pixel_vector_labels.pkl',
               #'tpa_train_raw_60_months_2016-01-01_pixel_vector_labels.pkl', 
               'java_v12_negatives_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               #'java_v12_positives_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               'w_nusa_tenggara_v1.1_negatives_raw_12_months_2020-01-01_pixel_vector_labels.pkl',
               #'w_nusa_tenggara_v1.1_positives_raw_24_months_2019-01-01_pixel_vector_labels.pkl'
               'bali_tpa_sites_2020-01-01_2020-12-31_pixel_vector_labels.pkl',
               'sri_lanka_sites_2020-01-01_2020-12-31_pixel_vector_labels.pkl',
               'v_1.1.5_negatives_2020-04-01_2020-05-31_pixel_vector_labels.pkl',
               'java_v1.0_positive_polygons_2017-01-01_2020-12-31_pixel_vector_labels.pkl',
               'lombok_v_1.1.5_negatives_2019-01-01_2020-12-31_pixel_vector_labels.pkl'
             ]

pixel_vectors = []
labels = []
for data, label in tqdm(zip(data_files, label_files), total=len(data_files)):
    with open(os.path.join(train_data_dir, data), 'rb') as f:
        pixel_vectors += pickle.load(f)
    with open(os.path.join(train_data_dir, label), 'rb') as f:
        labels += pickle.load(f)
pixel_vectors = np.array(pixel_vectors)
labels = np.array(labels)

positive_vectors = pixel_vectors[labels == 1]
negative_vectors = pixel_vectors[labels == 0]

print(f"Loaded {len(positive_vectors):,} positive pixel vectors and {len(negative_vectors):,} negative pixel vectors")

In [None]:
holdout_data_files = ['tpa_holdout_raw_60_months_2016-01-01_pixel_vectors.pkl']
holdout_label_files = ['tpa_holdout_raw_60_months_2016-01-01_pixel_vector_labels.pkl']
holdout_pixel_vectors = []
holdout_labels = []

for data, label in zip(holdout_data_files, holdout_label_files):
    with open(os.path.join(train_data_dir, data), 'rb') as f:
        holdout_pixel_vectors += pickle.load(f)
    with open(os.path.join(train_data_dir, label), 'rb') as f:
        holdout_labels += pickle.load(f)
holdout_pixel_vectors = np.array(holdout_pixel_vectors)
holdout_labels = np.array(holdout_labels)

### Filter positive samples such that NDVI is within a range
This is useful since the positive patches can include surrounding vegetation

In [None]:
def compute_ndvi(pixel_vectors):
    return (pixel_vectors[:,7] - pixel_vectors[:,3]) / (pixel_vectors[:,7] + pixel_vectors[:,3])

def filter_ndvi(data, lower_bound=0, upper_bound=0.4):
    ndvi = compute_ndvi(data)
    index = np.logical_and(ndvi > lower_bound, ndvi < upper_bound)
    filtered_data = np.array(data)[index]
    print(f"{np.sum(index) / len(data):.1%} of samples within NDVI range")
    return filtered_data


def filter_bright(data, brightness_threshold=2500):
    filtered_data = data[np.mean(data, axis=1) < brightness_threshold]
    filtered_data.shape
    print(f"{len(filtered_data) / len(data) :.1%} of data below brightness limit")
    return filtered_data

In [None]:
filtered_positive_vectors = filter_ndvi(positive_vectors)
filtered_positive_vectors = filter_bright(filtered_positive_vectors)
filtered_holdout_vectors = filter_ndvi(holdout_pixel_vectors)
filtered_holdout_vectors = filter_bright(filtered_holdout_vectors)

### Combine data and create train test split
Also expand dimensions to account for batches

In [None]:
x = np.concatenate((filtered_positive_vectors, negative_vectors))
y = np.concatenate((np.ones(len(filtered_positive_vectors)), np.zeros(len(negative_vectors))))

x, y = shuffle(x, y, random_state=42)
x = normalize(x)
x_holdout = normalize(filtered_holdout_vectors)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.20, random_state=42)
print("Num Train Samples:\t\t", len(x_train))
print("Num Test Samples:\t\t", len(x_test))
print("Num Holdout Samples:\t\t", len(x_holdout))
print(f"Percent Negative Train:\t {sum(y_train == 0.0) / len(y_train):.1%}")
print(f"Percent Negative Test:\t {sum(y_test == 0.0) / len(y_test):.1%}")

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
x_holdout = np.expand_dims(x_holdout, -1)

# Note: I am accustomed to assigning two classes for binary classification. 
# This habit comes from an issue in theano a long time ago, but I'm too superstitious to change it.
num_classes = 2
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
y_holdout = keras.utils.to_categorical(np.ones(len(filtered_holdout_vectors)), num_classes)

## Create and Train a Model

In [None]:
input_shape = np.shape(x_train[0])
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv1D(16, kernel_size=(3), activation="relu"),
        #layers.MaxPooling2D(pool_size=(2)),
        layers.Conv1D(32, kernel_size=(3), activation="relu"),
        #layers.MaxPooling2D(pool_size=(2)),
        layers.Flatten(),
        layers.Dense(32, activation="relu"),
        layers.Dense(32, activation="relu"),
        layers.Dropout(0.1),
        layers.Dense(32, activation="relu"),
        #layers.Dense(64, activation="relu"),
        #layers.Dense(64, activation="relu"),
        #layers.Dense(64, activation="relu"),
        layers.Dense(num_classes, activation="softmax"),
    ]
)
model.summary()

### Optional Class Weighting
Over experimental testing, I found that weighting classes seemed to degrade performance. This could use further investigation

In [None]:
from sklearn.utils import class_weight
negative_weight, positive_weight = class_weight.compute_class_weight('balanced', 
                                                                     classes = np.unique(y_train),
                                                                     y = y_train[:,1])
print(f"Negative Weight: {negative_weight:.2f}")
print(f"Positive Weight: {positive_weight:.2f}")

In [None]:
# Compile model. Note that many of these metrics are extraneous. 
# Can be useful to track during training at times though
model.compile(loss="binary_crossentropy", 
              optimizer="adam", 
              metrics=[keras.metrics.Recall(thresholds=(0.7), name='precision'), 
                       keras.metrics.Precision(thresholds=(0.7), name='recall'),
                       keras.metrics.AUC(curve='PR', name='auc'),
                       "accuracy"],
              #loss_weights = sum(y_train) / len(y_train),
              #weighted_metrics = ['accuracy']
             )

train_accuracy = []
test_accuracy = []

### Train the Model

In [None]:
batch_size = 128
epochs = 15

model.fit(x_train, 
          y_train, 
          batch_size=batch_size, 
          epochs=epochs, 
          validation_data=(x_test, y_test),
          #validation_split=0.1,
          #class_weight = {0: negative_weight, 1: positive_weight}
         )

In [None]:
train_accuracy += model.history.history['accuracy']
test_accuracy += model.history.history['val_accuracy']
plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
plt.plot(train_accuracy, label='Train Acc')
plt.plot(test_accuracy, c='r', label='Val Acc')
percent_negative = (sum(y_train == 0.0) / len(y_train))[1]
plt.plot([0, epochs-1], [percent_negative, percent_negative], '--', c='gray', label='Baseline')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Network Train and Val Accuracy')
plt.show()

In [None]:
threshold = 0.8
print("Test Set Metrics:")
print(classification_report(y_test[:,1], model.predict(x_test)[:,1] > threshold, 
                            target_names=['No TPA', 'TPA']))

print("\nHoldout Positive Set Metrics:")
print(classification_report(y_holdout[:,1], model.predict(x_holdout)[:,1] > threshold, 
                            target_names=['No TPA', 'TPA']))

## Save Model

In [None]:
version_number = '1.1.8'

current_date = date.today()
model_name = f"model_v{version_number}_{current_date.month}-{current_date.day}-{current_date.year}"
assert not os.path.exists('../models/' + model_name + '.h5'), f"Model of name {model_name} already exists"

with open('../models/' + model_name + '_config.txt', 'w') as f:
    f.write('Input Data:\n')
    [f.write(file + '\n') for file in data_files]
    f.write(f"\nBatch Size: {batch_size}")
    f.write(f"\nTraining Epochs: {len(train_accuracy)}")
    f.write('\n\nClassification Report\n')
    f.write(classification_report(y_test[:,1], model.predict(x_test)[:,1] > threshold, 
                            target_names=['No TPA', 'TPA']))
model.save('../models/' + model_name + '.h5')

## Visualize Network Predictions

In [None]:
rect_width = 0.02
coords = [116.0908,-8.6451]
start_date = '2020-05-01'
end_date = '2020-06-01'
patches = download_patch(rect_from_point(coords, rect_width), start_date, end_date)

In [None]:
pred_stack = make_predictions(patches, model)
visualize_predictions(patches, pred_stack, threshold=0.8)

In [None]:
# Compare to a baseline model
baseline_model = keras.models.load_model('../models/65_mo_tpa_bootstrap_toa-12-20-2020.h5')
pred_stack_baseline = make_predictions(patches, baseline_model)
visualize_predictions(patches, pred_stack_baseline, threshold=0.8)

### Show timeseries predictions

In [None]:
threshold = 0.8
num_img = int(np.ceil(np.sqrt(len(patches))))

plt.figure(figsize=(num_img,num_img), dpi=250, facecolor=(1,1,1))
for i, (img, pred) in enumerate(zip(patches, pred_stack)):
    rgb = normalize(img[:,:,3:0:-1])
    rgb[pred > threshold, 0] = 0.9
    rgb[pred > threshold, 1] = 0
    rgb[pred > threshold, 2] = 0.1
    
    plt.subplot(num_img, num_img, i + 1)
    plt.imshow(np.clip(rgb, 0, 1))
    plt.axis('off')
plt.show()

## Evaluate Failures

In [None]:
test_preds = model.predict(x_test)[:,1]

In [None]:
threshold = 0.8
tp = x_test[(test_preds > threshold) & (y_test[:,1] == 1)].squeeze()
tn = x_test[(test_preds <= threshold) & (y_test[:,1] == 0)].squeeze()
fp = x_test[(test_preds > threshold) & (y_test[:,1] == 0)].squeeze()
fn = x_test[(test_preds <= threshold) & (y_test[:,1] == 1)].squeeze()

In [None]:
def plot_pixel_colors(pixels, plot=True):
    """
    Function takes an array of pixels of shape (num_samples, 12)
    Outputs a square array of RGB color values for each input pixel sorted by brightness
    """
    num_samples = int(np.ceil(np.sqrt(len(pixels))))
    padding_len = num_samples ** 2 - len(pixels)
    padded = np.concatenate((pixels, np.zeros((padding_len, 12))))[:,1:4]
    brightness = [np.linalg.norm(pixel) for pixel in padded]
    padded = padded[np.argsort(brightness)[::-1]]
    colors = np.reshape(padded, (num_samples, num_samples, 3)).astype(np.float)
    colors = np.flip(colors,2)
    if plot:
        plt.figure(figsize=(8,8), dpi=150)
        plt.imshow(np.clip(colors, 0, 1))
        plt.xticks([])
        plt.yticks([])
        plt.show()
    return(colors)

In [None]:
plt.figure(figsize=(8,8), dpi=150)
for index, (data, name) in enumerate(zip([tp, tn, fp, fn], ['True Positives', 'True Negatives', 'False Positives', 'False Negatives'])):
    color_array = plot_pixel_colors(data, plot=False)
    plt.subplot(2,2,index + 1)
    plt.title(name)
    plt.imshow(np.clip(color_array, 0, 1))
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.show()

In [None]:
# Plot the mean pixel spectra of the different predicted classes
tp_df = pd.DataFrame(tp, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
fp_df = pd.DataFrame(fp, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
plt.figure(figsize=(12,4), dpi=150, facecolor=(1,1,1))
plt.subplot(1,2,1)
sns.lineplot(x='band', y='value', data=fp_df, ci="sd", color='r', label='False Positives')
sns.lineplot(x='band', y='value', data=tp_df, ci="sd", label='True Positives')
plt.legend()
plt.title('Mean Value +/- SD')

plt.subplot(1,2,2)
tn_df = pd.DataFrame(tn, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
fn_df = pd.DataFrame(fn, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
sns.lineplot(x='band', y='value', data=fn_df, ci="sd", color='r', label='False Negatives')
sns.lineplot(x='band', y='value', data=tn_df, ci="sd", label='True Negatives')
plt.legend()
plt.title('Mean Value +/- SD')
plt.show()

### Inspect Data

In [None]:
# Plot the mean pixel spectra of the extracted dataset
positive_df = pd.DataFrame(positive_vectors, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
negative_df = pd.DataFrame(negative_vectors, columns=band_descriptions.keys()).melt(var_name='band', value_name='value')
plt.figure(figsize=(6,4), dpi=150, facecolor=(1,1,1))
sns.lineplot(x='band', y='value', data=negative_df, ci="sd", color='r', label='Negative')
sns.lineplot(x='band', y='value', data=positive_df, ci="sd", label='Positive')
plt.legend()
plt.title('Mean Value +/- SD')
plt.show()