## Flood model exploration

In [1]:
from datetime import datetime
import numpy as np
import pandas as pd
import os
import random
import rasterio
import torch
from pathlib import Path
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

### Load in the data

In [2]:
root_dir = '/home/k3blu3/datasets/s1floods'
meta_file = 'flood-training-metadata.csv'
image_dir = os.path.join(root_dir, 'train_features')
label_dir = os.path.join(root_dir, 'train_labels')

In [3]:
df = pd.read_csv(os.path.join(root_dir, meta_file))
df.head(5)

Unnamed: 0,image_id,chip_id,flood_id,polarization,location,scene_start
0,awc00_vh,awc00,awc,vh,Bolivia,2018-02-15
1,awc00_vv,awc00,awc,vv,Bolivia,2018-02-15
2,awc01_vh,awc01,awc,vh,Bolivia,2018-02-15
3,awc01_vv,awc01,awc,vv,Bolivia,2018-02-15
4,awc02_vh,awc02,awc,vh,Bolivia,2018-02-15


In [4]:
df= df.drop_duplicates(subset=['chip_id'], ignore_index=True)

In [5]:
# add columns for each extension data type with path to full image
extensions = ['vv', 'vh', 'nasadem', 'jrc-gsw-occurrence']
ext_paths = dict()
label_paths = list()
for ext in extensions:
    ext_paths[ext] = list()

for idx, row in tqdm(df.iterrows(), total=len(df)):
    for ext in extensions:
        fname = os.path.join(image_dir, f"{row['chip_id']}_{ext}.tif")
        ext_paths[ext].append(fname)
        
    fname = os.path.join(label_dir, f"{row['chip_id']}.tif")
    label_paths.append(fname)
        
for ext in extensions:
    df[ext] = ext_paths[ext]
    
df['label'] = label_paths

  0%|          | 0/542 [00:00<?, ?it/s]

### Split data into training and validation

In [6]:
random.seed(9)

In [7]:
# sample three floods for validation set
flood_ids = df.flood_id.unique().tolist()
val_flood_ids = random.sample(flood_ids, 3)
val_flood_ids

['pxs', 'qxb', 'jja']

In [8]:
df_val = df[df.flood_id.isin(val_flood_ids)]
df_train = df[~df.flood_id.isin(val_flood_ids)]

In [9]:
val_x = df_val[['chip_id'] + extensions]
val_y = df_val['label']

train_x = df_train[['chip_id'] + extensions]
train_y = df_train['label']

In [10]:
train_x = train_x.reset_index(drop=True)
train_y = train_y.reset_index(drop=True)
val_x = val_x.reset_index(drop=True)
val_y = val_y.reset_index(drop=True)

### Build dataset for model

In [11]:
# write a function to rescale an input image with percentile scaling
def rescale_img(img, min_val=0.0, max_val=1.0, dtype=np.float32, pmin=0.0, pmax=100.0, vmin=None, vmax=None):
    # compute min and max percentile ranges to scale with
    if not vmin:
        vmin, vmax = np.nanpercentile(img, pmin), np.nanpercentile(img, pmax)

    # rescale & clip
    img_rescale = ((img - vmin) * (1.0 / (vmax - vmin) * max_val)).astype(dtype)
    np.clip(img_rescale, min_val, max_val, out=img_rescale)

    return img_rescale

In [12]:
extensions = ['vv', 'vh', 'nasadem', 'jrc-gsw-occurrence']
norms = [(-50, 0), (-50, 0), (0, 500), (0, 100)]

In [13]:
class FloodDataset(torch.utils.data.Dataset):
    def __init__(self, x, y=None):
        self.data = x
        self.label = y
        self.extensions = extensions
        self.norms = norms
        
    def __len__(self):
        return len(self.data)
    
    def __read_image(self, img):
        data_stack = list()
        for ext, norm in zip(self.extensions, self.norms):
            # read as masked array
            with rasterio.open(img['vv']) as src:
                data = src.read(1)
                
            # normalize this layer
            data = rescale_img(data, vmin=norm[0], vmax=norm[1])
            
            # append to stack
            data_stack.append(data)
            
        data_stack = np.stack(data_stack, axis=0)
        return data_stack
    
    def __read_label(self, label):
        with rasterio.open(label) as src:
            data = src.read(1)
            
        return data
    
    def __getitem__(self, idx):
        # grab image from dataframe
        img = self.data.loc[idx]
        
        # read in image layers, normalize each layer
        x_arr = self.__read_image(img)
        
        # prepare sample dictionary
        sample = {
            'chip_id': img.chip_id,
            'chip': x_arr
        }
        
        # load label (during training)
        if self.label is not None:
            label = self.label.loc[idx]
            sample['label'] = self.__read_label(label)
            
        return sample

### Define loss function and IOU

In [14]:
class XEDiceLoss(torch.nn.Module):
    """
    Computes (0.5 * CrossEntropyLoss) + (0.5 * DiceLoss).
    """

    def __init__(self):
        super().__init__()
        self.xe = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, pred, true):
        valid_pixel_mask = true.ne(255)  # valid pixel mask

        # Cross-entropy loss
        temp_true = torch.where((true == 255), 0, true)  # cast 255 to 0 temporarily
        xe_loss = self.xe(pred, temp_true)
        xe_loss = xe_loss.masked_select(valid_pixel_mask).mean()

        # Dice loss
        pred = torch.softmax(pred, dim=1)[:, 1]
        pred = pred.masked_select(valid_pixel_mask)
        true = true.masked_select(valid_pixel_mask)
        dice_loss = 1 - (2.0 * torch.sum(pred * true)) / (torch.sum(pred + true) + 1e-7)

        return (0.5 * xe_loss) + (0.5 * dice_loss)

In [15]:
def intersection_and_union(pred, true):
    """
    Calculates intersection and union for a batch of images.

    Args:
        pred (torch.Tensor): a tensor of predictions
        true (torc.Tensor): a tensor of labels

    Returns:
        intersection (int): total intersection of pixels
        union (int): total union of pixels
    """
    valid_pixel_mask = true.ne(255)  # valid pixel mask
    true = true.masked_select(valid_pixel_mask).to("cpu")
    pred = pred.masked_select(valid_pixel_mask).to("cpu")

    # Intersection and union totals
    intersection = np.logical_and(true, pred)
    union = np.logical_or(true, pred)
    return intersection.sum(), union.sum()

### Build UNet model

In [16]:
class FloodModel(pl.LightningModule):
    def __init__(self, hparams):
        super(FloodModel, self).__init__()
        self.hparams.update(hparams)
        self.save_hyperparameters()
        self.backbone = self.hparams.get("backbone", "resnet34")
        self.weights = self.hparams.get("weights", "imagenet")
        self.learning_rate = self.hparams.get("lr", 1e-3)
        self.max_epochs = self.hparams.get("max_epochs", 1000)
        self.min_epochs = self.hparams.get("min_epochs", 6)
        self.patience = self.hparams.get("patience", 4)
        self.num_workers = self.hparams.get("num_workers", 2)
        self.batch_size = self.hparams.get("batch_size", 32)
        self.x_train = self.hparams.get("x_train")
        self.y_train = self.hparams.get("y_train")
        self.x_val = self.hparams.get("x_val")
        self.y_val = self.hparams.get("y_val")
        self.output_path = self.hparams.get("output_path", "model-outputs")
        self.gpu = self.hparams.get("gpu", False)

        # Where final model will be saved
        self.output_path = Path.cwd() / self.output_path
        self.output_path.mkdir(exist_ok=True)

        # Track validation IOU globally (reset each epoch)
        self.intersection = 0
        self.union = 0

        # Instantiate datasets, model, and trainer params
        self.train_dataset = FloodDataset(
            self.x_train, self.y_train
        )
        self.val_dataset = FloodDataset(self.x_val, self.y_val)
        self.model = self._prepare_model()
        self.trainer_params = self._get_trainer_params()

    ## Required LightningModule methods ##

    def forward(self, image):
        # Forward pass
        return self.model(image)

    def training_step(self, batch, batch_idx):
        # Switch on training mode
        self.model.train()
        torch.set_grad_enabled(True)

        # Load images and labels
        x = batch["chip"]
        y = batch["label"].long()
        if self.gpu:
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)

        # Forward pass
        preds = self.forward(x)

        # Calculate training loss
        criterion = XEDiceLoss()
        xe_dice_loss = criterion(preds, y)

        # Log batch xe_dice_loss
        self.log(
            "xe_dice_loss",
            xe_dice_loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return xe_dice_loss

    def validation_step(self, batch, batch_idx):
        # Switch on validation mode
        self.model.eval()
        torch.set_grad_enabled(False)

        # Load images and labels
        x = batch["chip"]
        y = batch["label"].long()
        if self.gpu:
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)

        # Forward pass & softmax
        preds = self.forward(x)
        preds = torch.softmax(preds, dim=1)[:, 1]
        preds = (preds > 0.5) * 1

        # Calculate validation IOU (global)
        intersection, union = intersection_and_union(preds, y)
        self.intersection += intersection
        self.union += union

        # Log batch IOU
        batch_iou = intersection / union
        self.log(
            "iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return batch_iou

    def train_dataloader(self):
        # DataLoader class for training
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        # DataLoader class for validation
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=True,
        )

    def configure_optimizers(self):
        # Define optimizer
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # Define scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="max", factor=0.5, patience=self.patience
        )
        scheduler = {
            "scheduler": scheduler,
            "interval": "epoch",
            "monitor": "val_loss",
        }  # logged value to monitor
        return [optimizer], [scheduler]

    def validation_epoch_end(self, outputs):
        # Calculate IOU at end of epoch
        epoch_iou = self.intersection / self.union

        # Reset metrics before next epoch
        self.intersection = 0
        self.union = 0

        # Log epoch validation IOU
        self.log("val_loss", epoch_iou, on_epoch=True, prog_bar=True, logger=True)
        return epoch_iou

    ## Convenience Methods ##

    def _prepare_model(self):
        unet_model = smp.Unet(
            encoder_name=self.backbone,
            encoder_weights=self.weights,
            in_channels=4,
            classes=2,
        )
        if self.gpu:
            unet_model.cuda()
        return unet_model

    def _get_trainer_params(self):
        # Define callback behavior
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self.output_path,
            monitor="val_loss",
            mode="max",
            verbose=True,
        )
        early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
            monitor="val_loss",
            patience=(self.patience * 3),
            mode="max",
            verbose=True,
        )

        # Specify where TensorBoard logs will be saved
        self.log_path = Path.cwd() / self.hparams.get("log_path", "tensorboard-logs")
        self.log_path.mkdir(exist_ok=True)
        logger = pl.loggers.TensorBoardLogger(self.log_path, name="benchmark-model")

        trainer_params = {
            "callbacks": [checkpoint_callback, early_stop_callback],
            "max_epochs": self.max_epochs,
            "min_epochs": self.min_epochs,
            "default_root_dir": self.output_path,
            "logger": logger,
            "gpus": None if not self.gpu else 1,
            "fast_dev_run": self.hparams.get("fast_dev_run", False),
            "num_sanity_val_steps": self.hparams.get("val_sanity_checks", 0),
        }
        return trainer_params

    def fit(self):
        # Set up and fit Trainer object
        self.trainer = pl.Trainer(**self.trainer_params)
        self.trainer.fit(self)

