## Embeddings 

Train a linear probe on the SSL4EO foundation model. 


In [None]:
from datetime import date, datetime
import glob
import joblib
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from sklearn.metrics import classification_report, f1_score
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import legacy as legacy_optimizers
import torch
import torch.nn.functional as F
from tqdm import tqdm

import gee
import model_library
import tile_utils

SSL4EO_PATH = 'SSL4EO'

data_dir = '../data/training_patches2025-10-21T23:00'

In [None]:
def load_dataset(data_dir, bands_to_use=None):
    """
    Loads all images from '0' and '1' subdirectories into RAM.

    Returns:
        X: np.ndarray of shape (num_samples, H, W, C)
        y: np.ndarray of shape (num_samples,)
    """
    files_class_0 = glob.glob(os.path.join(data_dir, '0', '*.tif'))
    files_class_1 = glob.glob(os.path.join(data_dir, '1', '*.tif'))
    files = files_class_0 + files_class_1

    if not files:
        raise FileNotFoundError(f"No .tif files found in '0' or '1' subdirectories of {data_dir}")

    imgs, labels = [], []

    for file_path in files:
        import rasterio
        with rasterio.open(file_path) as src:
            arr = src.read()  # (bands, H, W)
            if bands_to_use is not None:
                arr = arr[bands_to_use, :, :]
            arr = np.moveaxis(arr, 0, -1)  # (H, W, C)
            arr = arr.astype(np.float32) / 10000.0
            imgs.append(arr)

        label_str = os.path.basename(os.path.dirname(file_path))
        labels.append(int(label_str))

    X = np.stack(imgs, axis=0)
    y = np.array(labels, dtype=np.int32)

    return X, y

In [None]:
# Input data

positive_paths =  glob.glob(f"{data_dir}/train/1/*.tif")
negative_paths = glob.glob(f"{data_dir}/train/0/*.tif")
pos_val_paths = glob.glob(f"{data_dir}/val/1/*.tif")
neg_val_paths = glob.glob(f"{data_dir}/val/0/*.tif")
print(f"{len(positive_paths)} train positives")
print(f"{len(negative_paths)} train negatives")
print(f"{len(pos_val_paths)} val positives")
print(f"{len(neg_val_paths)} val negatives")

X_train, y_train = load_dataset(os.path.join(data_dir, 'train'))
X_val, y_val = load_dataset(os.path.join(data_dir, 'val'))
print(f'Training data shape: {X_train.shape}')

### Embedding inference

In [None]:
model_chip_size = 224

# For running on Mac Mx chip
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu") 

print(f'Device: {device}')

In [None]:
# Model - SSL4EO

embed_model_name = 'dino_vit_small_patch16_224.pt'

embed_model = torch.load(os.path.join(SSL4EO_PATH, f'pretrained/{embed_model_name}'), weights_only=False)

In [None]:
embed_model.to(device)
embed_model.eval()

In [None]:
geo_chip_size = 48
batch_size = 4 
output_dim = embed_model.norm.normalized_shape[0]
feature_columns = [f"vit-dino-patch14_{i}" for i in range(output_dim)] 

def quantize(embeddings, lower_bound=-5, upper_bound=5):
    clipped = np.clip(embeddings, lower_bound, upper_bound)
    normalized = (clipped - lower_bound) / (upper_bound - lower_bound)
    scaled = normalized * 255
    return scaled.astype(np.uint8)

quantized = False


