# Catalyst tiled inference tutorial

Authors: [Pavel Danilov](https://github.com/pdanilov), [Sergey Kolesnilov](https://github.com/scitator)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

### Colab setup

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `"Python 3"` and `"GPU"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run` all and watch the tutorial.



## Requirements

Download and install the latest version of catalyst and other libraries required for this tutorial.

In [None]:
# for augmentations
!pip install albumentations==0.4.3

# for pretrained segmentation models for PyTorch
!pip install segmentation-models-pytorch==0.1.0

################
# Catalyst itself
!pip install -U catalyst
# For specific version of catalyst, uncomment:
# ! pip install git+http://github.com/catalyst-team/catalyst.git@{master/commit_hash}
################

# for tensorboard
!pip install tensorflow

### Colab extras – Plotly

To intergate visualization library `plotly` to colab, run

In [None]:
import IPython

def configure_plotly_browser_state():
    display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
            },
          });
        </script>
        '''))


IPython.get_ipython().events.register('pre_run_cell', configure_plotly_browser_state)

## Setting up GPUs

In [None]:
import os

from typing import Callable, List, Tuple

import torch
import catalyst

from catalyst.dl import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

-------

## Dataset

As a dataset we will take Carvana - binary segmentation for the "car" class.

> If you are on MacOS and you don’t have `wget`, you can install it with: `brew install wget`.

After Catalyst installation, `download-gdrive` function become available to download objects from Google Drive.
We use it to download datasets.

usage: `download-gdrive {FILE_ID} {FILENAME}`

If you have some issues during executing cell below, just try again

In [None]:
%%bash

DATA_ARCHIVE=segmentation_data.zip
if [ ! -f "${DATA_ARCHIVE}" ]; then
    download-gdrive 1iYaNijLmzsrMlAdMoUEhhJuo-5bkeAuj "${DATA_ARCHIVE}" &> /dev/null
    extract-archive "${DATA_ARCHIVE}" &> /dev/null
fi

In [None]:
from pathlib import Path

ROOT = Path("segmentation_data")

train_image_path = ROOT / "train"
train_mask_path = ROOT / "train_masks"
test_image_path = ROOT / "test"

Collect images and masks into variables.

In [None]:
ALL_IMAGES = sorted(train_image_path.glob("*.jpg"))
len(ALL_IMAGES)

In [None]:
ALL_MASKS = sorted(train_mask_path.glob("*.gif"))
len(ALL_MASKS)

In [None]:
TEST_IMAGES = sorted(test_image_path.glob("*.jpg"))
len(TEST_IMAGES)

In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread as gif_imread
from catalyst import utils


def show_examples(name: str, image: np.ndarray, mask: np.ndarray, mask_cmap_gray: bool = True):
    plt.figure(figsize=(10, 14))
    
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Image: {name}")
    
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 2, 2)
    cmap = "gray" if mask_cmap_gray else None
    plt.imshow(mask, cmap=cmap)
    plt.title(f"Mask: {name}")
    
    plt.xticks([])
    plt.yticks([])


def show(index: int, images: List[Path], masks: List[Path], transforms=None) -> None:
    image_path = images[index]
    name = image_path.name

    image = utils.imread(image_path)
    mask = gif_imread(masks[index])

    if transforms is not None:
        temp = transforms(image=image, mask=mask)
        image = temp["image"]
        mask = temp["mask"]

    show_examples(name, image, mask)

def show_random(images: List[Path], masks: List[Path], transforms=None) -> None:
    length = len(images)
    index = random.randint(0, length - 1)
    show(index, images, masks, transforms)

You can restart the cell below to see more examples.

In [None]:
show_random(ALL_IMAGES, ALL_MASKS)

The dataset below reads images and masks and optionally applies augmentation to them.

In [None]:
from typing import List

from torch.utils.data import Dataset


class SegmentationDataset(Dataset):
    def __init__(
        self,
        images: List[Path],
        masks: List[Path] = None,
        transforms=None
    ) -> None:
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image_path = self.images[idx]
        image = utils.imread(image_path)
        
        result = {"image": image}
        
        if self.masks is not None:
            mask = gif_imread(self.masks[idx])
            result["mask"] = mask
        
        if self.transforms is not None:
            result = self.transforms(**result)
        
        result["filename"] = image_path.name

        return result

-------

### Augmentations

[![Albumentation logo](https://albumentations.readthedocs.io/en/latest/_static/logo.png)](https://github.com/albu/albumentations)

The [albumentation](https://github.com/albu/albumentations) library works with images and masks at the same time, which is what we need.

In [None]:
from itertools import chain

import albumentations as albu
from albumentations.pytorch import ToTensor


def hard_transforms():
    result = [
        albu.Cutout(),
        albu.RandomBrightnessContrast(
            brightness_limit=0.2, contrast_limit=0.2, p=0.3
        ),
        albu.GridDistortion(p=0.3),
        albu.HueSaturationValue(p=0.3),
    ]
    return result


def resize_transforms(pre_size=448, image_size=448, train=True):
    result = [
        albu.RandomScale(scale_limit=(0.8, 1.25), p=1 if train else 0),
        albu.RandomCrop(pre_size, pre_size, p=1 if train else 0),
        albu.CenterCrop(pre_size, pre_size, p=0 if train else 1),
        albu.Resize(image_size, image_size, p=1),
    ]
    return result
  
    
def post_transforms():
    # we use ImageNet image normalization
    # and convert it to torch.Tensor
    return [albu.Normalize(), ToTensor()]


def compose(*transforms_to_compose, p=1):
    # combine all augmentations into one single pipeline
    transforms_to_compose = chain.from_iterable(transforms_to_compose)
    result = albu.Compose([*transforms_to_compose], p=p)
    return result

In [None]:
show_transforms = compose(
    hard_transforms(),
    resize_transforms(train=False),
)

Let's look at the augmented results. <br/>
You can restart the cell below to see more examples of augmentations.

In [None]:
show_random(ALL_IMAGES, ALL_MASKS, transforms=show_transforms)

-------

## Experiment
### Model

Catalyst has [several segmentation models](https://github.com/catalyst-team/catalyst/blob/master/catalyst/contrib/models/segmentation/__init__.py#L16) (Unet, Linknet, FPN, PSPnet and their versions with pretrain from Resnet).

> You can read more about them in [our blog post](https://github.com/catalyst-team/catalyst-info#catalyst-info-1-segmentation-models).

But for now let's take the model from [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) (SMP for short). The same segmentation architectures have been implemented in this repository, but there are many more pre-trained encoders.

[![Segmentation Models logo](https://raw.githubusercontent.com/qubvel/segmentation_models.pytorch/master/pics/logo-small-w300.png)](https://github.com/qubvel/segmentation_models.pytorch)

In [None]:
import segmentation_models_pytorch as smp

# We will use Feature Pyramid Network with pre-trained ResNeXt50 backbone
model = smp.FPN(encoder_name="resnext50_32x4d", classes=1)
init_state_dict = model.state_dict()

In [None]:
from catalyst.dl import SupervisedRunner

logdir = "./logs/segmentation"

device = utils.get_device()
print(f"device: {device}")

# by default SupervisedRunner uses "features" and "targets",
# in our case we get "image" and "mask" keys in dataset __getitem__
runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask")

## Tiled Inference
This tutorial is all about inference, so we provide you a checkpoint for trained model, which you can use to watch the results of inference

The main purpose of the method is allow inference on high-resolution image which doesn't fit into GPU memory

The main idea of the method is to split image into square patches with intersection, do inference on each of them and then calculate weighted sum of predictions to form resulting mask

If you have some issues during executing cell below, just try again

In [None]:
%%bash

CHECKPOINT=tiled_inference_best.pth
if [ ! -f "${CHECKPOINT}" ]; then
    download-gdrive 16qfA51v8ver0wUjyT7HSEZoA69U6uSbs "${CHECKPOINT}" &> /dev/null
fi

For Full HD images, we will use relatively big patches of size 448, stepping from one to another by 224 pixels.

In [None]:
import collections

from skimage.io import imread
from torch.utils.data import DataLoader

from catalyst.contrib.data.cv.datasets import TiledImageDataset
from catalyst.contrib.dl.callbacks.tiled_inference import TiledInferenceCallback

In [None]:
NUM_IMAGES = 5
infer_transforms = compose(post_transforms())
infer_images = TEST_IMAGES[:NUM_IMAGES]

infer_dataset = TiledImageDataset(
    images=infer_images,
    train=False,
    tile_size=448,
    tile_step=224,
    input_key="image",
    transform=infer_transforms,
)

batch_size: int = 32

infer_loader = DataLoader(
    infer_dataset,
    batch_size=batch_size,
    shuffle=False,
)

runner.infer(
    model=model,
    loaders=collections.OrderedDict(infer=infer_loader),
    # we use relatively high threshold of 0.9 instead of default 0.5 to avoid inference artifacts
    callbacks=[TiledInferenceCallback(save_dir="predictions", threshold=0.9)],
    # use default option to load pre-trainded checkpoint
    resume="tiled_inference_best.pth",
    # or train model by yourself and use option below
    # resume=f"{logdir}/checkpoints/best.pth",
    verbose=True,
)

Ok, now we are ready to watch segmentation results

In [None]:
masks = np.load("predictions/masks.npy")

for i, image_path in enumerate(infer_images):
    image = imread(image_path)
    mask = masks[i]
    show_examples(name="", image=image, mask=mask)

Also we can watch onto probability heatmaps

In [None]:
probs = np.load("predictions/probs.npy")

for i, image_path in enumerate(infer_images):
    image = imread(image_path)
    prob = probs[i][0]
    show_examples(name="", image=image, mask=prob, mask_cmap_gray=False)

### Conclusion
The results of inference is quite good, but you may think is's not perfect, and then why should I use this method instead of downscale-inference-upscale? The answer is, you can tune your model more, and obtain a scalable solution that works for 4K or even 8K images.

## Optional
You can train the model by youself and reproduce the results

Note: this option requires huge amount of computational resources: reference solution was trained on TITAN V (32 Gb VRAM) or you can try to use multiple GPU machine, but the result may be worse, because DataParallel module averages agradients from GPUs, on each of them batch size will have size around 8-12, that seems to be too small.

### Loaders

In [None]:
from sklearn.model_selection import train_test_split


def get_loaders(
    images: List[Path],
    masks: List[Path],
    random_state: int,
    valid_size: float = 0.2,
    batch_size: int = 32,
    num_workers: int = 4,
    train_transforms_fn = None,
    valid_transforms_fn = None,
) -> dict:

    indices = np.arange(len(images))

    # Let's divide the data set into train and valid parts.
    train_indices, valid_indices = train_test_split(
        indices, test_size=valid_size, random_state=random_state, shuffle=True
    )

    np_images = np.array(images)
    np_masks = np.array(masks)

    # Creates our train dataset
    train_dataset = SegmentationDataset(
        images=np_images[train_indices].tolist(),
        masks=np_masks[train_indices].tolist(),
        transforms=train_transforms_fn
    )

    # Creates our valid dataset
    valid_dataset = SegmentationDataset(
        images=np_images[valid_indices].tolist(),
        masks=np_masks[valid_indices].tolist(),
        transforms=valid_transforms_fn
    )

    # Catalyst uses normal torch.data.DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=True,
    )
    
    loaders = collections.OrderedDict(train=train_loader, valid=valid_loader)
    
    return loaders

### Training preparation
We will train our model on square patches with crop size 448.
This configuration of parameters wat tested on TITAN V

In [None]:
num_epochs: int = 6
batch_size: int = 48

train_transforms = compose(
    resize_transforms(),
    hard_transforms(),
    post_transforms(),
)
valid_transforms = compose(
    resize_transforms(),
    post_transforms(),
)    

loaders = get_loaders(
    images=ALL_IMAGES,
    masks=ALL_MASKS,
    random_state=SEED,
    train_transforms_fn=train_transforms,
    valid_transforms_fn=valid_transforms,
    batch_size=batch_size
)

In [None]:
from torch import optim

from catalyst.contrib.nn import RAdam, Lookahead

learning_rate = 0.001
encoder_learning_rate = 0.0005

# Since we use a pre-trained encoder, we will reduce the learning rate on it.
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}

# This function removes weight_decay for biases and applies our layerwise_params
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)

# Catalyst has new SOTA optimizers out of box
base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = Lookahead(base_optimizer)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)

### Loss
We will optimize loss as the sum of IoU, Dice and BCE, specifically this function: $IoU + Dice + 0.8*BCE$.

In [None]:
from torch import nn

from catalyst.dl.callbacks import DiceCallback, IouCallback, CriterionCallback, MetricAggregationCallback
from catalyst.contrib.nn import DiceLoss, IoULoss

# we have multiple criterions
criterion = {
    "dice": DiceLoss(),
    "iou": IoULoss(),
    "bce": nn.BCEWithLogitsLoss()
}

callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
        input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_iou",
        criterion_key="iou"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_bce",
        criterion_key="bce"
    ),

    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum", # can be "sum", "weighted_sum" or "mean"
        # because we want weighted sum, we need to add scale for each loss
        metrics={"loss_dice": 1.0, "loss_iou": 1.0, "loss_bce": 0.8},
    ),

    # metrics
    DiceCallback(input_key="mask"),
    IouCallback(input_key="mask"),
]

### Running train-loop

In [None]:
model.load_state_dict(init_state_dict)

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    # our dataloaders
    loaders=loaders,
    # We can specify the callbacks list for the experiment;
    callbacks=callbacks,
    # path to save logs
    logdir=logdir,
    num_epochs=num_epochs,
    # save our best checkpoint by IoU metric
    main_metric="iou",
    # IoU needs to be maximized.
    minimize_metric=False,
    # prints train logs
    verbose=True,
)

Your checkpoint is located at logs/segmentation/checkpoints/best.pth and you can also use it for making inference