<a href="https://colab.research.google.com/github/nickel525/gobi_site_prediction_2026/blob/main/CS_Final_Project_archaeology.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import glob

paths = sorted(glob.glob("/content/drive/MyDrive/gobi_s2_spring_labeled_chips_13x13_shard*.tfrecord.gz", recursive=True))
print("Num shards:", len(paths))
print(paths[:3])



Mounted at /content/drive
Num shards: 8
['/content/drive/MyDrive/gobi_s2_spring_labeled_chips_13x13_shard0.tfrecord.gz', '/content/drive/MyDrive/gobi_s2_spring_labeled_chips_13x13_shard1.tfrecord.gz', '/content/drive/MyDrive/gobi_s2_spring_labeled_chips_13x13_shard2.tfrecord.gz']


In [None]:
import tensorflow as tf

PATCH = 13
PLANE = PATCH * PATCH

BANDS = ["blue","green","red","nir","swir1","swir2","ndvi","ndwi","nbr"]

feature_spec = {b: tf.io.FixedLenFeature([PLANE], tf.float32) for b in BANDS}
# scalar float features (length 1)
for k in ["label","row","col","shard","cell_order"]:
    feature_spec[k] = tf.io.FixedLenFeature([1], tf.float32)
# bytes features
feature_spec["cell_id"] = tf.io.FixedLenFeature([], tf.string)
feature_spec["system:index"] = tf.io.FixedLenFeature([], tf.string)

def parse_example(x):
    ex = tf.io.parse_single_example(x, feature_spec)

    # stack planes into (289, 9) then reshape to (17,17,9)
    planes = [ex[b] for b in BANDS]                 # each (289,)
    stacked = tf.stack(planes, axis=-1)             # (289, 9)
    chip = tf.reshape(stacked, [PATCH, PATCH, len(BANDS)])  # (17,17,9)

    # label scalar
    label = tf.reshape(ex["label"], [])             # scalar
    label = tf.cast(label, tf.float32)

    # optional metadata (keep for debugging / later mapping)
    meta = {
        "row": tf.cast(tf.reshape(ex["row"], []), tf.int32),
        "col": tf.cast(tf.reshape(ex["col"], []), tf.int32),
        "shard": tf.cast(tf.reshape(ex["shard"], []), tf.int32),
        "cell_order": tf.cast(tf.reshape(ex["cell_order"], []), tf.int32),
        "cell_id": ex["cell_id"],
        "system_index": ex["system:index"],
    }
    return chip, label, meta
ds = tf.data.TFRecordDataset(paths, compression_type="GZIP", num_parallel_reads=tf.data.AUTOTUNE)
ds = ds.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)

# Peek one example
chip, y, meta = next(iter(ds.take(1)))
print("chip:", chip.shape, chip.dtype)
print("label:", y.numpy())
print("row/col:", int(meta["row"].numpy()), int(meta["col"].numpy()))
print("bands:", BANDS)
print("min/max per chip:", float(tf.reduce_min(chip).numpy()), float(tf.reduce_max(chip).numpy()))


chip: (13, 13, 9) <dtype: 'float32'>
label: 0.0
row/col: 1 541
bands: ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'ndvi', 'ndwi', 'nbr']
min/max per chip: -0.276074081659317 0.37489134073257446


In [None]:
import numpy as np
import tensorflow as tf

def filter_by_shards(ds, shard_set):
    shard_set = tf.constant(list(shard_set), dtype=tf.int32)
    def _keep(chip, y, meta):
        # meta["shard"] is scalar int32
        return tf.reduce_any(tf.equal(meta["shard"], shard_set))
    return ds.filter(_keep)

# ---- Split by shard ----
train_ds = filter_by_shards(ds, {0,1,2,3,4,5})
val_ds   = filter_by_shards(ds, {6})
test_ds  = filter_by_shards(ds, {7})

# ---- Drop metadata (KEEP UNBATCHED) ----
def drop_meta(chip, y, meta):
    return chip, y

