In [None]:
# =====================================================================
# UPDATED CELL 1: Robust SatMAE Environment Setup (version‑pinned)
# =====================================================================
import os
import subprocess
import sys

# 1. Remove conflicting packages and any existing NumPy
print("▸ Removing conflicting packages and existing NumPy...")
# Uninstall packages that frequently conflict with the pinned stack
subprocess.run(
    [sys.executable, "-m", "pip", "uninstall", "-y", "opencv-python-headless", "thinc", "numpy"],
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL,
)
print("✓ Conflicting packages and old NumPy removed.")

# 2. Upgrade pip so it can install the pinned wheels
print("▸ Upgrading pip to the latest version...")
subprocess.run(
    [sys.executable, "-m", "pip", "install", "--upgrade", "pip", "--quiet"],
    check=True,
)
print("✓ pip upgraded.")

# 3. Install pinned versions of PyTorch/Lightning/timm/numpy plus other deps
print("▸ Installing pinned versions for SatMAE (Torch 2.1.2, Lightning 2.2.x, timm 0.9.16, NumPy 1.26.4)...")
subprocess.run(
    [
        sys.executable,
        "-m",
        "pip",
        "install",
        "--no-cache-dir",
        "--quiet",
        "torch==2.1.2",
        "torchvision==0.16.2",
        "torchaudio==2.1.2",
        "pytorch-lightning==2.2.4",
        "timm==0.9.16",
        "numpy==1.26.4",
        # SatMAE dependencies
        "rasterio",
        "s3fs",
        "shapely",
        "pystac-client",
        "pandas",
        "geopandas",
        "tqdm",
    ],
    check=True,
)
print("✓ Core libraries installed.")

# 4. Clone the SatMAE repository and apply a small patch to remove qk_scale kwargs
REPO_DIR = "/content/SatMAE"
if not os.path.isdir(REPO_DIR):
    print("▸ Cloning and patching SatMAE...")
    subprocess.run(
        ["git", "clone", "--depth", "1", "https://github.com/sustainlab-group/SatMAE.git", REPO_DIR],
        check=True,
    )
    # Patch out the qk_scale argument that is not expected in some versions of timm
    patch_file = os.path.join(REPO_DIR, "models_mae_group_channels.py")
    subprocess.run(
        ["sed", "-i", "s/qk_scale=None, //g", patch_file],
        check=True,
    )
    print("✓ SatMAE cloned and patched.")
else:
    print("✓ SatMAE repository already exists – reusing it.")

# 5. Import libraries and verify versions; assert pinned versions are loaded
try:
    import numpy as np
    import torch
    import pytorch_lightning as pl
    import timm

    # Make sure the versions match the pinned requirements
    assert torch.__version__.startswith("2.1"), f"Torch version mismatch: {torch.__version__}"
    assert pl.__version__.startswith("2.2"), f"Lightning version mismatch: {pl.__version__}"
    # timm >= 0.6 has removed the deprecated np.float reference, so no need for monkey‑patch
    assert tuple(map(int, timm.__version__.split("."))) >= (0, 6, 0), f"Old timm detected: {timm.__version__}"

    print("-" * 50)
    print("✅ ENVIRONMENT SETUP COMPLETE AND VERIFIED ✅")
    print(f"  - PyTorch version: {torch.__version__}")
    print(f"  - NumPy version:   {np.__version__}")
    print(f"  - timm version:    {timm.__version__}")
    print(f"  - Lightning:       {pl.__version__}")
    print("-" * 50)
except Exception as e:
    print("❌ Environment setup failed:", e)
    print("Please restart the runtime and re-run this cell.")


▸ Removing conflicting packages and existing NumPy...
✓ Conflicting packages and old NumPy removed.
▸ Upgrading pip to the latest version...
✓ pip upgraded.
▸ Installing pinned versions for SatMAE (Torch 2.1.2, Lightning 2.2.x, timm 0.9.16, NumPy 1.26.4)...
✓ Core libraries installed.
✓ SatMAE repository already exists – reusing it.
--------------------------------------------------
✅ ENVIRONMENT SETUP COMPLETE AND VERIFIED ✅
  - PyTorch version: 2.1.2+cu121
  - NumPy version:   1.26.4
  - timm version:    0.9.16
  - Lightning:       2.2.4
--------------------------------------------------


In [None]:
# =====================================================================
# UPDATED CELL 2: Mount Drive and load the bounding‑box summary
# =====================================================================
from google.colab import drive
import pandas as pd
import os

# Mount your Google Drive
drive.mount('/content/drive')

# Path to the summary CSV in your My Drive (update if needed)
summary_path = '/content/drive/MyDrive/sturm_flood_event_summary.csv'

if not os.path.isfile(summary_path):
    raise FileNotFoundError(f"Could not find {summary_path}. Check the path and filename.")

# Load the CSV
summary_df = pd.read_csv(summary_path)
print(f"Loaded {len(summary_df)} events from {summary_path}")
summary_df.head()

Mounted at /content/drive
Loaded 60 events from /content/drive/MyDrive/sturm_flood_event_summary.csv


Unnamed: 0,EMSR_code,lat_min,lon_min,lat_max,lon_max,num_boxes,central_lat,central_lon
0,EMSR260,44.624375,10.5229,44.919442,10.849205,10,44.771909,10.686053
1,EMSR261,51.816565,9.360113,52.786385,10.733915,1028,52.301475,10.047014
2,EMSR265,47.777161,0.035549,49.619315,2.762011,929,48.698238,1.39878
3,EMSR268,56.399934,23.65603,56.688011,23.860839,57,56.543972,23.758435
4,EMSR275,45.155585,16.147086,45.696565,17.215032,920,45.426075,16.681059


In [None]:
# =====================================================================
# UPDATED CELL 3: Generate Sentinel‑2 asset prefixes (with correct bounding box)
# =====================================================================
import shapely.geometry as sg
import random
from pystac_client import Client
from tqdm import tqdm

api = Client.open("https://earth-search.aws.element84.com/v1")

prefixes = []
for _, row in tqdm(summary_df.iterrows(), total=len(summary_df)):
    # Correct bounding box: use lat_max for the fourth argument
    poly = sg.box(row.lon_min, row.lat_min, row.lon_max, row.lat_max).buffer(0.1)
    bbox = poly.bounds
    search = api.search(
        collections=["sentinel-2-l2a"],
        bbox=bbox,
        datetime="2021-01-01/2024-12-31",
        query={"eo:cloud_cover": {"lt": 40}},
        max_items=200,
    )
    for it in search.get_items():
        asset = it.assets.get("B04") or it.assets.get("visual")
        if asset is None:
            continue
        href = asset.href
        prefix = href.rsplit("/", 1)[0]  # drop the file name, keep full prefix
        prefixes.append(prefix)

# Shuffle and optionally limit the number of prefixes
random.shuffle(prefixes)
max_keys = 50000
prefixes = prefixes[:max_keys]

# Save prefixes to a text file for later streaming
prefix_path = '/content/s2_prefixes.txt'
with open(prefix_path, 'w') as f:
    f.write("".join(prefixes))

print(f"Wrote {len(prefixes)} prefixes to {prefix_path}")

100%|██████████| 60/60 [05:51<00:00,  5.86s/it]

Wrote 11997 prefixes to /content/s2_prefixes.txt





In [None]:
# =====================================================================
# UPDATED CELL 4: sentinel_dataset.py with caching for prefix-based streaming
# =====================================================================
%%writefile /content/sentinel_dataset.py
import os
import hashlib
import torch
import rasterio
import numpy as np
import random
import time
from torch.utils.data import Dataset

# Names of nine Sentinel‑2 bands to fetch
BANDS = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"]

# Disable directory listing on open and unsigned access for AWS
ENV_OPTS = {
    'GDAL_DISABLE_READDIR_ON_OPEN': 'YES',  # prevents GDAL from issuing a GetObject request for directory listing
    'AWS_NO_SIGN_REQUEST': 'YES'            # treats the bucket as public; avoids auth tokens in request
}

# Directory to cache local patches. Each prefix will be hashed and its
# corresponding 9×128×128 cube will be stored as a .npy file here.
CACHE_DIR = "/content/s2_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

def safe_read(url, window, attempts=3, delay=2):
    """Read a window from a remote GeoTIFF with retries to handle transient HTTP errors."""
    for attempt in range(attempts):
        try:
            # Use a GDAL environment to ensure we don't request directory listings
            with rasterio.Env(**ENV_OPTS):
                with rasterio.open(url) as src:
                    return src.read(1, window=window)
        except Exception as e:
            if attempt == attempts - 1:
                raise e
            time.sleep(delay * (attempt + 1))

