# Notebook for model training

### Training Flow

Load Training Data:
- Load set of input data and label files
- Prefilter positive set only to rule out heavily-masked patches
- Create train/test set
- Define augmentation parameters

Train Model:
- Define model architecture
- Compile model
- Train and evaluate model
- Save model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import date
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.notebook import tqdm

parent_dir = os.path.split(os.getcwd())[0]
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from scripts import dl_utils
from scripts import viz_tools

## Open Data

In [None]:
resolution = 44

data_files = ['bolivar_2020_thresh_0.8_1_negatives_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'amazonas_2020_thresh_0.5_2_negatives_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'riverbank_negatives_2019-06-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'amazonas_2020_thresh_0.8_sumbsample_3_positives_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'MinesPos2018-2020Sentinel_points_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'bolivar_2020_thresh_0.8_sumbsample_5_positives_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'full_amazon_v9_negatives_2020-01-01_2021-02-01_period_3_method_median_patch_arrays.pkl',
              'v2.0_bolivar_negatives_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              #'v2.1_bolivar_positives_v2_2020-01-01_2021-07-01_period_6_method_median_patch_arrays.pkl',
              'v2.1.1_bolivar_negatives_2020-01-01_2021-01-01_period_4_method_median_patch_arrays.pkl',
              'v2.4_amazonas_negatives_2020-01-01_2021-02-01_period_4_method_median_patch_arrays.pkl',
              'v2.4_amazon_negatives_2020-01-01_2021-02-01_period_4_method_median_patch_arrays.pkl',
              'v2.4_amazon_positives_2020-01-01_2021-02-01_period_6_method_median_patch_arrays.pkl',
              'random_land_patch_arrays.pkl',
              'v2.6_amazon_thresh_0.8_negatives_2020-01-01_2021-02-01_period_4_method_median_patch_arrays.pkl'
              ]

label_files = [f.split('s.pkl')[0] + '_labels.pkl' for f in data_files]

patches = []
labels = []

data_dir = os.path.join('..', 'data', 'training_data', 'patch_composites_48px')

for data, label in tqdm(zip(data_files, label_files), total=len(data_files)):
    with open(os.path.join(data_dir, data), 'rb') as f:
        data = pickle.load(f)
        for elem in data:
            patch = dl_utils.trim_patch(elem, resolution)
            patches.append(patch)
    with open(os.path.join(data_dir, label), 'rb') as f:
        label = pickle.load(f)
        labels = np.concatenate((labels, label))

patches = np.array(patches)
positive_patches = patches[labels == 1]
negative_patches = patches[labels == 0]

print(len(patches), "samples loaded")
print(sum(labels == 1), "positive samples")
print(sum(labels == 0), "negative samples")

In [None]:
num_samples = 64
indices = np.random.randint(0, len(patches), num_samples)
viz_tools.plot_image_grid(patches[indices], labels=[int(label) for label in labels[indices]])

In [None]:
def filter_black(data, mask_limit=0.1, return_rejects=False):
    masked_fraction = np.array([np.sum(np.mean(patch, axis=-1) < 10) / np.size(np.mean(patch, axis=-1)) for patch in data])
    filtered_data = data[masked_fraction < mask_limit]
    print(f"{len(filtered_data) / len(data) :.1%} of data below brightness limit")
    if return_rejects:
        rejected_data = data[masked_fraction >= mask_limit]
        return filtered_data, rejected_data
    else:
        return filtered_data

In [None]:
# Filter positive pixels that are masked beyond a threshold. Don't want to give positive examples of cloud-masked patches
filtered_positives, rejects = filter_black(positive_patches, mask_limit = 0.1, return_rejects=True)

## Prepare Data for Training

In [None]:
#for RGBIR, x = normalize(np.copy(images[:,:,:,[1,2,3,8]]))
x = np.concatenate((negative_patches, filtered_positives))
y = np.concatenate((np.zeros(len(negative_patches)), np.ones(len(filtered_positives))))
x = np.array([dl_utils.unit_norm(patch) for patch in x])
x, y = shuffle(x, y, random_state=22)

In [None]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.10, random_state=42)
print("Num Train:\t\t", len(x_train))
print("Num Test:\t\t", len(x_test))
print(f"Percent Negative Train:\t {100 * sum(y_train == 0.0) / len(y_train):.1f}")
print(f"Percent Negative Test:\t {100 * sum(y_test == 0.0) / len(y_test):.1f}")

num_classes = 2
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

## Data Augmentation
Augmentation is important to help the model train. These are parameters that generally worked, but there is certainly room for improvement.

Plot a set of images to give an example of augmented images

In [None]:
augmentation_parameters = {
    'featurewise_center': False,
    'rotation_range': 360,
    'width_shift_range': [0.9, 1.1],
    'height_shift_range': [0.9, 1.1],
    'shear_range': 10,
    'zoom_range': [0.8, 1.2],
    'vertical_flip': True,
    'horizontal_flip': True,
    # Fill options: "constant", "nearest", "reflect" or "wrap"
    'fill_mode': 'reflect'
}

datagen = ImageDataGenerator(**augmentation_parameters)


