# Catalyst segmentation tutorial

Authors: Dmitry Bleklov, Roman Tezikov

### Colab extras

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `"Python 3"` and `"GPU"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run` all and watch the tutorial.



## Requirements

Download and install the latest version of catalyst and other libraries required for this tutorial.

In [0]:
!pip install -U catalyst

# this variable will be used in `runner.train` and by default we disable FP16 mode
is_fp16_used = False

# for augmentations
!pip install -U albumentations

# for TTA
!pip install ttach

# for tensorboard
!pip install tensorflow
%load_ext tensorboard


# if Your machine doesn't support FP16, comment this 3 lines below
!git clone https://github.com/NVIDIA/apex
!pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
is_fp16_used = True

## Setting up GPUs
PyTorch and Catalyst versions:

In [0]:
import torch, catalyst

torch.__version__, catalyst.__version__

You can also specify GPU/CPU usage for this turorial.

Available GPUs

In [0]:
from catalyst.utils import get_available_gpus

get_available_gpus()

In [0]:
import os
from typing import List, Tuple, Callable

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

---

## Reproducibility first

Let's set our seed and set the CUDA settings to deterministic mode.

In [0]:
from catalyst.utils import set_global_seed, prepare_cudnn

SEED = 42

set_global_seed(SEED)
prepare_cudnn(deterministic=True)

## Dataset

As a data set we will take Carvana - binary segmentation for the "car" class.

If you are on MacOS and you donâ€™t have `wget`, you can install it with: `brew install wget`.


In [0]:
%%bash

function gdrive_download () {
  CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')
  wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2
  rm -rf /tmp/cookies.txt
}

gdrive_download 1iYaNijLmzsrMlAdMoUEhhJuo-5bkeAuj segmentation_data.zip

unzip segmentation_data.zip &>/dev/null

In [0]:
from pathlib import Path

ROOT = Path("segmentation_data/")

train_image_path = ROOT / "train"
train_mask_path = ROOT / "train_masks"
test_image_path = ROOT / "test"

Collect images and masks into variables.

In [0]:
ALL_IMAGES = sorted(train_image_path.glob("*.jpg"))
len(ALL_IMAGES)

In [0]:
ALL_MASKS = sorted(train_mask_path.glob("*.gif"))
len(ALL_MASKS)

In [0]:
import numpy as np

from skimage.io import imread as gif_imread
from catalyst import utils

import matplotlib.pyplot as plt

import random


def show_examples(name: str, image: np.ndarray, mask: np.ndarray):
    plt.figure(figsize=(10, 14))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Image: {name}")

    plt.subplot(1, 2, 2)
    plt.imshow(mask)
    plt.title(f"Mask: {name}")


def show(index: int, images: List[Path], masks: List[Path], transforms=None) -> None:
    image_path = images[index]
    name = image_path.name

    image = utils.imread(image_path)
    mask = gif_imread(masks[index])

    if transforms is not None:
        temp = transforms(image=image, mask=mask)
        image = temp["image"]
        mask = temp["mask"]

    show_examples(name, image, mask)

def show_random(images: List[Path], masks: List[Path], transforms=None) -> None:
    length = len(images)
    index = random.randint(0, length - 1)
    show(index, images, masks, transforms)

You can restart the cell below to see more examples.

In [0]:
show_random(ALL_IMAGES, ALL_MASKS)

The dataset below reads images and masks and optionally applies augmentation to them.

In [0]:
from typing import List

from torch.utils.data import Dataset


class SegmentationDataset(Dataset):
    def __init__(
        self,
        images: List[Path],
        masks: List[Path] = None,
        transforms=None
    ) -> None:
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image_path = self.images[idx]
        image = utils.imread(image_path)
        
        result = {"image": image}
        
        if self.masks is not None:
          mask = gif_imread(self.masks[idx])
          result["mask"] = mask
        
        if self.transforms is not None:
            result = self.transforms(**result)
        
        result["filename"] = image_path.name

        return result

### Augmentations

The [albumentation](https://github.com/albu/albumentations) library works with images and masks at the same time, which is what we need.

In [0]:
import albumentations as A
from albumentations.pytorch import ToTensor


def pre_transforms(image_size=224):
    return [A.Resize(image_size, image_size, p=1)]
  

def resize_transforms(image_size=224):
    BORDER_CONSTANT = 0
    pre_size = int(image_size * 1.5)

    random_crop = A.Compose([
      A.SmallestMaxSize(pre_size, p=1),
      A.RandomCrop(
          image_size, image_size, p=1
      )

    ])

    rescale = A.Compose([A.Resize(image_size, image_size, p=1)])

    random_crop_big = A.Compose([
      A.LongestMaxSize(pre_size, p=1),
      A.RandomCrop(
          image_size, image_size, p=1
      )

    ])

    # Converts the image to a square of size image_size x image_size
    result = [
      A.OneOf([
          random_crop,
          rescale,
          random_crop_big
      ], p=1)
    ]

    return result
  
def post_transforms():
    # we use ImageNet image normalization
    # and convert it to torch.Tensor
    return [A.Normalize(), ToTensor()]
  
def compose(transforms_to_compose):
    # combine all augmentations into one single pipeline
    result = A.Compose([
      item for sublist in transforms_to_compose for item in sublist
    ])
    return result


def hard_transforms():
    result = [
      A.RandomRotate90(),
      A.Cutout(),
      A.RandomBrightnessContrast(
          brightness_limit=0.2, contrast_limit=0.2, p=0.3
      ),
      A.GridDistortion(p=0.3),
      A.HueSaturationValue(p=0.3)
    ]

    return result


In [0]:
train_transforms = compose([resize_transforms(), hard_transforms(), post_transforms()])
valid_transforms = compose([pre_transforms(), post_transforms()])

show_transforms = compose([resize_transforms(), hard_transforms()])

You can restart the cell below to see more examples of augmentations.

In [0]:
show_random(ALL_IMAGES, ALL_MASKS, transforms=show_transforms)

-------

## Loaders

In [0]:
import collections
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader

def get_loaders(
    images: List[Path],
    masks: List[Path],
    random_state: int,
    valid_size: float = 0.2,
    batch_size: int = 32,
    num_workers: int = 4,
    train_transforms_fn = None,
    valid_transforms_fn = None,
) -> dict:

    indices = np.arange(len(images))

    # Let's divide the data set into train and valid parts.
    train_indices, valid_indices = train_test_split(
      indices, test_size=valid_size, random_state=random_state, shuffle=True
    )

    np_images = np.array(images)
    np_masks = np.array(masks)

    # Creates our train dataset
    train_dataset = SegmentationDataset(
      images = np_images[train_indices].tolist(),
      masks = np_masks[train_indices].tolist(),
      transforms = train_transforms_fn
    )

    # Creates our valid dataset
    valid_dataset = SegmentationDataset(
      images = np_images[valid_indices].tolist(),
      masks = np_masks[valid_indices].tolist(),
      transforms = valid_transforms_fn
    )

    # Catalyst uses normal torch.data.DataLoader
    train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers
    )

    valid_loader = DataLoader(
      valid_dataset,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers
    )

    # And excpect to get an OrderedDict of loaders
    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    return loaders

In [0]:
if is_fp16_used:
  batch_size = 64
else:
  batch_size = 32

loaders = get_loaders(
    images=ALL_IMAGES,
    masks=ALL_MASKS,
    random_state=SEED,
    train_transforms_fn=train_transforms,
    valid_transforms_fn=valid_transforms,
    batch_size=batch_size
)

## Model

The Catalyst has many segmentation models. 

In [0]:
from catalyst.contrib.models import ResnetFPNUnet

model = ResnetFPNUnet(num_classes=1, arch="resnet18", pretrained=True)

## Model training
### Preparing training params.

We will optimize loss as the sum of IoU, Dice and BCE, specifically this function: $IoU + Dice + 0.8*BCE$.


In [0]:
from torch import nn

from catalyst.contrib.criterion import DiceLoss, IoULoss

# we have multiple criterions
criterion = {
    "dice": DiceLoss(),
    "iou": IoULoss(),
    "bce": nn.BCEWithLogitsLoss()
}

In [0]:
from torch import optim

from catalyst.contrib.optimizers import RAdam, Lookahead

learning_rate = 0.001
encoder_learning_rate = 0.0005

# Since we use a pre-trained encoder, we will reduce the learning rate on it.
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}

# This function removes weight_decay for biases and applies our layerwise_params
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)

# Catalyst has new SOTA optimizers out of box
base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = Lookahead(base_optimizer)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)

In [0]:
from catalyst.dl import SupervisedRunner

num_epochs = 3
logdir = "./logs/segmentation"

device = utils.get_device()
print(f"device: {device}")

if is_fp16_used:
  fp16_params = dict(opt_level="O1") # params for FP16
else:
  fp16_params = None

print(f"FP16 params: {fp16_params}")


# by default SupervisedRunner uses "features" and "targets", in our case we get "image" and "mask" keys in dataset __getitem__
runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask")

### Monitoring in tensorboard

If you do not have a Tensorboard opened after you have run the cell below, try running the cell again.

In [0]:
%tensorboard --logdir {logdir}

### Running the train-loop

In [0]:
from catalyst.dl.callbacks import DiceCallback, IouCallback, \
  CriterionCallback, CriterionAggregatorCallback

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    
    # our dataloaders
    loaders=loaders,
    
    callbacks=[
        # Each criterion is calculated separately.
        CriterionCallback(
            input_key="mask",
            prefix="loss_dice",
            criterion_key="dice"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_iou",
            criterion_key="iou"
        ),
        CriterionCallback(
            input_key="mask",
            prefix="loss_bce",
            criterion_key="bce",
            multiplier=0.8
        ),
        
        # And only then we aggregate everything into one loss.
        CriterionAggregatorCallback(
            prefix="loss",
            loss_keys=["loss_dice", "loss_iou", "loss_bce"],
            loss_aggregate_fn="sum" # or "mean"
        ),
        
        # metrics
        DiceCallback(input_key="mask"),
        IouCallback(input_key="mask"),
    ],
    # path to save logs
    logdir=logdir,
    
    num_epochs=num_epochs,
    
    # save our best checkpoint by IoU metric
    main_metric="iou",
    # IoU needs to be maximized.
    minimize_metric=False,
    
    # for FP16. It uses the variable from the very first cell
    fp16=fp16_params,
    
    # prints train logs
    verbose=True
)

## Model inference


In [0]:
TEST_IMAGES = sorted(test_image_path.glob("*.jpg"))
len(TEST_IMAGES)

In [0]:
# create test dataset
test_dataset = SegmentationDataset(
    TEST_IMAGES, 
    transforms=valid_transforms
)


num_workers: int = 4

infer_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

In [0]:

# this get predictions for the whole loader
predictions = runner.predict_loader(
    model=model,
    loader=infer_loader,
    resume=f"{logdir}/checkpoints/best.pth",
    verbose=True,
)

The prediction type is `np.ndarray`

In [0]:
print(type(predictions))
print(predictions.shape)

In [0]:
threshold = 0.5
max_count = 5


def detach(tensor: torch.Tensor) -> np.ndarray:
    return tensor.detach().cpu().numpy()

for i, (features, logits) in enumerate(zip(test_dataset, predictions)):
    image = detach(features["image"])
    
    
    image = image.transpose(1, 2, 0)
    mask_ = torch.from_numpy(logits[0]).sigmoid()
    mask = detach(mask_ > threshold).astype("uint8")
        
    show_examples(name="", image=image, mask=mask)
    
    if i >= max_count:
        break

### Test-time augmentations (TTA)

`ttach` is a new awesome library for test-time augmentation for segmentation or classification tasks

In [0]:
import ttach as tta

# D4 makes horizontal and vertical flips + rotations for [0, 90, 180, 270] angels.
# and then merges the result masks with merge_mode="mean"
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")

tta_runner = SupervisedRunner(
    model=tta_model,
    device=utils.get_device(),
    input_key="image"
)

In [0]:
infer_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=num_workers
)

In [0]:
batch = next(iter(infer_loader))

if is_fp16_used:
  # Move our batch to FP16
  batch["image"] = batch["image"].half()

batch

In [0]:
# predict_batch will automatically move the batch to the Runner's device

tta_predictions = tta_runner.predict_batch(batch)

In [0]:
tta_predictions

Shape is `batch_size x channels x height x width`

In [0]:
tta_predictions["logits"].shape

Let's see our mask after TTA

In [0]:
threshold = 0.5

image = detach(batch["image"][0]).astype(float)


image = image.transpose(1, 2, 0)

logits = tta_predictions["logits"][0, 0].sigmoid()
mask = detach(logits > threshold).astype("uint8")

show_examples(name="", image=image, mask=mask)