# Train Pixel Classifier

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tqdm import tqdm

import sys
sys.path.append('../')
from scripts.get_s2_data_ee import get_history, get_pixel_vectors, band_descriptions
from scripts.viz_tools import stretch_histogram, normalize

np.random.seed(1)

## Create a Training Dataset
Outputs will be: `x_train`, `y_train`, `x_test`, `y_test`, `x_holdout`, `y_holdout`. Holdout data is only positive

In [None]:
train_data_dir = '../data/training_data/pixel_vectors/'

data_files = ['tpa_train_raw_60_months_2016-01-01_pixel_vectors.pkl', 
              'city_points_30_raw_24_months_2016-01-01_pixel_vectors.pkl',
              'adjacent_north_0.015_raw_24_months_2016-01-01_pixel_vectors.pkl',
              'bali_bootstrap_raw_24_months_2019-01-01_pixel_vectors.pkl',
              #'java_v12_negatives_raw_24_months_2016-01-01_pixel_vectors.pkl',
              #'java_v12_positives_raw_24_months_2016-01-01_pixel_vectors.pkl',
              #'w_nusa_tenggara_v1.1_negatives_raw_12_months_2020-01-01_pixel_vectors.pkl',
              #'w_nusa_tenggara_v1.1_positives_raw_24_months_2019-01-01_pixel_vectors.pkl'
             ]

label_files = ['tpa_train_raw_60_months_2016-01-01_pixel_vector_labels.pkl', 
               'city_points_30_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               'adjacent_north_0.015_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               'bali_bootstrap_raw_24_months_2019-01-01_pixel_vector_labels.pkl',
               #'java_v12_negatives_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               #'java_v12_positives_raw_24_months_2016-01-01_pixel_vector_labels.pkl',
               #'w_nusa_tenggara_v1.1_negatives_raw_12_months_2020-01-01_pixel_vector_labels.pkl',
               #'w_nusa_tenggara_v1.1_positives_raw_24_months_2019-01-01_pixel_vector_labels.pkl'
             ]

pixel_vectors = []
labels = []
for data, label in zip(data_files, label_files):
    with open(os.path.join(train_data_dir, data), 'rb') as f:
        pixel_vectors += pickle.load(f)
    with open(os.path.join(train_data_dir, label), 'rb') as f:
        labels += pickle.load(f)
pixel_vectors = np.array(pixel_vectors)
labels = np.array(labels)

positive_vectors = pixel_vectors[labels == 1]
negative_vectors = pixel_vectors[labels == 0]

print("Loaded", len(positive_vectors), "positive pixel vectors and", len(negative_vectors), "negative pixel vectors")

In [None]:
holdout_data_files = ['tpa_holdout_raw_60_months_2016-01-01_pixel_vectors.pkl']
holdout_label_files = ['tpa_holdout_raw_60_months_2016-01-01_pixel_vector_labels.pkl']
holdout_pixel_vectors = []
holdout_labels = []

for data, label in zip(holdout_data_files, holdout_label_files):
    with open(os.path.join(train_data_dir, data), 'rb') as f:
        holdout_pixel_vectors += pickle.load(f)
    with open(os.path.join(train_data_dir, label), 'rb') as f:
        holdout_labels += pickle.load(f)
holdout_pixel_vectors = np.array(holdout_pixel_vectors)
holdout_labels = np.array(holdout_labels)

### Filter positive samples such that NDVI is within a range
This is useful since the positive patches can include surrounding vegetation

In [None]:
def compute_ndvi(pixel_vectors):
    return (pixel_vectors[:,7] - pixel_vectors[:,3]) / (pixel_vectors[:,7] + pixel_vectors[:,3])

def filter_ndvi(data, lower_bound=0, upper_bound=0.4):
    ndvi = compute_ndvi(data)
    index = np.logical_and(ndvi > lower_bound, ndvi < upper_bound)
    filtered_data = np.array(data)[index]
    print(f"{np.sum(index) / len(data):.1%} of samples within NDVI range")
    return filtered_data