plt.figure(figsize=(12,12), facecolor=(1,1,1), dpi=150)
img, labels = datagen.flow(x_train, y_train, batch_size=64).next()
for index, (image, label) in enumerate(zip(img, labels)):
    rgb = (image[:,:,3:0:-1] + 1) / 4
    plt.subplot(8, 8, index+1)
    plt.imshow(np.clip(rgb, 0, 1))
    if label[1] == 1:
        plt.title('Mine')
    else:
        plt.title('No Mine')
    plt.axis('off')
plt.suptitle('Data Augmentation Examples')
plt.tight_layout()
plt.show()
    

## Create Model

In [None]:
input_shape = x_train.shape[1:]
print("Input Shape:", input_shape)

In [None]:
model = keras.Sequential([
        keras.Input(shape=input_shape),
        layers.Conv2D(16, kernel_size=(3), padding='same', activation="relu"),
        layers.MaxPooling2D(pool_size=(3)),
        layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
        layers.MaxPooling2D(pool_size=(2)),
        layers.Flatten(),
        layers.Dense(64, activation="relu"),
        layers.Dropout(0.3),
        layers.Dense(64, activation="relu"),
        layers.Dropout(0.3),
        layers.Dense(64, activation="relu"),
        #layers.Dense(32, activation="relu"),
        layers.Dropout(0.3),
        #layers.Dense(32, activation="relu"),
        layers.Dense(num_classes, activation="softmax")])
model.summary()

model.compile(loss="binary_crossentropy", 
              optimizer="adam", 
              metrics=["accuracy"])

train_accuracy = []
test_accuracy = []

## Train Model

In [None]:
batch_size = 32
epochs = 256
model.fit(datagen.flow(x_train, y_train, batch_size=batch_size), 
        epochs=epochs, 
        validation_data = (x_test, y_test),
        verbose = 1
        )
train_accuracy += model.history.history['accuracy']
test_accuracy += model.history.history['val_accuracy']

### Evaluate model training and test set performance

In [None]:
plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
plt.plot(train_accuracy, label='Train Acc')
plt.plot(test_accuracy, c='r', label='Val Acc')
percent_negative = (sum(y_train == 0.0) / len(y_train))[1]
plt.plot([0, len(train_accuracy)], [percent_negative, percent_negative], '--', c='gray', label='Baseline')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Network Train and Val Accuracy')
plt.show()

In [None]:
threshold = 0.8
print("Test Set Metrics:")
print(classification_report(y_test[:,1], model.predict(x_test)[:,1] > threshold, 
                            target_names=['No Mine', 'Mine']))

### Save Model

In [None]:
version_number = '2.8'
current_date = date.today()
model_name = f"{resolution}px_v{version_number}_{current_date.isoformat()}"

#assert not os.path.exists('../models/' + model_name + '.h5'), f"Model of name {model_name} already exists"

with open('../models/' + model_name + '_config.txt', 'w') as f:
    f.write('Input Data:\n')
    [f.write('\t' + file + '\n') for file in data_files]
    f.write('\n\nAugmentation Parameters:\n')
    for k, v in zip(augmentation_parameters.keys(), augmentation_parameters.values()):
        f.write(f"\t{k}: {v}\n")
    f.write(f"\nBatch Size: {batch_size}")
    f.write(f"\nTraining Epochs: {len(train_accuracy)}")
    f.write(f'\n\nClassification Report at {threshold}\n')
    f.write(classification_report(y_test[:,1], model.predict(x_test)[:,1] > threshold, 
                            target_names=['No Mine', 'Mine']))
        

model.save(f'../models/{model_name}.h5')

# Evaluate Model Performance Characteristics

Find the threshold that maximizes performance on the test set. Note that while this may be the optimum performance on the test set, it does not account for the fact that false positives are functionally worse than false negatives.

In [None]:
test_model = model

val_images = x_test
val_labels = y_test

thresh = []
score = []
for threshold in range(2, 100, 2):
    threshold /= 100
    thresh.append(threshold)
    test_labels = [np.argmax(y) for y in val_labels]
    test_preds = [pred > threshold for pred in test_model.predict(val_images)[:,1]]
    score.append(1 - np.sum(np.array(test_labels) != np.array(test_preds)) / len(test_labels))
    #print(np.sum(np.array(test_labels) != np.array(test_preds)), "of", len(test_labels), "test set predictions incorrect")
plt.plot(thresh, score)
plt.ylabel('Success Rate')
plt.xlabel('Threshold')
plt.title(f"Optimal Threshold: {thresh[np.argmax(score)]:.2f}")
plt.show()
optimal_threshold = thresh[np.argmax(score)]

Plot images that the model classifies incorrectly. Can be useful to evaluate model bias.

In [None]:
threshold = optimal_threshold

test_labels = [np.argmax(y) for y in val_labels]
test_preds = [pred > threshold for pred in test_model.predict(val_images)[:,1]]
for index, (label, pred, img) in enumerate(zip(test_labels, test_model.predict(val_images)[:,1], val_images)):
    if pred < threshold:
        binary_pred = 0
    else:
        binary_pred = 1
    if label != binary_pred:
        rgb = (img[:,:,3:0:-1] + 1) / 3
        plt.imshow(np.clip(rgb, 0, 1))
        plt.title(f"label: {label} - pred: {pred:.2f}")
        plt.axis('off')
        plt.show()