In [None]:
"""
Seagrass Segmentation (Drive Tiles) — Google Colab
===================================================
This notebook trains a **neural-network (U‑Net) classifier** using the training
patches you showed in Drive:


My Drive / GEOG 761_proj_training /
├─ composites/ # yearly summer composites (multi-band .tif)
├─ RGB Composites/ # RGB-only composites (optional)
└─ training_patches_32tile/
├─ images/ # input image tiles (*.tif)
└─ labels/ # label tiles (*.tif) with class ids


It then runs **sliding‑window inference** on any composite GeoTIFF and writes a
classified raster (GeoTIFF) with class ids.


Class codes (expected in label tiles):
0 = background (ignored during training)
1 = sparse seagrass
2 = dense seagrass
3 = exposed sediments
4 = water
5 = interfaces (water–sediment)
"""

In [None]:
#!pip install earthengine-api geopandas shapely rasterio numpy pandas scikit-learn tensorflow==2.15 tqdm

In [None]:
#!pip install rasterio

In [None]:
import geemap
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.cluster import KMeans

import os, glob, json, math
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.enums import Resampling
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from google.colab import drive

In [None]:
# ================================
# 1) MOUNT DRIVE & PATHS
# ================================


drive.mount('/content/drive', force_remount=True)


ROOT = '/content/drive/My Drive/GEOG 761_proj_training'
IM_DIR = f'{ROOT}/training_patches_32tile/images'
LB_DIR = f'{ROOT}/training_patches_32tile/labels'
COMPOSITES_DIR = f'{ROOT}/composites'
OUT_PRED = f'{ROOT}/predictions' # will be created
os.makedirs(OUT_PRED, exist_ok=True)

In [None]:
# ================================
# 2) DISCOVER TRAINING TILES (images ⇄ labels)
# ================================


img_files = sorted(glob.glob(f'{IM_DIR}/*.tif'))
# Match by basename existing in labels
pairs = []
for img in img_files:
  base = os.path.basename(img)
  lb = os.path.join(LB_DIR, base)
  if os.path.exists(lb):
    pairs.append((img, lb))


assert len(pairs) > 0, 'No matching image/label tile pairs found.'
print(f'Found {len(pairs)} tile pairs.')


# Peek first tile to infer shapes and band count
with rasterio.open(pairs[0][0]) as s:
  H, W = s.height, s.width
  C = s.count
print(f'Tile size: {H}x{W}, bands: {C}')

In [None]:
# ================================
# 3) LOAD TILES INTO MEMORY (small dataset friendly)
# ================================


X_list, y_list = [], []
for i, (im, lb) in enumerate(pairs):
  with rasterio.open(im) as src:
    arr = src.read() # (C,H,W)
  with rasterio.open(lb) as src:
    lab = src.read(1) # (H,W) integer class ids (0..5)
  X_list.append(arr.transpose(1,2,0).astype('float32')) # (H,W,C)
  y_list.append(lab.astype('int32'))


X = np.stack(X_list, axis=0) # (N,H,W,C)
y = np.stack(y_list, axis=0) # (N,H,W)
print('Data shapes:', X.shape, y.shape)

In [None]:
# ================================
# 4) NORMALIZE & SPLIT
# ================================


# Per‑channel standardization using train split only
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
mean = X_train.mean(axis=(0,1,2), keepdims=True)
std = X_train.std(axis=(0,1,2), keepdims=True) + 1e-6
X_train_n = (X_train - mean) / std
X_val_n = (X_val - mean) / std


# Save scaler for inference
os.makedirs('/content/artefacts', exist_ok=True)
with open('/content/artefacts/scaler.json','w') as f:
  json.dump({'mean': mean.squeeze().tolist(), 'std': std.squeeze().tolist()}, f)


# Optional: compute class weights (ignore 0 during weighting)
classes_present = np.unique(y_train)
classes_present = classes_present[classes_present>0]
cls_w = compute_class_weight('balanced', classes=classes_present, y=y_train[y_train>0])
class_weight = {int(c): float(w) for c, w in zip(classes_present, cls_w)}
print('Class weights (excluding 0):', class_weight)


# Build pixel‑wise weights to **ignore background (0)**
def make_sample_weights(y_batch):
  # weight 0 for background pixels; others from class_weight
  w = np.zeros_like(y_batch, dtype='float32')
  for c, cw in class_weight.items():
    w[y_batch==c] = cw
  return w


sw_train = make_sample_weights(y_train)
sw_val = make_sample_weights(y_val)

In [None]:
# ================================
# 5) MODEL — light U‑Net
# ================================


def conv_block(x, f):
  x = tf.keras.layers.Conv2D(f, 3, padding='same', activation='relu')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Conv2D(f, 3, padding='same', activation='relu')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  return x


def encoder_block(x, f):
  c = conv_block(x, f)
  p = tf.keras.layers.MaxPool2D()(c)
  return c, p


def decoder_block(x, skip, f):
  x = tf.keras.layers.Conv2DTranspose(f, 2, strides=2, padding='same')(x)
  x = tf.keras.layers.Concatenate()([x, skip])
  x = conv_block(x, f)
  return x


#inputs = tf.keras.Input(shape=(H, W, C))
inputs = tf.keras.Input(shape=(None, None, C))


s1, p1 = encoder_block(inputs, 32)
s2, p2 = encoder_block(p1, 64)
s3, p3 = encoder_block(p2, 128)


bottleneck = conv_block(p3, 256)


d1 = decoder_block(bottleneck, s3, 128)
d2 = decoder_block(d1, s2, 64)
d3 = decoder_block(d2, s1, 32)


outputs = tf.keras.layers.Conv2D(6, 1, activation='softmax')(d3) # classes: 0..5
model = tf.keras.Model(inputs, outputs)
model.summary()


# Loss that supports per‑pixel sample weights
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])


callbacks = [
  tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True)
]


history = model.fit(
  X_train_n, y_train[..., None],
  validation_data=(X_val_n, y_val[..., None], sw_val), # weights for val reported
  sample_weight=sw_train,
  epochs=80,
  batch_size=8,
  callbacks=callbacks,
  verbose=1)


model.save('/content/artefacts/seagrass_unet.keras')
print('Saved model to /content/artefacts/seagrass_unet.keras')

In [None]:
# ================================
# 6) INFERENCE — classify ALL composite GeoTIFFs
# ================================

composites = sorted(glob.glob(f'{COMPOSITES_DIR}/*.tif'))
assert len(composites) > 0, 'No composite GeoTIFFs found.'
print('Found composites:', [os.path.basename(p) for p in composites])

# Load scaler
with open('/content/artefacts/scaler.json','r') as f:
    sc = json.load(f)
mu = np.array(sc['mean'], dtype='float32')
sigma = np.array(sc['std'], dtype='float32')

WIN = 256  # multiple of 8

for TARGET in composites:
    print('Classifying:', os.path.basename(TARGET))

    # open source explicitly to keep it alive while writing
    src = rasterio.open(TARGET, 'r')
    try:
        profile = src.profile
        bands_total = src.count
        Hc, Wc = src.height, src.width

        # Use first C bands; error if fewer than C
        if bands_total < C:
            raise ValueError(
                f'{os.path.basename(TARGET)} has {bands_total} bands; model trained on {C}.'
            )
        band_indexes = list(range(1, C + 1))

        out_path = os.path.join(
            OUT_PRED,
            os.path.basename(TARGET).replace('.tif', '_PRED.tif')
        )
        out_profile = profile.copy()
        out_profile.update({'count': 1, 'dtype': 'uint8', 'nodata': 0})

        with rasterio.open(out_path, 'w', **out_profile) as dst:
            for r0 in range(0, Hc, WIN):
                for c0 in range(0, Wc, WIN):
                    h = min(WIN, Hc - r0)
                    w = min(WIN, Wc - c0)
                    win = Window(c0, r0, w, h)

                    block = src.read(indexes=band_indexes, window=win)  # (C,h,w)
                    x = block.transpose(1, 2, 0)[None, ...].astype('float32')

                    # pad edge windows
                    if h < WIN or w < WIN:
                        x_pad = np.zeros((1, WIN, WIN, C), dtype='float32')
                        x_pad[:, :h, :w, :] = x
                        x = x_pad

                    x = (x - mu) / sigma
                    p = model.predict(x, verbose=0)[0]
                    pred = np.argmax(p, axis=-1).astype('uint8')[:h, :w]

                    dst.write(pred, 1, window=win)

        print('Saved:', out_path)
    finally:
        src.close()

print('Legend: 0=background, 1=sparse, 2=dense, 3=sediments, 4=water, 5=interfaces')