In [None]:
def embed(X, y, model, batch_size=4, geo_chip_size=48, model_chip_size=224, quantized=False):
    
    tensor = torch.from_numpy(X)
    
    batch_outputs = []
    for i in tqdm(range(0, len(tensor), batch_size)):
        batch = tensor[i : i + batch_size]
        batch = batch.permute(0, 3, 1, 2)

        if geo_chip_size != model_chip_size:
            batch = F.interpolate(batch, size=(model_chip_size, model_chip_size), 
                                  mode='bicubic', align_corners=False)
        
        batch = batch.to(device)
        
        with torch.no_grad():
            batch_output = model(batch)
            if isinstance(batch_output, dict):
                key = list(batch_output.keys())[0]
                batch_output = batch_output[key]
            batch_outputs.append(batch_output.cpu())
            del batch, batch_output 

    batch_outputs = torch.cat(batch_outputs).numpy()    
    features = quantize(batch_outputs) if quantized else batch_outputs
    features_df = pd.DataFrame(features, columns=feature_columns)
    features_df['label'] = y

    return features_df

In [None]:
df_val = embed(X_val, y_val, embed_model)
df_val['split'] = 'val'
df_val

In [None]:
df_train = embed(X_train, y_train, embed_model)#, batch_size=1)
df_train['split'] = 'train'

In [None]:
df = pd.concat([df_val, df_train])
df

In [None]:
df[['label', 'split']].value_counts()

In [None]:
# Optional. If there were previous training embeddings, load and concat them here: 
prev = pd.read_parquet('../data/training_patches2025-10-21T13:25ssl4eo.parquet')
print(f"Prev len: {len(prev)}")
df = pd.concat([prev, df])
print(f"New len: {len(df)}")

In [None]:
df.to_parquet(data_dir + 'ssl4eo.parquet', index=False)

### Or restore embeddings

In [None]:
df = pd.read_parquet(data_dir + 'ssl4eo.parquet')
df

In [None]:
df_train = df[df.split == 'train']
df_val = df[df.split == 'val']

### MLP training

In [None]:
model_name = '48px_v1.3SSL4EO-MLP'

In [None]:
def make_tf_dataset(X, y, batch_size=8, shuffle=True):
    """
    X, y: NumPy arrays 
    batch_size: int
    shuffle: whether to shuffle dataset
    """
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(X))
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=8)  # small buffer reduces GPU memory spikes
    return dataset

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = "../checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

try: 
    checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}{timestamp}.h5")
except NameError: 
    checkpoint_path = os.path.join(checkpoint_dir, f"best_model{timestamp}.h5")

# Consider adding something dynamic like this to checkpoint_path: "/model_{epoch:02d}_{val_acc:.4f}.h5" 
checkpoint_cb = ModelCheckpoint(
    filepath=checkpoint_path,
    monitor="val_acc",
    save_best_only=True,
    save_weights_only=False,
    mode="max",
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_acc",
    patience=20,
    mode="max",
    restore_best_weights=True,
    verbose=1
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor="val_acc",
    factor=0.33,
    patience=10,
    min_delta=0.005,
    min_lr=1e-6,
    verbose=1
)

In [None]:
train_ds = make_tf_dataset(df_train[feature_columns].values, df_train['label'].values, shuffle=True)
val_ds = make_tf_dataset(df_val[feature_columns].values, df_val['label'].values, shuffle=False)

In [None]:
mlp = model_library.MLP(input_dim=384, hidden_layers=(64,16))
mlp.compile(
    optimizer=legacy_optimizers.Adam(learning_rate=3e-4),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), 
    metrics=[tf.keras.metrics.BinaryAccuracy(name="acc")],
    run_eagerly=True
)

In [None]:
# Or reload to continue training
model_name = '48px_v1.1.1SSL4EO-MLP20251021_172507'
mlp = tf.keras.models.load_model(f'../checkpoints/{model_name}.h5')

mlp.compile(
    optimizer=legacy_optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), 
    metrics=[tf.keras.metrics.BinaryAccuracy(name="acc")],
    run_eagerly=True
)

In [None]:
mlp.fit(
    train_ds,
    validation_data=val_ds,
    epochs=50, 
    verbose=1,
    callbacks=[checkpoint_cb, reduce_lr_cb]#, earlystop_cb]
)