train_xy = train_ds.map(drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
val_xy   = val_ds.map(drop_meta,   num_parallel_calls=tf.data.AUTOTUNE)
test_xy  = test_ds.map(drop_meta,  num_parallel_calls=tf.data.AUTOTUNE)

# ---- Augmentation (UNBATCHED) ----
def augment(x, y):
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    k = tf.random.uniform([], 0, 4, dtype=tf.int32)
    x = tf.image.rot90(x, k)
    return x, y

train_xy_aug = train_xy.map(augment, num_parallel_calls=tf.data.AUTOTUNE)

# ---- Compute mean/std from UNBATCHED training chips (no normalize yet) ----
N = 5000
xs = []
for x, y in train_xy.take(N):   # NOTE: use train_xy (no aug) for stable stats
    xs.append(x.numpy())
X = np.stack(xs)  # (N, 13, 13, 9)

mean = X.reshape(-1, X.shape[-1]).mean(axis=0)
std  = X.reshape(-1, X.shape[-1]).std(axis=0) + 1e-6

print("mean:", mean)
print("std :", std)

mean_tf = tf.constant(mean, dtype=tf.float32)
std_tf  = tf.constant(std,  dtype=tf.float32)

def normalize(x, y):
    x = (x - mean_tf) / std_tf
    return x, y

# ---- Normalize (STILL UNBATCHED) ----
train_xy_n_unbatched = train_xy_aug.map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
val_xy_n_unbatched   = val_xy.map(normalize,      num_parallel_calls=tf.data.AUTOTUNE)
test_xy_n_unbatched  = test_xy.map(normalize,     num_parallel_calls=tf.data.AUTOTUNE)

# ---- Batch EXACTLY ONCE, repeat ONLY train ----
BATCH = 64

train_xy_n = (train_xy_n_unbatched
              .shuffle(5000)
              .batch(BATCH, drop_remainder=True)
              .prefetch(tf.data.AUTOTUNE))

val_xy_n = (val_xy_n_unbatched
            .batch(BATCH)
            .prefetch(tf.data.AUTOTUNE))

test_xy_n = (test_xy_n_unbatched
             .batch(BATCH)
             .prefetch(tf.data.AUTOTUNE))

# ---- Sanity check shapes ----
for xb, yb in train_xy_n.take(1):
    print("TRAIN batch:", xb.shape, yb.shape)
for xb, yb in val_xy_n.take(1):
    print("VAL batch:", xb.shape, yb.shape)


mean: [ 0.11804193  0.1643592   0.2300159   0.28792363  0.372446    0.3405337
  0.11277924 -0.2710927  -0.08047115]
std : [0.00143058 0.00199728 0.00535583 0.00653893 0.00404383 0.00410531
 0.00407734 0.00919885 0.0118668 ]
TRAIN batch: (64, 13, 13, 9) (64,)
VAL batch: (64, 13, 13, 9) (64,)


In [None]:
pos = 0
tot = 0
for _, y in train_xy_n.unbatch().take(20000):  # sample
    pos += int(y.numpy() == 1.0)
    tot += 1
neg = tot - pos
print("sampled train pos/neg:", pos, neg)


sampled train pos/neg: 285 9507


In [None]:
import tensorflow as tf
import math

PATCH = 13
C = len(BANDS)
BATCH = 64  # keep in sync with how you batched train_xy_n/val_xy_n

# ---- Steps per epoch (robust) ----
train_batches = tf.data.experimental.cardinality(train_xy_n).numpy()
val_batches   = tf.data.experimental.cardinality(val_xy_n).numpy()

# If cardinality is unknown (-2), fall back to counts you measured
# train: 9840 examples, val: 1576 examples
if train_batches < 0:
    train_batches = math.ceil(9840 / BATCH)
if val_batches < 0:
    val_batches = math.ceil(1576 / BATCH)

steps_per_epoch = 153
validation_steps = int(val_batches)

print("steps_per_epoch:", steps_per_epoch, "validation_steps:", validation_steps)

# ---- Model ----
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(PATCH, PATCH, C)),
    tf.keras.layers.Conv2D(32, 3, padding="same", activation="relu"),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(64, 3, padding="same", activation="relu"),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu"),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(1, activation="sigmoid"),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics = [
        tf.keras.metrics.AUC(name="auc"),
        tf.keras.metrics.AUC(curve="PR", name="prauc"),
        tf.keras.metrics.Precision(name="prec_t30", thresholds=0.30),
        tf.keras.metrics.Recall(name="rec_t30", thresholds=0.30),
    ]
)

