Based on https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb

# Segmentation Models

[Segmentation Models for PyTorch](https://github.com/qubvel/segmentation_models.pytorch) (SMP) is a Python library with which you can easily build neural networks for **semantic image segmentation**. You can install it with 

```bash
pip install segmentation-models-pytorch
```

and import it in your code base with

```python
import segmentation_models_pytorch as smp
```

## Create the model

The first step is to set up the semantic segmentation model itself. The library gives you a lot of choices.

For example, you can choose:

* The **model architecture**: UNet, FPN,...
* The **encoder**: ResNet, Inception, VGG, MobileNet,...
* The pretrained encoder **weights** to use: ImageNet plus some other large dataset, depending on the chosen encoder. If `None`, the weights will be randomly initialized.
* The number of **output channels** (= number of classes)
* **Activation function**: activation to apply to the output of the final convolutional layer
* Some other things

In [None]:
import segmentation_models_pytorch as smp


ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['car']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation


model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

## Define train and test transforms

To avoid overfitting, we apply some basic data augmentation transforms.

Note that the concatenations in the decoder part are only possible when the input size is **divisible by 32** (and larger than 64 pixels).

In [None]:
import albumentations as A

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0,
                       shift_limit=0.1, p=1, border_mode=0),
    A.PadIfNeeded(min_height=320, min_width=320,
                  always_apply=True, border_mode=0),
    A.RandomCrop(height=320, width=320, always_apply=True),
])


test_transform = A.Compose([
    # Add paddings to make image shape divisible by 32
    A.PadIfNeeded(
        min_height=None,
        min_width=None,
        pad_height_divisor=32,
        pad_width_divisor=32,
    )
])

## Define preprocessing transform

Apart from the data augmentations, we also apply a preprocessing step to transform the image into an input that is compatible with the model, i.e. the same normalization as the pretrained weights and the correct data type.

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocess_transform = A.Compose([
    A.Lambda(image=preprocessing_fn),
    A.Lambda(image=to_tensor, mask=to_tensor),
])

## Create datasets and dataloaders

In [None]:
from lib.camvid import CamVid

DATA_DIR = './data/CamVid/'

train_dataset = CamVid(
    DATA_DIR, 
    'train', 
    augmentation=train_transform, 
    preprocessing=preprocess_transform,
    classes=CLASSES,
)

valid_dataset = CamVid(
    DATA_DIR,
    'val',
    augmentation=test_transform,
    preprocessing=preprocess_transform,
    classes=CLASSES,
)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=8,
                          shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1,
                          shuffle=False, num_workers=4)

## Define loss, optimizer and metrics

### Soft Dice loss

We use the **soft Dice loss** as a loss function. This is a "soft" loss-variant of the Dice score (or Dice similarity coefficient, DSC):

$$
\text{DSC} = \frac{2|Y_{\text{true}}\cap Y_{\text{pred}}|}{|Y_{\text{true}}| + |Y_{\text{pred}}|}
$$

or, in other words,

$$
\text{DSC} = \frac{2\cdot\text{TP}}{2\text{TP}+\text{FP}+\text{FN}}
$$

(Note that this is exactly the same as the F1 score, which is the harmonic mean between precision and recall.)

As you can see, we need positive (car) and negative (no car) predictions for each pixel to compute the regular Dice similarity coefficient. However, this does not take into account how confident the network is of each prediction. The *soft* Dice score, on the other hand, is calculated as

$$
\text{DSC}_{\text{soft}} = \frac{2\cdot Y_{\text{true}} \odot \tilde{Y}_{\text{pred}}}{Y_{\text{true}} + \tilde{Y}_{\text{pred}}}
$$

with $\tilde{Y}_{\text{pred}}$ containing the confidence scores and $\odot$ the element-wise multiplication of both matrices (or tensors, in general). The **soft Dice loss** is then defined as one minus the soft Dices score:

$$
\text{Soft Dice loss} = 1 - \text{DSC}_{\text{soft}}
$$


See [here](https://github.com/qubvel/segmentation_models.pytorch/blob/740dab561ccf54a9ae4bb5bda3b8b18df3790025/segmentation_models_pytorch/losses/_functional.py#L172) and [here](https://github.com/qubvel/segmentation_models.pytorch/blob/740dab561ccf54a9ae4bb5bda3b8b18df3790025/segmentation_models_pytorch/losses/dice.py#L111) for the implementation in `smp`.

### Jaccard index

To evaluate our result, we use the **Jaccard index**, which is actually another name for the Intersection over Union (IoU):

$$
\text{Jaccard index} = \frac{|Y_{\text{true}}\cap Y_{\text{pred}}|}{|Y_{\text{true}}\cup Y_{\text{pred}}|}
$$

Unlike the soft Dice loss, we pass in a threshold that will turn the model's confidence scores into binary classifications (car/no car).

In [None]:
import torch

# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

## Create epoch runners

This is a simple loop that iterates over the corresponding dataloader's batches.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

## Train the model

In [None]:
num_epochs = 40

max_score = 0

for i in range(0, num_epochs):
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    if max_score < valid_logs['iou_score']:
        # Save model when it is better than the previous best
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

## Test best saved model

In [None]:
# load best saved checkpoint
best_model = torch.load('./best_model.pth')

In [None]:
test_dataset = CamVid(
    DATA_DIR, 'test',
    augmentation=test_transform,
    preprocessing=preprocess_transform,
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=device,
)

logs = test_epoch.run(test_dataloader)

## Visualize predictions

In [None]:
from lib.plot import visualize
import numpy as np

In [None]:
# test dataset without transformations for image visualization
test_dataset_vis = CamVid(
    DATA_DIR, 'test',
    classes=['car'],
)

In [None]:
for i in range(5):
    n = np.random.choice(len(test_dataset))
    
    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(device).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        
    visualize(
        image=image_vis, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )