In [None]:
# In google colab, 
# For master version of catalyst, uncomment:
# (master version should be fully compatible with this notebook)
# ! pip install git+git://github.com/catalyst-team/catalyst.git

# For last release version of catalyst, uncomment:
# ! pip install catalyst

# For specific commit version of catalyst, uncomment:
# ! pip install git+http://github.com/catalyst-team/catalyst.git@{commit_hash}

# Segmentation

If you have Unet, all CV is segmentation now.

## Goals

- train Unet on isbi dataset
- visualize the predictions

# Preparation

In [None]:
# Get the data:
! wget -P ./data/ https://www.dropbox.com/s/0rvuae4mj6jn922/isbi.tar.gz
! tar -xf ./data/isbi.tar.gz -C ./data/ 

Final folder structure with training data:
```bash
catalyst-examples/
    data/
        isbi/
            train-volume.tif
            train-labels.tif
```

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Data

In [None]:
# ! pip install tifffile

In [None]:
import tifffile as tiff

images = tiff.imread('./data/isbi/train-volume.tif')
masks = tiff.imread('./data/isbi/train-labels.tif')

data = list(zip(images, masks))

train_data = data[:-4]
valid_data = data[-4:]

In [None]:
import collections
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from catalyst.data import Augmentor
from catalyst.dl import utils

bs = 4
num_workers = 4

data_transform = transforms.Compose([
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)),
    Augmentor(
        dict_key="features",
        augment_fn=transforms.Normalize(
            (0.5, ),
            (0.5, ))),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0))
])

open_fn = lambda x: {"features": x[0], "targets": x[1]}

loaders = collections.OrderedDict()

train_loader = utils.get_loader(
    train_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    num_workers=num_workers, 
    shuffle=True)

valid_loader = utils.get_loader(
    valid_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    num_workers=num_workers, 
    shuffle=False)

loaders["train"] = train_loader
loaders["valid"] = valid_loader

# Model

In [None]:
from catalyst.contrib.models.segmentation import Unet

# Train

In [None]:
import torch
import torch.nn as nn
from catalyst.dl.runner import SupervisedRunner

# experiment setup
num_epochs = 50
logdir = "./logs/segmentation_notebook"

# model, criterion, optimizer
model = Unet(num_classes=1, in_channels=1, num_channels=64, num_blocks=4)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3)


# model runner
runner = SupervisedRunner()

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True
)

# Inference

In [None]:
from catalyst.dl.callbacks import InferCallback, CheckpointCallback
loaders = collections.OrderedDict([("infer", loaders["valid"])])
runner.infer(
    model=model,
    loaders=loaders,
    callbacks=[
        CheckpointCallback(
            resume=f"{logdir}/checkpoints/best.pth"),
        InferCallback()
    ],
)

# Predictions visualization

In [None]:
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

In [None]:
sigmoid = lambda x: 1/(1 + np.exp(-x))

for i, (input, output) in enumerate(zip(
        valid_data, runner.callbacks[1].predictions["logits"])):
    image, mask = input
    
    threshold = 0.5
    
    plt.figure(figsize=(10,8))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image, 'gray')
    
    plt.subplot(1, 3, 2)
    output = sigmoid(output[0].copy())
    output = (output > threshold).astype(np.uint8)
    plt.imshow(output, 'gray')
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, 'gray')
    
    plt.show()