class S2Stream(Dataset):
    def __init__(self, prefix_file, size=128):
        # Each line in prefix_file should be a full URL prefix (without the band file)
        self.prefixes = [line.strip() for line in open(prefix_file) if line.strip()]
        self.size = size

    def __len__(self):
        return len(self.prefixes)

    def __getitem__(self, idx):
        prefix = self.prefixes[idx]

        # Generate a deterministic filename based on the prefix URL.
        file_key = hashlib.md5(prefix.encode('utf-8')).hexdigest()
        cache_file = os.path.join(CACHE_DIR, f"{file_key}.npy")

        # If a cached patch exists, load it directly from disk.
        if os.path.isfile(cache_file):
            cube = np.load(cache_file)
            return torch.from_numpy(cube)

        # Otherwise, fetch a random window for each band and save it to cache.
        windows = []
        for band in BANDS:
            url = f"{prefix}/{band}.tif"
            # Randomly select a window position once per prefix
            with rasterio.Env(**ENV_OPTS):
                with rasterio.open(url) as src:
                    w = random.randint(0, src.width - self.size)
                    h = random.randint(0, src.height - self.size)
                    window = ((h, h + self.size), (w, w + self.size))
            # Use safe_read to avoid HTTP 503/504 storms and large header requests
            band_data = safe_read(url, window)
            windows.append(band_data.astype(np.float32))
        cube = np.stack(windows) / 10000.0  # normalise reflectance

        # Persist the cube to the cache for future epochs.
        try:
            np.save(cache_file, cube)
        except Exception:
            # If saving fails, ignore; it will be re-fetched next time.
            pass

        return torch.from_numpy(cube)


Overwriting /content/sentinel_dataset.py


In [None]:
# =====================================================================
# UPDATED CELL 6: Lightning wrapper with grouped_bands (Lightning 2.x compatible)
# =====================================================================
%%writefile /content/mae_pretrain_lightning.py
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from sentinel_dataset import S2Stream
from SatMAE.models_mae_group_channels import mae_vit_base_patch16_dec512d8b