In [None]:
filtered_positive_vectors = filter_ndvi(positive_vectors)
filtered_holdout_vectors = filter_ndvi(holdout_pixel_vectors)

### Combine data and create train test split
Also expand dimensions to account for batches

In [None]:
x = np.concatenate((filtered_positive_vectors, negative_vectors))
y = np.concatenate((np.ones(len(filtered_positive_vectors)), np.zeros(len(negative_vectors))))

x, y = shuffle(x, y, random_state=42)
x = normalize(x)
x_holdout = normalize(filtered_holdout_vectors)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.20, random_state=42)
print("Num Train Samples:\t\t", len(x_train))
print("Num Test Samples:\t\t", len(x_test))
print("Num Holdout Samples:\t\t", len(x_holdout))
print(f"Percent Negative Train:\t {100 * sum(y_train == 0.0) / len(y_train):.1f}")
print(f"Percent Negative Test:\t {100 * sum(y_test == 0.0) / len(y_test):.1f}")

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
x_holdout = np.expand_dims(x_holdout, -1)

# Note: I am accustomed to assigning two classes for binary classification. 
# This habit comes from an issue in theano a long time ago, but I'm too superstitious to change it.
num_classes = 2
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
y_holdout = keras.utils.to_categorical(np.ones(len(filtered_holdout_vectors)), num_classes)

## Create and Train a Model

In [None]:
input_shape = np.shape(x_train[0])
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv1D(16, kernel_size=(3), activation="relu"),
        #layers.MaxPooling2D(pool_size=(2)),
        layers.Conv1D(32, kernel_size=(3), activation="relu"),
        #layers.MaxPooling2D(pool_size=(2)),
        layers.Flatten(),
        layers.Dense(64, activation="relu"),
        layers.Dense(64, activation="relu"),
        layers.Dense(64, activation="relu"),
        #layers.Dense(64, activation="relu"),
        #layers.Dense(64, activation="relu"),
        #layers.Dense(64, activation="relu"),
        #layers.Dropout(0.2),
        layers.Dense(num_classes, activation="softmax"),
    ]
)
model.summary()

### Optional Class Weighting
Over experimental testing, I found that weighting classes seemed to degrade performance. This could use further investigation

In [None]:
from sklearn.utils import class_weight
negative_weight, positive_weight = class_weight.compute_class_weight('balanced', 
                                                                     classes = np.unique(y_train),
                                                                     y = y_train[:,1])
print(f"Negative Weight: {negative_weight:.2f}")
print(f"Positive Weight: {positive_weight:.2f}")

In [None]:
# Compile model. Note that many of these metrics are extraneous. 
# Can be useful to track during training at times though
model.compile(loss="binary_crossentropy", 
              optimizer="adam", 
              metrics=[keras.metrics.Recall(thresholds=(0.7), name='precision'), 
                       keras.metrics.Precision(thresholds=(0.7), name='recall'),
                       keras.metrics.AUC(curve='PR', name='auc'),
                       "accuracy"],
              #loss_weights = sum(y_train) / len(y_train),
              #weighted_metrics = ['accuracy']
             )

### Train the Model

In [None]:
batch_size = 128
epochs = 15

model.fit(x_train, 
          y_train, 
          batch_size=batch_size, 
          epochs=epochs, 
          validation_data=(x_test, y_test),
          #validation_split=0.1,
          #class_weight = {0: negative_weight, 1: positive_weight}
         )

In [None]:
train_accuracy = model.history.history['accuracy']
test_accuracy = model.history.history['val_accuracy']
plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
plt.plot(train_accuracy, label='Train Acc')
plt.plot(test_accuracy, c='r', label='Val Acc')
percent_negative = (sum(y_train == 0.0) / len(y_train))[1]
plt.plot([0, epochs-1], [percent_negative, percent_negative], '--', c='gray', label='Baseline')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Network Train and Val Accuracy')
plt.show()

