# Train Patch Classifier
This notebook loads inputs created in `create_patch_dataset` to train a spatial classifier


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import pickle

import cv2
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm

from scripts.get_s2_data_ee import get_history, get_history_polygon, get_pixel_vectors
from scripts.viz_tools import stretch_histogram, create_img_stack, normalize, band_descriptions

## Load Data
Data is generated in the `create_spatial_patches` notebook

In [None]:
train_data_dir = '../data/training_data/patches/'

data_files = ['positive_patches_12_months_2019-01-01_45px_patches.pkl', 
              'negative_patches_6_months_2019-03-01_45px_patches.pkl',
              'w_nusa_tenggara_v1.1_positives_patches_12_months_2020-01-01_45px_patches.pkl',
              'w_nusa_tenggara_v1.1_negatives_patches_12_months_2020-01-01_45px_patches.pkl'
             ]

label_files = ['positive_patches_12_months_2019-01-01_45px_patch_labels.pkl', 
               'negative_patches_6_months_2019-03-01_45px_patch_labels.pkl',
               'w_nusa_tenggara_v1.1_positives_patches_12_months_2020-01-01_45px_patch_labels.pkl',
               'w_nusa_tenggara_v1.1_negatives_patches_12_months_2020-01-01_45px_patch_labels.pkl'
             ]

patches = []
labels = []
for data, label in zip(data_files, label_files):
    with open(os.path.join(train_data_dir, data), 'rb') as f:
        patches += pickle.load(f)
    with open(os.path.join(train_data_dir, label), 'rb') as f:
        labels += pickle.load(f)
patches = np.array(patches)
labels = np.array(labels)

positive_patches = patches[labels == 1]
negative_patches = patches[labels == 0]

print("Loaded", len(positive_patches), "positive patches and", len(negative_patches), "negative patches")

## Prepare training dataset

In [None]:
x = normalize(patches)
y = labels
x, y = shuffle(x, y, random_state=42)

In [None]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.10, random_state=42)
print("Num Train:\t\t", len(x_train))
print("Num Test:\t\t", len(x_test))
print(f"Percent Negative Train:\t {100 * sum(y_train == 0.0) / len(y_train):.1f}")
print(f"Percent Negative Test:\t {100 * sum(y_test == 0.0) / len(y_test):.1f}")

num_classes = 2
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

In [None]:
# Visualize data with no augmentation
datagen = ImageDataGenerator(
    rotation_range=0
)

plt.figure(figsize=(12,12), facecolor=(1,1,1))
images, labels = datagen.flow(x_train, y_train, batch_size=36).next()
for index, (image, label) in enumerate(zip(images, labels)):

    rgb = np.stack((image[:,:,3],
                    image[:,:,2], 
                    image[:,:,1]), axis=-1)
    rgb_stretch = stretch_histogram(rgb, max_val = 1)
    plt.subplot(6,6,index+1)
    plt.imshow(np.clip(rgb_stretch, 0, 1))
    if label[1] == 1:
        plt.title('Waste')
    else:
        plt.title('No Waste')
    plt.axis('off')
plt.suptitle('Dataset Examples - No Augmentation', size=16)
plt.tight_layout()
plt.show()

In [None]:
datagen = ImageDataGenerator(
    rotation_range=360,
    width_shift_range=[0.8, 1.2],
    height_shift_range=[0.8, 1.2],
    #shear_range=10,
    zoom_range=[0.8, 1.5],
    vertical_flip=True,
    horizontal_flip=True,
    fill_mode='reflect'
)


plt.figure(figsize=(12,12), facecolor=(1,1,1))
images, labels = datagen.flow(x_train, y_train, batch_size=36).next()
for index, (image, label) in enumerate(zip(images, labels)):

    rgb = np.stack((image[:,:,3],
                    image[:,:,2], 
                    image[:,:,1]), axis=-1)
    rgb_stretch = stretch_histogram(rgb, max_val = .8)
    plt.subplot(6,6,index+1)
    plt.imshow(np.clip(rgb_stretch, 0, 1))
    if label[1] == 1:
        plt.title('Waste')
    else:
        plt.title('No Waste')
    plt.axis('off')
plt.suptitle('Data Augmentation Examples', size=16)
plt.tight_layout()
plt.show()
    

## Train Network

In [None]:
input_shape = np.shape(x_train[0])
print("Input Shape:", input_shape)

In [None]:
model = keras.Sequential([
        keras.Input(shape=input_shape),
        layers.Conv2D(16, kernel_size=(3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2)),
        layers.Conv2D(32, kernel_size=(3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2)),
        layers.Conv2D(64, kernel_size=(3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2)),
        layers.Flatten(),
        layers.Dense(32, activation="relu"),
        layers.Dropout(0.1),
        layers.Dense(32, activation="relu"),
        layers.Dropout(0.1),
        #layers.Dense(32, activation="relu"),
        #layers.Dropout(0.1),
        layers.Dense(num_classes, activation="softmax")])
model.summary()

In [None]:
model.compile(loss="binary_crossentropy", 
              optimizer="adam", 
              metrics=["accuracy"])
train_accuracy = []
test_accuracy = []

In [None]:
batch_size = 32
epochs = 200
model.fit(datagen.flow(x_train, y_train, batch_size=batch_size), 
          epochs=epochs, 
          validation_data = (x_test, y_test),
          verbose = 1
         )

In [None]:
train_accuracy += model.history.history['accuracy']
test_accuracy += model.history.history['val_accuracy']
plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
plt.plot(train_accuracy, label='Train Acc')
plt.plot(test_accuracy, c='r', label='Val Acc')
percent_negative = (sum(y_train == 0.0) / len(y_train))[1]
plt.plot([0, epochs], [percent_negative, percent_negative], '--', c='gray', label='Baseline')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Network Train and Val Accuracy')
plt.show()

In [None]:
model_name = 'v1.1.0_200_4-23-21'
model_path = f'../models/{model_name}_patch_classifier_{patches.shape[1]}px_patch.h5'
print('Saving model to', model_path)
model.save(model_path)