class MAE9(pl.LightningModule):
    def __init__(self, lr=1.5e-4, mask_ratio=0.6):
        super().__init__()
        self.save_hyperparameters()
        # Pass grouped_bands to match the original SatMAE signature
        self.model = mae_vit_base_patch16_dec512d8b(
            img_size=128,
            in_chans=9,
            grouped_bands=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
        )

    def training_step(self, batch, _):
        loss, *_ = self.model(batch, mask_ratio=self.hparams.mask_ratio)
        self.log("loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)


def make_loader(prefix_file, batch_size, workers=4):
    return DataLoader(
        S2Stream(prefix_file),
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )

Overwriting /content/mae_pretrain_lightning.py


In [None]:
# =====================================================================
# CORRECTED CELL 7: Clean prefixes, fix channel groups, disable flash attention, and train on GPU
# =====================================================================
import os
import numpy as np
import torch

# 1. Prevent GDAL from issuing directory-listing requests on remote COGs
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "YES"
os.environ["AWS_NO_SIGN_REQUEST"] = "YES"

# 2. Restore deprecated NumPy alias for older timm versions
np.float = np.float64

# 3. Clean up the prefixes file: split concatenated URLs into separate lines
raw_prefix_file = '/content/s2_prefixes.txt'
clean_prefix_file = '/content/s2_prefixes_clean.txt'
cleaned = []
with open(raw_prefix_file) as f:
    for line in f:
        l = line.strip()
        if not l:
            continue
        if l.count('https://') > 1:
            parts = l.split('https://')
            for part in parts:
                part = part.strip()
                if part:
                    cleaned.append('https://' + part)
        else:
            cleaned.append(l)
with open(clean_prefix_file, 'w') as f:
    f.write('\n'.join(cleaned))
print(f"Using {len(cleaned)} cleaned prefixes from {clean_prefix_file}")

# 4. Probe dataset shape, but handle failure gracefully
from sentinel_dataset import S2Stream
try:
    sample_shape = S2Stream(clean_prefix_file)[0].shape
    print("One sample shape:", sample_shape)
except Exception as e:
    print("Could not read a sample from S2Stream:", e)
    sample_shape = (9, 128, 128)
    print("Assuming sample shape is", sample_shape)

# 5. Disable flash and mem‑efficient scaled dot product attention kernels for stability
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

# 6. Import SatMAE code and construct the model with correct channel groups
import importlib
import sys
satmae_dir = '/content/SatMAE'
if satmae_dir not in sys.path:
    sys.path.insert(0, satmae_dir)
import SatMAE.models_mae_group_channels as mmgc
importlib.reload(mmgc)
from SatMAE.models_mae_group_channels import mae_vit_base_patch16_dec512d8b

import pytorch_lightning as pl
from torch.utils.data import DataLoader

# Lightning module using correct channel groups
class MAE9Fixed(pl.LightningModule):
    def __init__(self, lr=1.5e-4, mask_ratio=0.75):
        super().__init__()
        self.save_hyperparameters()
        # Specify channel_groups to match 9 input bands: (0,1,2), (3,4,5), (6,7,8)
        self.model = mae_vit_base_patch16_dec512d8b(
            img_size=128,
            in_chans=9,
            channel_groups=((0, 1, 2), (3, 4, 5), (6, 7, 8)),
        )
    def training_step(self, batch, _):
        loss, *_ = self.model(batch, mask_ratio=self.hparams.mask_ratio)
        self.log("loss", loss, prog_bar=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# DataLoader factory with workers=0 to avoid HTTP contention
def make_loader(prefix_file, batch_size, workers=0):
    return DataLoader(
        S2Stream(prefix_file),
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )

# 7. Mount Drive and set up checkpoint directory
import pathlib
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
ckpt_dir = '/content/drive/MyDrive/satmae_pretrain_ckpts'
pathlib.Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

# 8. Instantiate model and trainer on GPU
model = MAE9Fixed(lr=1.5e-4, mask_ratio=0.75)
trainer = pl.Trainer(
    max_epochs=20,
    precision="32-true",
    devices=1,
    accelerator="gpu",
    log_every_n_steps=50,
    gradient_clip_val=1.0,
    accumulate_grad_batches=4,
    default_root_dir=ckpt_dir,
)

# 9. Train the model using the cleaned prefixes file
trainer.fit(model, make_loader(clean_prefix_file, batch_size=8, workers=4))


Using 11997 cleaned prefixes from /content/s2_prefixes_clean.txt
One sample shape: torch.Size([9, 128, 128])
Mounted at /content/drive


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                             | Params
-----------------------------------------------------------
0 | model | MaskedAutoenco

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


In [None]:
trainer.save_checkpoint("/content/drive/MyDrive/satmae_pretrain_ckpts/mae9_epoch20.ckpt")


In [None]:
!ls -lh /content/drive/MyDrive/satmae_pretrain_ckpts


total 1.3G
drwx------ 16 root root 4.0K Jul 19 09:54 lightning_logs
-rw-------  1 root root 1.3G Jul 19 18:33 mae9_epoch20.ckpt


In [None]:
from google.colab import drive
import shutil
import os

# Mount your Drive (if you haven't already)
drive.mount('/content/drive', force_remount=True)

# 1. Copy the prefixes list into your Drive
src_prefixes = '/content/s2_prefixes_clean.txt'   # adjust if your file is elsewhere
dst_prefixes = '/content/drive/MyDrive/s2_prefixes_clean.txt'
shutil.copyfile(src_prefixes, dst_prefixes)
print(f"Prefixes file saved to {dst_prefixes}")

# 2. Copy the entire cache directory into your Drive
src_cache = '/content/s2_cache'
dst_cache = '/content/drive/MyDrive/s2_cache'
# Make sure the destination does not already exist
if os.path.exists(dst_cache):
    print(f"Destination {dst_cache} already exists; you may want to delete it first or choose another name.")
else:
    shutil.copytree(src_cache, dst_cache)
    print(f"Cache directory saved to {dst_cache}")


Mounted at /content/drive
Prefixes file saved to /content/drive/MyDrive/s2_prefixes_clean.txt
Cache directory saved to /content/drive/MyDrive/s2_cache


In [None]:
# This will summarise the total size of the directory and print a human‑readable value
!du -sh /content/drive/MyDrive/s2_cache


6.0G	/content/drive/MyDrive/s2_cache
