# Notebook for model training

Trains streaming pre-downloaded raster data on disk. Expected data folder structure:

```
training_patches/
├── train/
│   ├── 0/
│   └── 1/
└── val/
    ├── 0/
    └── 1/
```

In [None]:
from datetime import date, datetime
import glob
import math
import os
import random
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from sklearn.metrics import classification_report, f1_score
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.notebook import tqdm

parent_dir = os.path.split(os.getcwd())[0]
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import model_library

WORK_DIR = '..'

%load_ext autoreload
%autoreload 2

In [None]:
# GPU functionality sanity check - should run fast
print(tf.config.list_physical_devices('GPU'))
x = tf.random.normal([4096, 4096])
y = tf.matmul(x, x) 
y

### Data generation / augmentation

#### TF / In-RAM version - current preferred

In [None]:

# ----------------------------
# 1️⃣ Define augmentation pipeline (reuse a single instance)
# ----------------------------
def get_satellite_augmentation_pipeline():
    return keras.Sequential([
        keras.layers.RandomFlip("horizontal_and_vertical"),
        keras.layers.RandomRotation(factor=1.0, fill_mode="reflect"),  # full 360°
        keras.layers.RandomTranslation(height_factor=0.1, width_factor=0.1, fill_mode="reflect"),
        keras.layers.RandomZoom(height_factor=(-0.1, 0.1), width_factor=(-0.1, 0.1), fill_mode="reflect"),
        keras.layers.RandomContrast(factor=0.1),
        keras.layers.GaussianNoise(stddev=0.01)
    ])

aug_pipeline = get_satellite_augmentation_pipeline()

def load_dataset(data_dir, bands_to_use=None):
    """
    Loads all images from '0' and '1' subdirectories into RAM.

    Returns:
        X: np.ndarray of shape (num_samples, H, W, C)
        y: np.ndarray of shape (num_samples,)
    """
    files_class_0 = glob.glob(os.path.join(data_dir, '0', '*.tif'))
    files_class_1 = glob.glob(os.path.join(data_dir, '1', '*.tif'))
    files = files_class_0 + files_class_1

    if not files:
        raise FileNotFoundError(f"No .tif files found in '0' or '1' subdirectories of {data_dir}")

    imgs, labels = [], []

    for file_path in files:
        import rasterio
        with rasterio.open(file_path) as src:
            arr = src.read()  # (bands, H, W)
            if bands_to_use is not None:
                arr = arr[bands_to_use, :, :]
            arr = np.moveaxis(arr, 0, -1)  # (H, W, C)
            arr = arr.astype(np.float32) / 10000.0
            imgs.append(arr)

        label_str = os.path.basename(os.path.dirname(file_path))
        labels.append(int(label_str))

    X = np.stack(imgs, axis=0)
    y = np.array(labels, dtype=np.int32)

    return X, y

# ----------------------------
# 2️⃣ Dataset creation function
# ----------------------------
def make_tf_dataset(X, y, batch_size=8, shuffle=True, augment=True):
    """
    X, y: NumPy arrays 
    batch_size: int
    shuffle: whether to shuffle dataset
    augment: whether to apply augmentation
    """
    dataset = tf.data.Dataset.from_tensor_slices((X, y))

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(X))

    if augment:
        # Apply augmentation on CPU to avoid GPU memory spikes
        def augment_on_cpu(x, y):
            with tf.device("/CPU:0"):
                x_aug = aug_pipeline(x)
            return x_aug, y

        dataset = dataset.map(augment_on_cpu, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=8)  # small buffer reduces GPU memory spikes
    return dataset


Use pattern:

```
X_train, y_train = load_dataset(os.path.join(data_dir, 'train'))
X_val, y_val = load_dataset(os.path.join(data_dir, 'val'))

train_ds = make_tf_dataset(X_train, y_train, batch_size=batch_size, augment=True, shuffle=True)
val_ds = make_tf_dataset(X_val, y_val, batch_size=batch_size, augment=False, shuffle=False)
```

## Train Model

In [None]:
input_shape = (48, 48, 13)
print("Input Shape:", input_shape)

In [None]:
model = model_library.ResNet18(input_shape=input_shape)
model_name = '48px_v1.1_ResNet18_2025-10-14'

model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(1e-3), 
    loss=keras.losses.BinaryCrossentropy(from_logits=False), 
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
    run_eagerly=True
)

model.summary()

In [None]:
data_dir = os.path.join(WORK_DIR, 'data/training_patches2025-09-26T15:38')

positive_paths =  glob.glob(f"{data_dir}/train/1/*.tif")
negative_paths = glob.glob(f"{data_dir}/train/0/*.tif")
pos_val_paths = glob.glob(f"{data_dir}/val/1/*.tif")
neg_val_paths = glob.glob(f"{data_dir}/val/0/*.tif")
print(f"{len(positive_paths)} train positives")
print(f"{len(negative_paths)} train negatives")
print(f"{len(pos_val_paths)} val positives")
print(f"{len(neg_val_paths)} val negatives")

In [None]:
X_train, y_train = load_dataset(os.path.join(data_dir, 'train'))
X_val, y_val = load_dataset(os.path.join(data_dir, 'val'))

batch_size = 8
train_ds = make_tf_dataset(X_train, y_train, batch_size=batch_size, augment=True, shuffle=True)
val_ds = make_tf_dataset(X_val, y_val, batch_size=batch_size, augment=False, shuffle=False)

In [None]:
# Reload a model for further training

model_name = 'ResNet1820250829_131606'
model = keras.models.load_model(os.path.join(WORK_DIR, f'checkpoints/{model_name}.h5'))