# ---- Class weights (moderate; your pos rate ~3%) ----
# Start with 10â€“15x instead of ~33x to avoid "trigger-happy" behavior.
class_weight = {0: 1.0, 1: 6.0}

# ---- Callbacks ----
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath="best_cnn.keras",
        monitor="val_prauc",
        mode="max",
        save_best_only=True,
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_prauc",
        mode="max",
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_prauc",
        mode="max",
        factor=0.5,
        patience=3,
        min_lr=1e-5,
        verbose=1
    ),
]

# ---- Train ----
history = model.fit(
    train_xy_n,                 # IMPORTANT: prevents running out of data
    validation_data=val_xy_n,
    epochs=50,
    validation_steps=validation_steps,
    class_weight=class_weight,
    callbacks=callbacks,
    verbose=2
)


steps_per_epoch: 153 validation_steps: 25
Epoch 1/50





Epoch 1: val_prauc improved from -inf to 0.14923, saving model to best_cnn.keras
153/153 - 16s - 101ms/step - auc: 0.5820 - loss: 0.5292 - prauc: 0.0431 - prec_t30: 0.0400 - rec_t30: 0.3227 - val_auc: 0.5843 - val_loss: 0.2573 - val_prauc: 0.1492 - val_prec_t30: 0.1060 - val_rec_t30: 0.2963 - learning_rate: 1.0000e-03
Epoch 2/50

Epoch 2: val_prauc did not improve from 0.14923
153/153 - 14s - 91ms/step - auc: 0.5822 - loss: 0.4957 - prauc: 0.0419 - prec_t30: 0.0539 - rec_t30: 0.2000 - val_auc: 0.6266 - val_loss: 0.2083 - val_prauc: 0.1048 - val_prec_t30: 0.0000e+00 - val_rec_t30: 0.0000e+00 - learning_rate: 1.0000e-03
Epoch 3/50

Epoch 3: val_prauc did not improve from 0.14923
153/153 - 18s - 120ms/step - auc: 0.6308 - loss: 0.4808 - prauc: 0.0495 - prec_t30: 0.0548 - rec_t30: 0.1789 - val_auc: 0.6184 - val_loss: 0.2529 - val_prauc: 0.0493 - val_prec_t30: 0.0417 - val_rec_t30: 0.0556 - learning_rate: 1.0000e-03
Epoch 4/50

Epoch 4: val_prauc did not improve from 0.14923

Epoch 4: Redu

In [None]:
MODEL_DIR = "/content/drive/MyDrive/gobi_cnn_model.keras"
model.save(MODEL_DIR)

print("Model saved to:", MODEL_DIR)

Model saved to: /content/drive/MyDrive/gobi_cnn_model.keras


In [None]:
import json, os

norm = {
    "bands": BANDS,
    "mean": mean.tolist(),
    "std": std.tolist(),
}

with open("/content/drive/MyDrive/gobi_cnn_model_norm.json", "w") as f:
    json.dump(norm, f, indent=2)

print("Normalization saved")


Normalization saved


In [None]:
import tensorflow as tf
import json

model = tf.keras.models.load_model("/content/drive/MyDrive/gobi_cnn_model.keras")

with open("/content/drive/MyDrive/gobi_cnn_model_norm.json") as f:
    norm = json.load(f)

