In [None]:
# ==============================================================================
# Cell 1 (Corrected): Environment Setup
# ==============================================================================
# Install required libraries
# ADDED: segmentation-models-pytorch for advanced loss functions (Dice, Focal).
# Note the hyphenated name for pip install.
!pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install -q pytorch-lightning timm rasterio albumentations grad-cam onnx onnxruntime-gpu segmentation-models-pytorch

# Mount Google Drive
from google.colab import drive
import os
from pathlib import Path

drive.mount('/content/drive')

# Define the project's root directory for consistent pathing
ROOT = Path('/content/drive/MyDrive/opticflood_phd_project')

# Create the project's directory structure
(ROOT / 'data/raw').mkdir(parents=True, exist_ok=True)
(ROOT / 'data/processed').mkdir(parents=True, exist_ok=True)
(ROOT / 'models').mkdir(parents=True, exist_ok=True)
(ROOT / 'reports/figures').mkdir(parents=True, exist_ok=True)

print(f"Project root is set to: {ROOT}")
print("Directory structure is ready.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m111.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m85.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ==============================================================================
# Cell 2: Automated Local Data Preparation
# PURPOSE: Downloads, unzips, and organizes data on the fast Colab local disk.
# This makes the notebook self-contained and fully reproducible for anyone.
# ==============================================================================
from pathlib import Path
import shutil

# --- All paths are now LOCAL to the Colab machine ---
LOCAL_DATA_DIR = Path('/content/data')
LOCAL_IMAGES_DIR = LOCAL_DATA_DIR / 'images'
LOCAL_MASKS_DIR = LOCAL_DATA_DIR / 'masks'
LOCAL_ZIP_PATH = LOCAL_DATA_DIR / 'Dataset.zip'
LOCAL_EXTRACTED_DIR = LOCAL_DATA_DIR / 'Dataset'
EXPECTED_FILE_COUNT = 2675

# --- Main Logic ---
print("--- Preparing Local Data Environment for this Session ---")

# Check if data is already on the local disk
if LOCAL_IMAGES_DIR.is_dir() and len(list(LOCAL_IMAGES_DIR.glob('*.tif'))) == EXPECTED_FILE_COUNT:
    print("✅ Data is already prepared on the local disk.")
else:
    print("⚠️ Data not found on local disk. Starting full download and setup...")
    LOCAL_DATA_DIR.mkdir(parents=True, exist_ok=True)

    # 1. Download directly to the local disk
    print(f"   -> Downloading master ZIP to {LOCAL_ZIP_PATH}...")
    !wget -c "https://zenodo.org/records/12748983/files/Dataset.zip?download=1" -O {LOCAL_ZIP_PATH}

    # 2. Unzip locally
    print(f"   -> Extracting ZIP locally...")
    !unzip -q -o {LOCAL_ZIP_PATH} -d {LOCAL_DATA_DIR}

    # 3. Organize files locally
    print("   -> Organizing files...")
    source_images_path = LOCAL_EXTRACTED_DIR / 'Sentinel2/S2'
    source_masks_path = LOCAL_EXTRACTED_DIR / 'Sentinel2/Floodmaps'

    LOCAL_IMAGES_DIR.mkdir(exist_ok=True)
    LOCAL_MASKS_DIR.mkdir(exist_ok=True)

    !mv {source_images_path}/*.tif {LOCAL_IMAGES_DIR}/
    !mv {source_masks_path}/*.tif {LOCAL_MASKS_DIR}/

    # 4. Clean up the large intermediate files
    print("   -> Cleaning up...")
    shutil.rmtree(LOCAL_EXTRACTED_DIR)
    LOCAL_ZIP_PATH.unlink() # Delete the 4GB zip file
    print("✅ Local data preparation complete.")

# --- Final Verification ---
num_images = len(list(LOCAL_IMAGES_DIR.glob('*.tif')))
num_masks = len(list(LOCAL_MASKS_DIR.glob('*.tif')))
print(f"\nVerification: Found {num_images} images and {num_masks} masks on local disk.")
assert num_images == EXPECTED_FILE_COUNT and num_masks == EXPECTED_FILE_COUNT

--- Preparing Local Data Environment for this Session ---
⚠️ Data not found on local disk. Starting full download and setup...
   -> Downloading master ZIP to /content/data/Dataset.zip...
--2025-07-15 12:39:00--  https://zenodo.org/records/12748983/files/Dataset.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.185.45.92, 188.185.48.194, 188.185.43.25, ...
Connecting to zenodo.org (zenodo.org)|188.185.45.92|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3985215183 (3.7G) [application/octet-stream]
Saving to: ‘/content/data/Dataset.zip’


2025-07-15 12:57:46 (3.38 MB/s) - ‘/content/data/Dataset.zip’ saved [3985215183/3985215183]

   -> Extracting ZIP locally...
   -> Organizing files...
   -> Cleaning up...
✅ Local data preparation complete.

Verification: Found 2675 images and 2675 masks on local disk.


In [None]:
# ==============================================================================
# FINAL BASELINE CELL: Event-Wise Splitting and Training
# This cell implements the academically rigorous "leave-one-event-out" style split.
# ==============================================================================
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import pytorch_lightning as pl
from pathlib import Path
import numpy as np
import collections
import albumentations as A
from albumentations.pytorch import ToTensorV2
import rasterio
import segmentation_models_pytorch as smp
from torchmetrics import F1Score
from sklearn.model_selection import GroupShuffleSplit
import os

# --- 1. Define Necessary Classes (Unchanged) ---
class SturmFloodDataset9Band(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths,self.mask_paths,self.transform = image_paths,mask_paths,transform; self.water_classes=[1,2,3,4,5]; self.MIN_VAL_SCALE,self.MAX_VAL_SCALE=0.0,10.0
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        with rasterio.open(self.image_paths[idx]) as src:
            image=src.read().astype(np.float32);image=np.nan_to_num(image);image_scaled=(image-self.MIN_VAL_SCALE)/(self.MAX_VAL_SCALE-self.MIN_VAL_SCALE);image=np.clip(image_scaled,0.0,1.0)
        with rasterio.open(self.mask_paths[idx]) as src:
            original_mask=src.read(1);binary_mask=np.isin(original_mask,self.water_classes).astype(np.int64)
        if self.transform:augmented=self.transform(image=image.transpose(1,2,0),mask=binary_mask);image,mask=augmented['image'],augmented['mask']
        return image,mask

class SmpUNetBaseline(pl.LightningModule):
    def __init__(self,learning_rate=1e-4):
        super().__init__();self.save_hyperparameters();self.model=smp.Unet("resnet34",encoder_weights="imagenet",in_channels=9,classes=2);self.loss_fn=smp.losses.DiceLoss(mode='multiclass',from_logits=True,smooth=1.0);self.val_f1=F1Score(task='multiclass',num_classes=2,average='none')
    def forward(self,x):return self.model(x)
    def training_step(self,batch,batch_idx):x,y=batch;logits=self(x);loss=self.loss_fn(logits,y.long());self.log('train_loss',loss,prog_bar=True);return loss
    def validation_step(self,batch,batch_idx):x,y=batch;logits=self(x);loss=self.loss_fn(logits,y.long());self.val_f1.update(logits,y.long());self.log('val_loss',loss,prog_bar=True)
    def on_validation_epoch_end(self):f1=self.val_f1.compute();self.log('val_f1_water',f1[1],prog_bar=True) if f1.numel()>1 else None;self.val_f1.reset()
    def configure_optimizers(self):opt=torch.optim.AdamW(self.parameters(),lr=self.hparams.learning_rate);sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,'min',patience=5,verbose=True);return {"optimizer":opt,"lr_scheduler":{"scheduler":sch,"monitor":"val_loss"}}

# --- 2. Setup Data with EVENT-WISE SPLIT ---
print("--- Setting up Data with Event-Wise (Grouped) Splitting ---")
LOCAL_DATA_DIR = Path('/content/data')
image_paths = sorted(list(LOCAL_DATA_DIR.glob('images/*.tif')))
mask_paths = sorted(list(LOCAL_DATA_DIR.glob('masks/*.tif')))

# Extract event IDs (e.g., 'EMSR470') from filenames to use as groups
event_groups = [p.name.split('_')[0] for p in image_paths]

# Use GroupShuffleSplit to create a robust train/test split that respects events
# This ensures all images from one event go into either train or test, not both.
gss_test = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
train_val_indices, test_indices = next(gss_test.split(image_paths, groups=event_groups))

# Now split the train_val set again to create a validation set
# We'll use 15% of the train_val set for validation
gss_val = GroupShuffleSplit(n_splits=1, test_size=0.15, random_state=42)
train_indices, val_indices = next(gss_val.split([image_paths[i] for i in train_val_indices], groups=[event_groups[i] for i in train_val_indices]))

# The indices from the second split are relative to the train_val set, so map them back
original_train_indices = train_val_indices[train_indices]
original_val_indices = train_val_indices[val_indices]

# Verify that no event IDs are shared between sets
train_events = set([event_groups[i] for i in original_train_indices])
val_events = set([event_groups[i] for i in original_val_indices])
test_events = set([event_groups[i] for i in test_indices])
print(f"Train/Val intersection: {train_events.intersection(val_events)}")
print(f"Train/Test intersection: {train_events.intersection(test_events)}")
print(f"Val/Test intersection: {val_events.intersection(test_events)}")
assert len(train_events.intersection(test_events)) == 0, "Leakage detected between train and test sets!"
assert len(val_events.intersection(test_events)) == 0, "Leakage detected between val and test sets!"

# Create Datasets with the correct transforms and indices
train_transform = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), ToTensorV2()])
val_test_transform = A.Compose([ToTensorV2()])

train_dataset = Subset(SturmFloodDataset9Band(image_paths, mask_paths, transform=train_transform), original_train_indices)
val_dataset = Subset(SturmFloodDataset9Band(image_paths, mask_paths, transform=val_test_transform), original_val_indices)
test_dataset = Subset(SturmFloodDataset9Band(image_paths, mask_paths, transform=val_test_transform), test_indices)

print(f"\nEvent-Wise Split Complete:")
print(f"  -> Training samples: {len(train_dataset)} ({len(train_events)} unique events)")
print(f"  -> Validation samples: {len(val_dataset)} ({len(val_events)} unique events)")
print(f"  -> Test samples: {len(test_dataset)} ({len(test_events)} unique events)")


# --- 3. Training Execution ---
pl.seed_everything(42)
torch.set_float32_matmul_precision('medium')
GDRIVE_SAVE_DIR = Path('/content/drive/MyDrive/opticflood_phd_project')

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)

# Model
model = SmpUNetBaseline(learning_rate=1e-4)

# Callbacks
logger = pl.loggers.CSVLogger(GDRIVE_SAVE_DIR / 'reports', name='unet_9band_event_split')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=GDRIVE_SAVE_DIR / 'models', filename='unet-9band-event-split-best',
    monitor='val_f1_water', mode='max', save_top_k=1
)

trainer = pl.Trainer(
    max_epochs=50, # Train a bit longer on the new, harder split
    accelerator='gpu', devices=1, logger=logger,
    callbacks=[checkpoint_callback]
)

print("\n--- Starting Training with Academically Rigorous Event-Wise Split ---")
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
print("--- Training complete ---")

INFO:lightning_fabric.utilities.seed:Seed set to 42


--- Setting up Data with Event-Wise (Grouped) Splitting ---
Train/Val intersection: set()
Train/Test intersection: set()
Val/Test intersection: set()

Event-Wise Split Complete:
  -> Training samples: 2212 (19 unique events)
  -> Validation samples: 134 (4 unique events)
  -> Test samples: 329 (6 unique events)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs



--- Starting Training with Academically Rigorous Event-Wise Split ---


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /content/drive/MyDrive/opticflood_phd_project/models exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | Unet              | 24.5 M | train
1 | loss_fn | DiceLoss          | 0      | train
2 | val_f1  | MulticlassF1Score | 0      | train
------------------------------------------------------
24.5 M    Trainable params
0         Non-trainable params
24.5 M    Total params
97.821    Total estimated model params size (MB)
190       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (35) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


--- Training complete ---