In [None]:
threshold = 0.8
print("Test Set Metrics:")
print(classification_report(y_test[:,1], model.predict(x_test)[:,1] > threshold, 
                            target_names=['No TPA', 'TPA']))

print("\nHoldout Positive Set Metrics:")
print(classification_report(y_holdout[:,1], model.predict(x_holdout)[:,1] > threshold, 
                            target_names=['No TPA', 'TPA']))

In [None]:
model_name = 'model_v1.1.1_4-22-21.h5'
model.save('../models/' + model_name)

## Evaluate Failures

In [None]:
test_preds = model.predict(x_test)[:,1]

In [None]:
threshold = 0.8
tp = x_test[(test_preds > threshold) & (y_test[:,1] == 1)].squeeze()
tn = x_test[(test_preds <= threshold) & (y_test[:,1] == 0)].squeeze()
fp = x_test[(test_preds > threshold) & (y_test[:,1] == 0)].squeeze()
fn = x_test[(test_preds <= threshold) & (y_test[:,1] == 1)].squeeze()

In [None]:
def plot_pixel_colors(pixels, plot=True):
    """
    Function takes an array of pixels of shape (num_samples, 12)
    Outputs a square array of RGB color values for each input pixel sorted by brightness
    """
    num_samples = int(np.ceil(np.sqrt(len(pixels))))
    padding_len = num_samples ** 2 - len(pixels)
    padded = np.concatenate((pixels, np.zeros((padding_len, 12))))[:,1:4]
    brightness = [np.linalg.norm(pixel) for pixel in padded]
    padded = padded[np.argsort(brightness)[::-1]]
    colors = np.reshape(padded, (num_samples, num_samples, 3)).astype(np.float)
    colors = np.flip(colors,2)
    if plot:
        plt.figure(figsize=(8,8), dpi=150)
        plt.imshow(np.clip(colors, 0, 1))
        plt.xticks([])
        plt.yticks([])
        plt.show()
    return(colors)

In [None]:
plt.figure(figsize=(8,8), dpi=150)
for index, (data, name) in enumerate(zip([tp, tn, fp, fn], ['True Positives', 'True Negatives', 'False Positives', 'False Negatives'])):
    color_array = plot_pixel_colors(data, plot=False)
    plt.subplot(2,2,index + 1)
    plt.title(name)
    plt.imshow(np.clip(color_array, 0, 1))
    plt.xticks([])
    plt.yticks([])
plt.tight_layout()
plt.show()

In [None]:
from matplotlib.lines import Line2D
num_samples = 200
alpha = 0.05

plt.figure(figsize=(12,4), dpi=150)
plt.subplot(1,2,1)
for sample in range(num_samples):
    plt.plot(tp[sample], label='tp', c='C0', alpha=alpha)
    #plt.plot(tn[sample], label='tn', c='r', alpha=alpha)
    plt.plot(fp[sample], label='fp', c='r', alpha=alpha)
    #plt.plot(fn[sample], label='fn', c='purple', alpha=alpha)
legend_lines = [Line2D([0], [0], color='C0', lw=2),
                Line2D([0], [0], color='r', lw=2)]
plt.legend(legend_lines, ['True Positives', 'False Positives'], loc='upper left')
plt.ylim([0, 2])



plt.subplot(1,2,2)
for sample in range(num_samples):
    plt.plot(tn[sample], label='tn', c='C0', alpha=alpha)
    plt.plot(fn[sample], label='fn', c='r', alpha=alpha)

legend_lines = [Line2D([0], [0], color='C0', lw=2),
                Line2D([0], [0], color='r', lw=2)]
plt.legend(legend_lines, ['True Negatives', 'False Negatives'], loc='upper left')
plt.ylim([0, 2])
plt.show()

## Visualize Network Predictions

