# Notebook for model training

### Training Flow

Load Training Data:
- Load set of input data and label files
- Prefilter positive set only to rule out heavily-masked patches
- Create train/test set
- Define augmentation parameters

Train Model:
- Define model architecture
- Compile model
- Train and evaluate model
- Save model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import date
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.notebook import tqdm

parent_dir = os.path.split(os.getcwd())[0]
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from scripts import dl_utils
from scripts import viz_tools

## Open Data

In [None]:
resolution = 48

data_file_names = [
                    "bolivar_2020_thresh_0.8_1_negatives",
                    "amazonas_2020_thresh_0.5_2_negatives",
                    "riverbank_negatives",
                    "amazonas_2020_thresh_0.8_sumbsample_3_positives",
                    "MinesPos2018-2020Sentinel_points",
                    "bolivar_2020_thresh_0.8_sumbsample_5_positives",
                    "full_amazon_v9_negatives",
                    "v2.0_bolivar_negatives",
                    "v2.1.1_bolivar_negatives",

              ]

data_file_names = [
                    'bolivar_2020_thresh_0.8_1_negatives',
                    'amazonas_2020_thresh_0.5_2_negatives',
                    'riverbank_negatives',
                    'amazonas_2020_thresh_0.8_sumbsample_3_positives',
                    'MinesPos2018-2020Sentinel_points',
                    #'bolivar_2020_thresh_0.8_sumbsample_5_positives',
                    'full_amazon_v9_negatives',
                    'v2.0_bolivar_negatives',
                    'v2.1.1_bolivar_negatives',
                    'v2.4_amazonas_negatives',
                    'v2.4_amazon_negatives',
                    'v2.4_amazon_positives',
                    'v2.6_amazon_thresh_0.8_negatives',
                    'v2.6_amazon_negatives',
                    'v2.6_amazon_negatives_v2'
              ]



#start_dates = ['2020-01-01', '2021-01-01', '2022-01-01']
#end_dates = ['2021-01-01', '2022-01-01', '2023-01-01']
start_dates = ['2019-01-01']
end_dates = ['2020-01-01']
dates = [f'{sd}_{ed}' for sd,ed in zip(start_dates, end_dates)]
data_files = [f'{fn}_{sd}_{ed}_patch_arrays.pkl' for fn in data_file_names for sd, ed in zip(start_dates, end_dates)]
label_files = [f.split('s.pkl')[0] + '_labels.pkl' for f in data_files]


data_files += [
    'v3.1_2023_negatives_2023-01-01_2024-01-01_patch_arrays.pkl', 
    'amazon_all_48px_v3.1_2023_positives_0.999_2023-01-01_2024-01-01_patch_arrays.pkl', 
    'v3.2_2023_negatives_2023-01-01_2024-01-01_patch_arrays.pkl',
    'v3.3_2023_positives_2023-01-01_2024-01-01_patch_arrays.pkl',
    'v3.4_2023_negatives_2023-01-01_2024-01-01_patch_arrays.pkl',
    'v3.5_2023_negatives_2023-01-01_2024-01-01_patch_arrays.pkl',
    'v3.6_2023_positives_2023-01-01_2024-01-01_patch_arrays.pkl'
    ]
label_files += [
    'v3.1_2023_negatives_2023-01-01_2024-01-01_patch_array_labels.pkl', 
    'amazon_all_48px_v3.1_2023_positives_0.999_2023-01-01_2024-01-01_patch_array_labels.pkl', 
    'v3.2_2023_negatives_2023-01-01_2024-01-01_patch_array_labels.pkl',
    'v3.3_2023_positives_2023-01-01_2024-01-01_patch_array_labels.pkl',
    'v3.4_2023_negatives_2023-01-01_2024-01-01_patch_array_labels.pkl',
    'v3.5_2023_negatives_2023-01-01_2024-01-01_patch_array_labels.pkl',
    'v3.6_2023_positives_2023-01-01_2024-01-01_patch_array_labels.pkl'
    ]

patches = []
labels = []

data_dir = os.path.join('..', 'data', 'training_data', f"{resolution}_px")

for data, label in tqdm(zip(data_files, label_files), total=len(data_files)):
    with open(os.path.join(data_dir, data), 'rb') as f:
        data = pickle.load(f)
        for elem in data:
            #patch = dl_utils.pad_patch(elem, resolution)
            patches.append(elem)
    with open(os.path.join(data_dir, label), 'rb') as f:
        label = pickle.load(f)
        labels = np.concatenate((labels, label))

patches = np.array(patches)
positive_patches = patches[labels == 1]
negative_patches = patches[labels == 0]

print(len(patches), "samples loaded")
print(sum(labels == 1), "positive samples")
print(sum(labels == 0), "negative samples")

In [None]:
num_samples = 64
indices = np.random.randint(0, len(patches), num_samples)
viz_tools.plot_image_grid(patches[indices], labels=[int(label) for label in labels[indices]])

