## Flood model exploration

In [None]:
import albumentations
from datetime import datetime
import numpy as np
import pandas as pd
import os
import random
import rasterio
import torch
import torchvision
from pathlib import Path
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

### Load in the data

In [None]:
root_dir = '/home/k3blu3/datasets/s1floods'
meta_file = 'flood-training-metadata.csv'
image_dir = os.path.join(root_dir, 'train_features')
label_dir = os.path.join(root_dir, 'train_labels')

In [None]:
df = pd.read_csv(os.path.join(root_dir, meta_file))
df.head(5)

In [None]:
df= df.drop_duplicates(subset=['chip_id'], ignore_index=True)

In [None]:
# add columns for each extension data type with path to full image
extensions = ['vv', 'vh', 'nasadem', 'jrc-gsw-occurrence']
ext_paths = dict()
label_paths = list()
for ext in extensions:
    ext_paths[ext] = list()

for idx, row in tqdm(df.iterrows(), total=len(df)):
    for ext in extensions:
        fname = os.path.join(image_dir, f"{row['chip_id']}_{ext}.tif")
        ext_paths[ext].append(fname)
        
    fname = os.path.join(label_dir, f"{row['chip_id']}.tif")
    label_paths.append(fname)
        
for ext in extensions:
    df[ext] = ext_paths[ext]
    
df['label'] = label_paths

### Split data into training and validation

In [None]:
random.seed(9)

In [None]:
# sample three floods for validation set
flood_ids = df.flood_id.unique().tolist()
val_flood_ids = random.sample(flood_ids, 3)
val_flood_ids

In [None]:
df_val = df[df.flood_id.isin(val_flood_ids)]
df_train = df[~df.flood_id.isin(val_flood_ids)]

In [None]:
val_x = df_val[['chip_id'] + extensions]
val_y = df_val['label']

train_x = df_train[['chip_id'] + extensions]
train_y = df_train['label']

In [None]:
train_x = train_x.reset_index(drop=True)
train_y = train_y.reset_index(drop=True)
val_x = val_x.reset_index(drop=True)
val_y = val_y.reset_index(drop=True)

### Build dataset for model

In [None]:
# write a function to rescale an input image with percentile scaling
def rescale_img(img, min_val=0.0, max_val=1.0, dtype=np.float32, pmin=0.0, pmax=100.0, vmin=None, vmax=None):
    # compute min and max percentile ranges to scale with
    if not vmin:
        vmin, vmax = np.nanpercentile(img, pmin), np.nanpercentile(img, pmax)
    
    img_rescale = ((img - vmin) * (1.0 / (vmax - vmin) * max_val)).astype(dtype)
    np.ma.clip(img_rescale, min_val, max_val, out=img_rescale)
    
    img_rescale = np.nan_to_num(img_rescale)

    return img_rescale

In [None]:
extensions = ['vv', 'vh', 'nasadem', 'jrc-gsw-occurrence']
norms = [(-50, 0), (-50, 0), (0, 500), (0, 100)]

In [None]:
xforms = albumentations.Compose(
    [
        albumentations.RandomRotate90(),
        albumentations.HorizontalFlip(),
        albumentations.VerticalFlip()
    ]
)

In [None]:
class FloodDataset(torch.utils.data.Dataset):
    def __init__(self, x, y=None, xforms=None):
        self.data = x
        self.label = y
        self.xforms = xforms
        self.extensions = extensions
        self.norms = norms
        
    def __len__(self):
        return len(self.data)
    
    def __read_image(self, img):
        data_stack = list()
        for idx, (ext, norm) in enumerate(zip(self.extensions, self.norms)):
            # read as masked array
            with rasterio.open(img[ext]) as src:
                data = src.read(1, masked=True)
            
            # normalize this layer
            data = rescale_img(data, vmin=norm[0], vmax=norm[1]).filled(0)
            
            # append to stack
            data_stack.append(data)
        
        # throw data into stack
        data_stack = np.stack(data_stack, axis=-1)
        return data_stack
    
    def __read_label(self, label):
        with rasterio.open(label) as src:
            data = src.read(1)
            
        return data
    
    def __getitem__(self, idx):
        # grab image from dataframe
        self.augment = random.uniform(0, 1)
        
        img = self.data.loc[idx]
        
        # read in image layers, normalize each layer
        image = self.__read_image(img)
        
        # load label (during training)
        if self.label is not None:
            lbl = self.label.loc[idx]
            label = self.__read_label(lbl)
            
            if xforms is not None:
                aug = self.xforms(image=image, mask=label)
                
            image = np.transpose(aug['image'], axes=(2, 0, 1))
            label = aug['mask']
            
            sample = {
                'chip_id': img.chip_id,
                'image': image,
                'target': label
            }
            
            
        return sample