### Train the model

In [17]:
hparams = {
    # Required hparams
    "x_train": train_x,
    "x_val": val_x,
    "y_train": train_y,
    "y_val": val_y,
    # Optional hparams
    "backbone": "resnet34",
    "weights": "imagenet",
    "lr": 1e-3,
    "min_epochs": 6,
    "max_epochs": 1000,
    "patience": 4,
    "batch_size": 32,
    "num_workers": 0,
    "val_sanity_checks": 0,
    "fast_dev_run": False,
    "output_path": "model-outputs",
    "log_path": "tensorboard_logs",
    "gpu": torch.cuda.is_available(),
}

In [18]:
flood_model = FloodModel(hparams=hparams)

In [19]:
flood_model.fit()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | Unet | 24.4 M
-------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.759    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Validating: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.132
Epoch 0, global step 11: val_loss reached 0.13194 (best 0.13194), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=0-step=11.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.140
Epoch 1, global step 23: val_loss reached 0.13985 (best 0.13985), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=1-step=23.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.136 >= min_delta = 0.0. New best score: 0.275
Epoch 2, global step 35: val_loss reached 0.27548 (best 0.27548), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=2-step=35.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.027 >= min_delta = 0.0. New best score: 0.303
Epoch 3, global step 47: val_loss reached 0.30254 (best 0.30254), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=3-step=47.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 59: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.307
Epoch 5, global step 71: val_loss reached 0.30660 (best 0.30660), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=5-step=71.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.316
Epoch 6, global step 83: val_loss reached 0.31600 (best 0.31600), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=6-step=83.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 0.337
Epoch 7, global step 95: val_loss reached 0.33691 (best 0.33691), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=7-step=95.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 107: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 119: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 131: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 143: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 155: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 167: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 179: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 191: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.347
Epoch 16, global step 203: val_loss reached 0.34716 (best 0.34716), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=16-step=203.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.359
Epoch 17, global step 215: val_loss reached 0.35949 (best 0.35949), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=17-step=215.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 227: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 239: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.365
Epoch 20, global step 251: val_loss reached 0.36486 (best 0.36486), saving model to "/home/k3blu3/dev/s1floods/notebooks/model-outputs/epoch=20-step=251.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 21, global step 263: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 22, global step 275: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 23, global step 287: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 24, global step 299: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 25, global step 311: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 26, global step 323: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 27, global step 335: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 28, global step 347: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 359: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 30, global step 371: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 31, global step 383: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 12 records. Best score: 0.365. Signaling Trainer to stop.
Epoch 32, global step 395: val_loss was not in top 1


In [21]:
flood_model.state_dict()

OrderedDict([('model.encoder.conv1.weight',
              tensor([[[[ 9.0117e-03, -2.4840e-05,  8.0679e-03,  ...,  3.2420e-02,
                          1.4833e-02,  1.5086e-02],
                        [ 3.3601e-02,  2.4113e-02,  1.9424e-02,  ...,  1.3613e-02,
                          8.3309e-03,  2.4300e-02],
                        [ 5.1429e-03, -2.6324e-02, -5.4711e-02,  ..., -8.6984e-02,
                         -1.0312e-01, -1.0005e-01],
                        ...,
                        [-8.0314e-03, -2.0910e-02, -1.3609e-02,  ...,  4.4407e-03,
                         -9.5135e-03,  5.5440e-03],
                        [ 4.1653e-03,  2.8424e-03,  2.5188e-02,  ...,  7.5664e-02,
                          4.9421e-02,  4.0625e-02],
                        [ 1.7020e-02,  1.1079e-02,  2.6930e-02,  ...,  7.4759e-02,
                          6.1966e-02,  7.2752e-02]],
              
                       [[-8.3582e-04, -1.4213e-03,  1.2762e-02,  ...,  2.1041e-02,
                  