mean = tf.constant(norm["mean"], dtype=tf.float32)
std  = tf.constant(norm["std"], dtype=tf.float32)

print("Reload successful")

Reload successful


In [None]:


def get_preds_and_labels(ds):
    ys = []
    ps = []
    for x_batch, y_batch in ds:
        p_batch = model.predict(x_batch, verbose=0).reshape(-1)
        ps.append(p_batch)
        ys.append(y_batch.numpy().reshape(-1))
    y = np.concatenate(ys).astype(np.int32)
    p = np.concatenate(ps).astype(np.float32)
    return y, p

y_true, p_site = get_preds_and_labels(test_xy_n)

import numpy as np

def confusion_at_threshold(y_true, p_site, t):
    yhat = (p_site >= t).astype(int)
    TP = int(((yhat==1) & (y_true==1)).sum())
    FP = int(((yhat==1) & (y_true==0)).sum())
    TN = int(((yhat==0) & (y_true==0)).sum())
    FN = int(((yhat==0) & (y_true==1)).sum())
    prec = TP/(TP+FP+1e-9)
    rec  = TP/(TP+FN+1e-9)
    acc  = (TP+TN)/(TP+FP+TN+FN)
    f1   = 2*prec*rec/(prec+rec+1e-9)
    return TP,FP,TN,FN,acc,prec,rec,f1

ts = np.linspace(0,1,401)
rows = []
for t in ts:
    rows.append((t,)+confusion_at_threshold(y_true, p_site, t))
rows = np.array(rows, dtype=float)

# best F1
best_f1 = rows[np.argmax(rows[:,-1])]
print("Best F1:", best_f1)

# threshold that keeps recall >= 0.60 and minimizes FP
rows_recall = rows[rows[:,-2] >= 0.60]  # recall column
best_fp = rows_recall[np.argmin(rows_recall[:,2])]  # FP column
print("Min FP with recall>=0.60:", best_fp)

# show a few candidate thresholds around where FP drops
for t in [0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    out = confusion_at_threshold(y_true, p_site, t)
    print(t, out)

print("p min/max:", p_site.min(), p_site.max())
print("p quantiles:", np.quantile(p_site, [0, .01, .05, .1, .5, .9, .95, .99, 1.0]))

test_metrics = model.evaluate(test_xy_n, verbose=2)
print(dict(zip(model.metrics_names, test_metrics)))

Best F1: [3.20000000e-01 1.60000000e+01 5.00000000e+01 1.54400000e+03
 4.10000000e+01 9.44881890e-01 2.42424242e-01 2.80701754e-01
 2.60162601e-01]
Min FP with recall>=0.60: [2.02500000e-01 3.50000000e+01 3.22000000e+02 1.27200000e+03
 2.20000000e+01 7.91641429e-01 9.80392157e-02 6.14035088e-01
 1.69082125e-01]
0.1 (57, 1001, 593, 0, 0.3937007874015748, 0.05387523629484511, 0.9999999999824563, 0.10224215236916888)
0.2 (36, 329, 1265, 21, 0.788007268322229, 0.09863013698603115, 0.6315789473573408, 0.17061611350961345)
0.25 (28, 230, 1364, 29, 0.8431253785584494, 0.1085271317825251, 0.4912280701668206, 0.1777777774802318)
0.3 (16, 53, 1541, 41, 0.9430648092065415, 0.23188405796765385, 0.28070175438104034, 0.2539682534687579)
0.4 (12, 43, 1551, 45, 0.9466989703210176, 0.21818181817785126, 0.21052631578578027, 0.21428571378204722)
0.5 (0, 0, 1594, 57, 0.9654754694124773, 0.0, 0.0, 0.0)
0.6 (0, 0, 1594, 57, 0.9654754694124773, 0.0, 0.0, 0.0)
0.7 (0, 0, 1594, 57, 0.9654754694124773, 0.0, 0.0