### Make sure the data looks right

In [None]:
idx = 3
ds = FloodDataset(train_x, train_y, xforms)
sample = ds[idx]

In [None]:
plt.figure(dpi=200)
plt.tight_layout()
plt.subplot(1, 5, 1)
plt.axis('off')
plt.imshow(sample['image'][0, :, :])

plt.subplot(1, 5, 2)
plt.axis('off')
plt.imshow(sample['image'][1, :, :])

plt.subplot(1, 5, 3)
plt.axis('off')
plt.imshow(sample['image'][2, :, :])

plt.subplot(1, 5, 4)
plt.axis('off')
plt.imshow(sample['image'][3, :, :])

plt.subplot(1, 5, 5)
plt.axis('off')
label = sample['target']
label[label==255] = 0
plt.imshow(label)

In [None]:
np.unique(label)

### Define loss function and IOU

In [None]:
class XEDiceLoss(torch.nn.Module):
    """
    Computes (0.5 * CrossEntropyLoss) + (0.5 * DiceLoss).
    """

    def __init__(self):
        super().__init__()
        self.xe = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, pred, true):
        valid_pixel_mask = true.ne(255)  # valid pixel mask

        # Cross-entropy loss
        temp_true = torch.where((true == 255), 0, true)  # cast 255 to 0 temporarily
        xe_loss = self.xe(pred, temp_true)
        xe_loss = xe_loss.masked_select(valid_pixel_mask).mean()

        # Dice loss
        pred = torch.softmax(pred, dim=1)[:, 1]
        pred = pred.masked_select(valid_pixel_mask)
        true = true.masked_select(valid_pixel_mask)
        dice_loss = 1 - (2.0 * torch.sum(pred * true)) / (torch.sum(pred + true) + 1e-7)

        return (0.5 * xe_loss) + (0.5 * dice_loss)

In [None]:
def intersection_and_union(pred, true):
    """
    Calculates intersection and union for a batch of images.

    Args:
        pred (torch.Tensor): a tensor of predictions
        true (torc.Tensor): a tensor of labels

    Returns:
        intersection (int): total intersection of pixels
        union (int): total union of pixels
    """
    valid_pixel_mask = true.ne(255)  # valid pixel mask
    true = true.masked_select(valid_pixel_mask).to("cpu")
    pred = pred.masked_select(valid_pixel_mask).to("cpu")

    # Intersection and union totals
    intersection = np.logical_and(true, pred)
    union = np.logical_or(true, pred)
    return intersection.sum(), union.sum()

### Build UNet model