In [None]:
def make_predictions(model, data, site_name, threshold):
    test_image = data

    rgb_stack = []
    preds_stack = []
    threshold_stack = []
    print("Making Predictions")
    for month in list(test_image.keys()):
        test_pixel_vectors, width, height = get_pixel_vectors(test_image, month)
        if width > 0:
            test_pixel_vectors = normalize(test_pixel_vectors)

            r = np.reshape(np.array(test_pixel_vectors)[:,3], (width, height))
            g = np.reshape(np.array(test_pixel_vectors)[:,2], (width, height))
            b = np.reshape(np.array(test_pixel_vectors)[:,1], (width, height))
            rgb = np.stack((r,g,b), axis=-1)
            rgb_stack.append(rgb)

            preds = model.predict(np.expand_dims(test_pixel_vectors, axis=-1))
            preds_img = np.reshape(preds, (width, height, 2))[:,:,1]
            preds_stack.append(preds_img)

            thresh_img = np.clip(preds_img, threshold, 1)
            threshold_stack.append(thresh_img)
            
    rgb_median = np.median(rgb_stack, axis=0)
    preds_median = np.median(preds_stack, axis=0)
    threshold_median = np.median(threshold_stack, axis=0)

    plt.figure(dpi=150, facecolor=(1,1,1), figsize=(15,5))

    plt.subplot(1,3,1)
    
    adjusted_image = stretch_histogram(rgb_median)
    plt.imshow(adjusted_image)
    plt.title(f'{site_name} Median', size=8)
    plt.axis('off')

    plt.subplot(1,3,2)
    plt.imshow(preds_median, vmin=0, vmax=1, cmap='seismic')
    plt.title('Classification Median', size=8)
    plt.axis('off')

    plt.subplot(1,3,3)
    plt.imshow(threshold_median, cmap='gray')
    plt.title(f"Positive Pixels Median: Threshold {threshold}", size=8)
    plt.axis('off')

    title = f"{site_name} - Median Values - Neural Network Classification - Threshold {threshold}"
    plt.suptitle(title, y=1.01)
    plt.tight_layout()
    #plt.savefig(os.path.join(output_dir, title + '.png'), bbox_inches='tight')
    plt.show()
    
    return rgb_stack, preds_stack, threshold_stack

### Download a test patch

In [None]:
rect_width = 0.02
coords = [106.99772434682319,-6.355577477446754]
num_months = 3
start_date = '2020-03-01'
name = 'Java 01'
patch_history = get_history([coords], 
                            [name], 
                            rect_width,
                            num_months=num_months,
                            start_date=start_date,
                            cloud_mask=True)

### Show composite predictions

In [None]:
threshold = 0.8
rgb_stack, pred_stack, threshold_stack = make_predictions(model, patch_history, name, threshold)

### Show timeseries predictions

In [None]:
threshold = 0.95

for image, pred, date in zip(rgb_stack, pred_stack, list(patch_history.keys())):
    plt.figure(figsize=(9,3), facecolor=(1,1,1), dpi=100)
    plt.subplot(1,3,1)
    plt.imshow(np.clip(stretch_histogram(image), 0, 1))
    plt.axis('off')
    plt.title('RGB')
    
    plt.subplot(1,3,2)
    combo = stretch_histogram(np.copy(image))
    combo[:,:,0][np.ma.masked_greater(pred, threshold).mask] = .7
    combo[:,:,1][np.ma.masked_greater(pred, threshold).mask] = 0
    combo[:,:,2][np.ma.masked_greater(pred, threshold).mask] = 0
    plt.imshow(combo)
    plt.axis('off')
    plt.title(f'Composite - Thresh {threshold}')
    
    plt.subplot(1,3,3)
    plt.imshow(pred, cmap='seismic', vmin=0.5, vmax=1)
    plt.title('Prediction')
    plt.axis('off')
    
    plt.suptitle(date, size=14)
    plt.tight_layout()
    plt.show()