In [None]:
def filter_black(data, mask_limit=0.1, return_rejects=False):
    masked_fraction = np.array([np.sum(np.mean(patch, axis=-1) < 10) / np.size(np.mean(patch, axis=-1)) for patch in data])
    filtered_data = data[masked_fraction < mask_limit]
    print(f"{len(filtered_data) / len(data) :.1%} of data below brightness limit")
    if return_rejects:
        rejected_data = data[masked_fraction >= mask_limit]
        return filtered_data, rejected_data
    else:
        return filtered_data

In [None]:
# Filter positive pixels that are masked beyond a threshold. Don't want to give positive examples of cloud-masked patches
filtered_positives, positive_rejects = filter_black(positive_patches, mask_limit = 0.1, return_rejects=True)
if len(positive_rejects) > 0:
    num_samples = 16**2
    if len(positive_rejects) < num_samples:
        num_samples = len(positive_rejects)
    indices = np.random.randint(0, len(positive_rejects), num_samples)
    viz_tools.plot_numpy_grid(positive_rejects[indices,:,:,3:0:-1] / 3000)
    plt.title("Positive Masked Rejects")
    plt.show()

num_samples = 25 ** 2
indices = np.random.randint(0, len(filtered_positives), num_samples)
fig = viz_tools.plot_numpy_grid(filtered_positives[indices,:,:,3:0:-1] / 3000)
plt.title('Positive Filtered Samples')
plt.show()

In [None]:
filtered_negatives, negative_rejects = filter_black(negative_patches, mask_limit = 0.6, return_rejects=True)
if len(negative_rejects) > 0:
    num_samples = 16**2
    if len(positive_rejects) < num_samples:
        num_samples = len(positive_rejects)
    indices = np.random.randint(0, len(negative_rejects), num_samples)
    viz_tools.plot_numpy_grid(negative_rejects[indices,:,:,3:0:-1] / 3000)
    plt.title("Negative Mask Rejects")
    plt.show()

num_samples = 25 ** 2
indices = np.random.randint(0, len(filtered_negatives), num_samples)
fig = viz_tools.plot_numpy_grid(filtered_negatives[indices,:,:,3:0:-1] / 3000)
plt.title('Negative Filtered Samples')
plt.show()

## Prepare Data for Training

In [None]:
#for RGBIR, x = normalize(np.copy(images[:,:,:,[1,2,3,8]]))
x = np.concatenate((filtered_negatives, filtered_positives))
y = np.concatenate((np.zeros(len(filtered_negatives)), np.ones(len(filtered_positives))))
x, y = shuffle(x, y, random_state=33)

## Data Augmentation
Augmentation is important to help the model train. These are parameters that generally worked, but there is certainly room for improvement.

Plot a set of images to give an example of augmented images

In [None]:
augmentation_parameters = {
    'featurewise_center': False,
    'rotation_range': 360,
    'width_shift_range': [0.9, 1.1],
    'height_shift_range': [0.9, 1.1],
    'shear_range': 10,
    'zoom_range': [0.9, 1.1],
    'vertical_flip': True,
    'horizontal_flip': True,
    # Fill options: "constant", "nearest", "reflect" or "wrap"
    'fill_mode': 'reflect'
}

datagen = ImageDataGenerator(**augmentation_parameters)


plt.figure(figsize=(12,12), facecolor=(1,1,1), dpi=150)
aug_img, aug_labels = datagen.flow(x, y, batch_size=64).next()
viz_tools.plot_image_grid(aug_img, labels=[int(l) for l in aug_labels], norm=True);


In [None]:
x_norm = np.clip(np.array(x.astype("float32") / 10000), 0, 1)

## Train Ensemble

In [None]:

for seed, name in zip([1, 2, 3, 4, 5, 6, 7, 8, 9], ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']):
        keras.utils.set_random_seed(seed)
        x_train, x_test, y_train, y_test = train_test_split(x_norm, y, test_size=0.30, random_state=seed)
        print(f"Min Value:\t\t {np.min(x_train):.3f}")
        print(f"Max Value:\t\t {np.max(x_train):.3f}")
        print(f"Median Value:\t\t {np.median(x_train):.3f}")
        print("Num Train:\t\t", len(x_train))
        print("Num Test:\t\t", len(x_test))
        print(f"Percent Negative Train:\t {100 * sum(y_train == 0.0) / len(y_train):.1f}")
        print(f"Percent Negative Test:\t {100 * sum(y_test == 0.0) / len(y_test):.1f}")

        num_classes = 2
        input_shape = x_train.shape[1:]
        print("Input Shape:", input_shape)

        model = keras.Sequential([
                keras.Input(shape=input_shape),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.MaxPooling2D(pool_size=(2)),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.MaxPooling2D(pool_size=(2)),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.Conv2D(32, kernel_size=(3), padding='same', activation="relu"),
                layers.MaxPooling2D(pool_size=(3)),
                layers.Flatten(),
                layers.Dense(64, activation="relu"),
                layers.Dropout(0.3),
                layers.Dense(64, activation="relu"),
                layers.Dropout(0.3),
                layers.Dense(32, activation="relu"),
                layers.Dropout(0.3),
                layers.Dense(1, activation='sigmoid')])
        model.summary()
        model.compile(
        optimizer=keras.optimizers.Adam(3e-4), 
        loss=keras.losses.BinaryCrossentropy(from_logits=False), 
        metrics=[keras.metrics.BinaryAccuracy(name="acc")])

        train_accuracy = []
        test_accuracy = []

        batch_size = 32
        epochs = 160
        # add a weighted loss
        class_weight = {0: 1, 1: 1}

        model.fit(datagen.flow(x_train, y_train),
                batch_size=batch_size, 
                epochs=epochs, 
                validation_data = (x_test, y_test),
                verbose = 1,
                shuffle=True,
                class_weight = class_weight
                )
        train_accuracy += model.history.history['acc']
        test_accuracy += model.history.history['val_acc']


        plt.figure(figsize=(8,5), dpi=100, facecolor=(1,1,1))
        plt.plot([i for i in range(1, len(train_accuracy) + 1)], train_accuracy, label='Train Acc')
        plt.plot([i for i in range(1, len(train_accuracy) + 1)], test_accuracy, c='r', label='Val Acc')
        percent_negative = (sum(y_train == 0.0) / len(y_train))
        plt.plot([1, len(train_accuracy) + 1], [percent_negative, percent_negative], '--', c='gray', label='Baseline')
        plt.grid()
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title('Network Train and Val Accuracy')
        plt.show()

        threshold = 0.8
        print("Test Set Metrics:")
        preds = model.predict(x_test)
        print(classification_report(y_test, preds > threshold, 
                                target_names=['No Mine', 'Mine']))

        version_number = f'3.7-{name}'
        current_date = date.today()
        model_name = f"{resolution}px_v{version_number}_{current_date.isoformat()}"

        #assert not os.path.exists('../models/' + model_name + '.h5'), f"Model of name {model_name} already exists"

        with open('../models/' + model_name + '_config.txt', 'w') as f:
                f.write('Input Data:\n')
                [f.write('\t' + file + '\n') for file in data_files]
                f.write('\n\nAugmentation Parameters:\n')
                for k, v in zip(augmentation_parameters.keys(), augmentation_parameters.values()):
                        f.write(f"\t{k}: {v}\n")
                f.write(f"\nBatch Size: {batch_size}")
                f.write(f"\nTraining Epochs: {len(train_accuracy)}")
                f.write(f'\n\nClassification Report at {threshold}\n')
                f.write(classification_report(y_test, model.predict(x_test) > threshold, 
                                        target_names=['No Mine', 'Mine']))
                        

        model.save(f'../models/{model_name}.h5')

# Evaluate Model Performance Characteristics

Find the threshold that maximizes performance on the test set. Note that while this may be the optimum performance on the test set, it does not account for the fact that false positives are functionally worse than false negatives.

In [None]:
test_model = model

val_images = x_test
val_labels = y_test

test_preds = test_model.predict(val_images)
thresh = []
score = []
for threshold in range(2, 100, 2):
    threshold /= 100
    thresh.append(threshold)
    test_labels = [np.argmax(y) for y in val_labels]
    test_out = [pred > threshold for pred in test_preds]
    score.append(1 - (np.sum(np.array(test_labels) != np.array(test_out)[:,0]) / len(test_labels)))
    #print(np.sum(np.array(test_labels) != np.array(test_preds)), "of", len(test_labels), "test set predictions incorrect")
plt.plot(thresh, score)
plt.ylabel('Success Rate')
plt.xlabel('Threshold')
plt.title(f"Optimal Threshold: {thresh[np.argmax(score)]:.2f}")
plt.show()
optimal_threshold = thresh[np.argmax(score)]

Plot images that the model classifies incorrectly. Can be useful to evaluate model bias.

In [None]:
threshold = 0.8
test_model = model
val_images = x_test
val_labels = y_test
test_labels = val_labels
test_preds = test_model.predict(val_images)
for index, (label, pred, img) in enumerate(zip(test_labels, test_preds, val_images)):
    pred = pred[0]
    if pred < threshold:
        binary_pred = 0
    else:
        binary_pred = 1
    if label != binary_pred:
        rgb = (img[:,:,3:0:-1] * 10000 / 3000)
        fig = plt.figure(figsize=(2,2), facecolor=(1,1,1), dpi=150)
        plt.imshow(np.clip(rgb, 0, 1))
        plt.title(f"label: {label} - pred: {pred:.2f}")
        plt.axis('off')
        plt.show()