In [None]:
class FloodModel(pl.LightningModule):
    def __init__(self, hparams):
        super(FloodModel, self).__init__()
        self.hparams.update(hparams)
        self.save_hyperparameters()
        self.backbone = self.hparams.get("backbone", "resnet34")
        self.weights = self.hparams.get("weights", "imagenet")
        self.learning_rate = self.hparams.get("lr", 1e-3)
        self.max_epochs = self.hparams.get("max_epochs", 1000)
        self.min_epochs = self.hparams.get("min_epochs", 6)
        self.patience = self.hparams.get("patience", 4)
        self.num_workers = self.hparams.get("num_workers", 2)
        self.batch_size = self.hparams.get("batch_size", 32)
        self.x_train = self.hparams.get("x_train")
        self.y_train = self.hparams.get("y_train")
        self.x_val = self.hparams.get("x_val")
        self.y_val = self.hparams.get("y_val")
        self.output_path = self.hparams.get("output_path", "model-outputs")
        self.gpu = self.hparams.get("gpu", False)
        self.xforms = self.hparams.get("xforms", None)

        # Where final model will be saved
        self.output_path = Path.cwd() / self.output_path
        self.output_path.mkdir(exist_ok=True)

        # Track validation IOU globally (reset each epoch)
        self.intersection = 0
        self.union = 0

        # Instantiate datasets, model, and trainer params
        self.train_dataset = FloodDataset(
            self.x_train, self.y_train, xforms=self.xforms
        )
        self.val_dataset = FloodDataset(self.x_val, self.y_val, xforms=None)
        self.model = self._prepare_model()
        self.trainer_params = self._get_trainer_params()

    ## Required LightningModule methods ##

    def forward(self, image):
        # Forward pass
        return self.model(image)

    def training_step(self, batch, batch_idx):
        # Switch on training mode
        self.model.train()
        torch.set_grad_enabled(True)

        # Load images and labels
        x = batch["chip"]
        y = batch["label"].long()
        if self.gpu:
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)

        # Forward pass
        preds = self.forward(x)

        # Calculate training loss
        criterion = XEDiceLoss()
        xe_dice_loss = criterion(preds, y)

        # Log batch xe_dice_loss
        self.log(
            "xe_dice_loss",
            xe_dice_loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return xe_dice_loss

    def validation_step(self, batch, batch_idx):
        # Switch on validation mode
        self.model.eval()
        torch.set_grad_enabled(False)

        # Load images and labels
        x = batch["chip"]
        y = batch["label"].long()
        if self.gpu:
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)

        # Forward pass & softmax
        preds = self.forward(x)
        preds = torch.softmax(preds, dim=1)[:, 1]
        preds = (preds > 0.5) * 1

        # Calculate validation IOU (global)
        intersection, union = intersection_and_union(preds, y)
        self.intersection += intersection
        self.union += union

        # Log batch IOU
        batch_iou = intersection / union
        self.log(
            "iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return batch_iou

    def train_dataloader(self):
        # DataLoader class for training
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        # DataLoader class for validation
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            pin_memory=True,
        )

    def configure_optimizers(self):
        # Define optimizer
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # Define scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="max", factor=0.5, patience=self.patience
        )
        scheduler = {
            "scheduler": scheduler,
            "interval": "epoch",
            "monitor": "val_loss",
        }  # logged value to monitor
        return [optimizer], [scheduler]

    def validation_epoch_end(self, outputs):
        # Calculate IOU at end of epoch
        epoch_iou = self.intersection / self.union

        # Reset metrics before next epoch
        self.intersection = 0
        self.union = 0

        # Log epoch validation IOU
        self.log("val_loss", epoch_iou, on_epoch=True, prog_bar=True, logger=True)
        return epoch_iou

    ## Convenience Methods ##

    def _prepare_model(self):
        unet_model = smp.Unet(
            encoder_name=self.backbone,
            encoder_weights=self.weights,
            in_channels=4,
            classes=2,
        )
        if self.gpu:
            unet_model.cuda()
        return unet_model

    def _get_trainer_params(self):
        # Define callback behavior
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self.output_path,
            monitor="val_loss",
            mode="max",
            verbose=True,
        )
        early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
            monitor="val_loss",
            patience=(self.patience * 3),
            mode="max",
            verbose=True,
        )

        # Specify where TensorBoard logs will be saved
        self.log_path = Path.cwd() / self.hparams.get("log_path", "tensorboard-logs")
        self.log_path.mkdir(exist_ok=True)
        logger = pl.loggers.TensorBoardLogger(self.log_path, name="benchmark-model")

        trainer_params = {
            "callbacks": [checkpoint_callback, early_stop_callback],
            "max_epochs": self.max_epochs,
            "min_epochs": self.min_epochs,
            "default_root_dir": self.output_path,
            "logger": logger,
            "gpus": None if not self.gpu else 1,
            "fast_dev_run": self.hparams.get("fast_dev_run", False),
            "num_sanity_val_steps": self.hparams.get("val_sanity_checks", 0),
        }
        return trainer_params

    def fit(self):
        # Set up and fit Trainer object
        self.trainer = pl.Trainer(**self.trainer_params)
        self.trainer.fit(self)

### Train the model

In [None]:
hparams = {
    # Required hparams
    "x_train": train_x,
    "x_val": val_x,
    "y_train": train_y,
    "y_val": val_y,
    # Optional hparams
    "backbone": "resnet34",
    "weights": "imagenet",
    "lr": 1e-3,
    "min_epochs": 6,
    "max_epochs": 1000,
    "patience": 4,
    "batch_size": 32,
    "num_workers": 6,
    "val_sanity_checks": 0,
    "fast_dev_run": False,
    "output_path": "model-outputs",
    "log_path": "tensorboard_logs",
    "gpu": torch.cuda.is_available(),
    "xforms": None
}

In [None]:
flood_model = FloodModel(hparams=hparams)

In [None]:
flood_model.fit()

In [None]:
flood_model.state_dict()