In [None]:
epoch = 31
resolution = 48
version_number = '1.2SSL4EO-MLP'
current_date = date.today()
model_name = f'{resolution}px_v{version_number}ep{epoch}_{current_date.isoformat()}'
model_path = f"../checkpts-tmp/{model_name}.h5"
assert not os.path.exists(model_path), f"Model {model_path} already exists"

mlp.save(model_path)
print(f"Saved {model_path}")

#### Eval

In [None]:
# Reload a saved model 
model_name = '48px_v0.X_SSL4EO-MLPensemble_2025-10-21'
mlp = tf.keras.models.load_model(f'../models/{model_name}.h5')

In [None]:
mlp.summary()

In [None]:
with tf.device("/CPU:0"):
    preds = mlp.predict(val_ds, verbose=1)
preds.shape

In [None]:
preds = preds.squeeze()
preds.shape

In [None]:
# For an ensemble
preds = preds.mean(axis=1)

In [None]:
def acc_curve(preds, y_true, thresholds=np.arange(.01, 1.01, .01)):
    """Compute accuracy curve as function of threshold"""
    score = [np.sum((preds >= t).astype('int') == y_true) / len(y_true) for t in thresholds]
    plt.plot(thresholds, score)
    plt.ylabel('Success Rate')
    plt.xlabel('Threshold')
    plt.title(f"Optimal Threshold: {thresholds[np.argmax(score)]:.2f} w/ accuracy {score[np.argmax(score)]:.2f}")

acc_curve(preds, df_val['label'].values)

In [None]:
def f1_curve(preds, y_true, thresholds=np.arange(.01, 1.01, .01)):
    """Compute F1 curve."""
    f1s = []
    for t in thresholds:
        y_pred = (preds >= t)
        f1s.append(f1_score(y_true, y_pred))

    fig, ax = plt.subplots()
    ax.plot(thresholds, f1s, label='Patchwise')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('F1 score')
    ax.legend(loc='lower left')
    plt.title(f"Optimal Threshold: {thresholds[np.argmax(f1s)]:.2f} w/ F1 {f1s[np.argmax(f1s)]:.2f}")
    return fig, ax

f1_curve(preds, df_val['label'].values)

In [None]:
threshold = 0.925
report = classification_report(df_val['label'].values, preds > threshold, target_names=['No Mine', 'Mine'], output_dict=True)
report = pd.DataFrame(report).transpose()
report


In [None]:
threshold = 0.99
report = classification_report(df_val['label'].values, preds > threshold, target_names=['No Mine', 'Mine'], output_dict=True)
report = pd.DataFrame(report).transpose()
report

In [None]:
threshold = 0.925
target_names = ['No Mine', 'Mine']
training_dataset = 'collected_locations2025-10-21T23:00.geojson'

model_path = f'../checkpts-tmp/{model_name}.h5'
with open(model_path.split('.h5')[0] + f"_config-t{threshold}.txt", 'w') as f:
    f.write(f'Training dataset: {training_dataset}')
    f.write(f"\nBatch Size: {batch_size}")
    f.write(f'\n\nClassification Report at {threshold}\n')
    f.write(classification_report(df_val['label'].values, preds > threshold, target_names=target_names))

#### Sklearn version

In [None]:
from sklearn.neural_network import MLPClassifier
layer_sizes = (64, 16)
mlp = MLPClassifier(hidden_layer_sizes=layer_sizes, n_iter_no_change=40, max_iter=1000, verbose=True)

In [None]:
mlp.fit(df_train[feature_columns], df_train['label'])

In [None]:
preds = mlp.predict_proba(df.loc[df.split == 'val', feature_columns])
preds = preds[:, 1]

In [None]:
now = datetime.today().isoformat()[:16]
model_path = f'../checkpts-tmp/SSL4EO-MLP{"-".join([str(s) for s in layer_sizes])}_{now}.joblib'
print(f'Model saved to: {model_path}')
joblib.dump(model, model_path)