model.compile(
    optimizer=keras.optimizers.legacy.Adam(3e-5), 
    loss=keras.losses.BinaryCrossentropy(from_logits=False), 
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
    run_eagerly=True   # Required w/ GPU on Mac
)

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join(WORK_DIR, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

try: 
    checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}{timestamp}.h5")
except NameError: 
    checkpoint_path = os.path.join(checkpoint_dir, f"best_model{timestamp}.h5")

checkpoint_cb = ModelCheckpoint(
    filepath=checkpoint_path,
    monitor="val_acc",
    save_best_only=True,
    save_weights_only=False,
    mode="max",
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_acc",
    patience=40,
    mode="max",
    restore_best_weights=True,
    verbose=1
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor="val_acc",
    factor=0.33,
    patience=20,
    min_delta=0.005,
    min_lr=1e-6,
    verbose=1
)



In [None]:
model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=100, 
        verbose=1,
        callbacks=[checkpoint_cb, reduce_lr_cb]#, earlystop_cb,]
)

In [None]:
epoch = 171
resolution = 48
version_number = 'v1.1_ResNet18_2025-10-14'
current_date = date.today()
model_path = os.path.join(WORK_DIR, f"checkpts-tmp/{resolution}px_v{version_number}ep{epoch}_{current_date.isoformat()}.h5")

assert not os.path.exists(model_path), f"Model {model_path} already exists"

model.save(model_path)
print(f"Saved {model_path}")

# Evaluate Model Performance Characteristics

Find the threshold that maximizes performance on the test set. Note that while this may be the optimum performance on the test set, it does not account for the fact that false positives are functionally worse than false negatives.

In [None]:
model_name = '48px_v1.1_ResNet18ep171_2025-10-16'
model = keras.models.load_model(os.path.join(WORK_DIR, f'checkpts-tmp/{model_name}.h5'))


In [None]:
model.summary()

In [None]:
batch_size = 32

# Assuming S2L1C 13 band input imagery
bands_to_use = list(range(13))  
if model.input_shape[-1] == 12:
    bands_to_use.remove(10)  # drop B10 for old models

X_val, y_val = load_dataset(os.path.join(data_dir, 'val'), bands_to_use=bands_to_use)
val_ds = make_tf_dataset(X_val, y_val, batch_size=batch_size, augment=False, shuffle=False)

In [None]:
with tf.device("/CPU:0"):
    preds = model.predict(val_ds, verbose=1)


In [None]:
preds.shape

In [None]:
# New models
preds = preds.squeeze()
preds.shape

In [None]:
# For the old ensemble
#print(preds.shape)
#preds = np.mean(preds, axis=1)
#preds.shape

In [None]:
def acc_curve(preds, y_true, thresholds=np.arange(.01, 1.01, .01)):
    """Compute accuracy curve as function of threshold"""
    score = [np.sum((preds >= t).astype('int') == y_true) / len(y_true) for t in thresholds]
    plt.plot(thresholds, score)
    plt.ylabel('Success Rate')
    plt.xlabel('Threshold')
    plt.title(f"Optimal Threshold: {thresholds[np.argmax(score)]:.2f} w/ accuracy {score[np.argmax(score)]:.2f}")

acc_curve(preds, y_val)


In [None]:
def f1_curve(preds, y_true, thresholds=np.arange(.01, 1.01, .01)):
    """Compute F1 curve."""
    f1s = []
    for t in thresholds:
        y_pred = (preds >= t)
        f1s.append(f1_score(y_true, y_pred))

    fig, ax = plt.subplots()
    ax.plot(thresholds, f1s, label='Patchwise')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('F1 score')
    ax.legend(loc='lower left')
    plt.title(f"Optimal Threshold: {thresholds[np.argmax(f1s)]:.2f} w/ F1 {f1s[np.argmax(f1s)]:.2f}")
    return fig, ax

f1_curve(preds, y_val)

In [None]:
threshold = 0.9
report = classification_report(y_val, preds > threshold, target_names=['No Mine', 'Mine'], output_dict=True)
df = pd.DataFrame(report).transpose()
df


In [None]:
threshold = 0.99
target_names = ['No Mine', 'Mine']
training_dataset = 'collected_locations2025-09-26T15:38.geojson'

model_path = os.path.join(WORK_DIR, f'checkpts-tmp/{model_name}.h5')
with open(model_path.split('.h5')[0] + f"_config-t{threshold}.txt", 'w') as f:
    f.write(f'Training dataset: {training_dataset}')
    f.write(f"\nBatch Size: {batch_size}")
    f.write(f'\n\nClassification Report at {threshold}\n')
    f.write(classification_report(y_val, preds > threshold, target_names=target_names))
   

Plot images that the model classifies incorrectly. Can be useful to evaluate model bias.

In [None]:
threshold = 0.99
test_model = model
val_images = x_test
val_labels = y_test
test_labels = val_labels
test_preds = test_model.predict(val_images)
for index, (label, pred, img) in enumerate(zip(test_labels, test_preds, val_images)):
    pred = pred[0]
    if pred < threshold:
        binary_pred = 0
    else:
        binary_pred = 1
    if label != binary_pred:
        rgb = (img[:,:,3:0:-1] * 10000 / 3000)
        fig = plt.figure(figsize=(2,2), facecolor=(1,1,1), dpi=150)
        plt.imshow(np.clip(rgb, 0, 1))
        plt.title(f"label: {label} - pred: {pred:.2f}")
        plt.axis('off')
        plt.show()