# Body segmentation with Catalyst

   The purpose of this work is find out how to use Catalyst for the segmentation task.Make a pipeline for this, using only catalyst.
    Catalyst is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop.

# 1) Import libraries and set hyperparameters

In [None]:
from pathlib import Path
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2 
from catalyst.contrib.nn import DiceLoss, IoULoss
from catalyst.dl import SupervisedRunner
from catalyst import dl
from torch.nn import BCEWithLogitsLoss

import numpy as np
import os
from pathlib import Path
from skimage.io import imread
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'
IMAGE_DIR = '../input/segmentation-full-body-mads-dataset/segmentation_full_body_mads_dataset_1192_img/images'
MASK_DIR = '../input/segmentation-full-body-mads-dataset/segmentation_full_body_mads_dataset_1192_img/masks'
IMAGE_HEIGHT = 200
IMAGE_WIDTH = 200
BATCH_SIZE = 4

# 2) Prepare data

In [None]:
path_img = Path(IMAGE_DIR)
img_list = list(path_img.glob('*.png'))
path_mask = Path(MASK_DIR)
mask_list = list(path_mask.glob('*.png'))

In [None]:
class FullBodySegmentation(data.Dataset):
    def __init__(self, inputs: list, targets: list, transform=None):
        super().__init__() 
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        
    def __len__(self,):
        return len(self.inputs)
    
    def __getitem__(self, idx : int):
        
        input_image = self.inputs[idx]
        target_image = self.targets[idx]

        image = np.array(Image.open(input_image).convert("RGB"))
        mask = np.array(Image.open(target_image).convert("L"), dtype=np.float32)
        mask = mask / 255

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        return image, mask


In [None]:
x_data, x_test ,y_data, y_test = train_test_split(
                                            img_list,
                                            mask_list,
                                            test_size=0.1, 
                                            random_state=42, 
                                            shuffle=True)

x_train, x_val ,y_train, y_val = train_test_split(
                                            x_data,
                                            y_data,
                                            test_size=0.1, 
                                            random_state=42, 
                                            shuffle=True)

In [None]:
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    test_dir,
    test_maskdir,   
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = FullBodySegmentation(
        inputs=train_dir,
        targets=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = FullBodySegmentation(
        inputs=val_dir,
        targets=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    test_ds = FullBodySegmentation(
        inputs=test_dir,
        targets=test_maskdir,
        transform=val_transform,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    return train_loader, val_loader, test_loader

In [None]:
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(),
        ToTensorV2(),
    ],
)

val_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(),
        ToTensorV2(),
    ],
)

In [None]:
 train_loader, val_loader, test_loader = get_loaders(
                                            x_train,
                                            y_train,
                                            x_val,
                                            y_val,
                                            x_test,
                                            y_test,   
                                            BATCH_SIZE,
                                            train_transform,
                                            val_transform,
                                            num_workers=2,
                                            pin_memory=True,
                                        )
loaders = {"train": train_loader, "valid": val_loader}

# 3) Let's look at the data and check it

In [None]:
print(f'Length of the train data: {len(train_loader)*BATCH_SIZE} images')
print(f'Length of the validation_data: {len(val_loader)*BATCH_SIZE} images')
print(f'Length of the test data: {len(test_loader)*BATCH_SIZE} images')

In [None]:
train_image, train_mask = next(iter(train_loader))
val_image, val_mask = next(iter(val_loader))
print(f'Shape of input images:\n train -  {train_image.shape},\n val -  {val_image.shape}')

In [None]:
print("Check that the mask images in [0,1] range")
print(f'Shape of input masks:\n train -  {train_mask.shape},\n val -  {val_mask.shape}')
print(f'Train mask values: \n max - {train_mask.max()} \n min - {train_mask.min()}')
print(f'Validate mask values: \n max - {val_mask.max()} \n min - {val_mask.min()}')

In [None]:
fig, axs = plt.subplots(2,4,figsize=(15,15))

for batch in val_loader:
    images,mask_target = batch
    for i in range(len(images)+len(mask_target)):
        if i <len(images):
            axs[0,i].imshow(images[i].permute(1,2,0).cpu().numpy())
        else:
            axs[1,i%4].imshow(mask_target[i%4].cpu().numpy(),cmap='gray')

# 4) Make a training loop
   Let's take UNET model from catalyst library

In [None]:
from catalyst.contrib.models.cv.segmentation.unet import Unet

model = Unet()

In [None]:
class CustomRunner(dl.Runner):
    def predict_batch(self, batch):
        # model inference step
        return self.model(batch[0].to(self.device))
    def handle_batch(self, batch):
        x, y = batch
        #logits = self.model(x)['out'].squeeze(1)
        logits = self.model(x).squeeze(1)
        binar = torch.sigmoid(logits)
        num_classes = logits.shape[-1]
        self.batch = {
            "features": x,
            "logits": logits,
            "targets": y,
            "binar": binar,
        }

In [None]:
criterion = {
    "dice": DiceLoss(),
    "iou": IoULoss(),
    "bce": BCEWithLogitsLoss()
}
optimizer = torch.optim.Adam(model.parameters())
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
# training
runner = CustomRunner()
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    num_epochs=10,
    callbacks=[
        dl.CriterionCallback(
            input_key="binar",
            target_key="targets",
            metric_key="loss_dice",
            criterion_key="dice",
        ),
        dl.CriterionCallback(
            input_key="binar",
            target_key="targets",
            metric_key="loss_iou",
            criterion_key="iou",
        ),
        dl.CriterionCallback(
            input_key="logits",
            target_key="targets",
            metric_key="loss_bce",
            criterion_key="bce",
        ),
        # loss aggregation
        dl.MetricAggregationCallback(
            metric_key="loss",
            metrics={"loss_dice": 1.0, "loss_iou": 1.0, "loss_bce": 0.8},
            mode="weighted_sum",
        ),
        dl.OptimizerCallback(metric_key="loss"),
    ],
)

In [None]:
fig, axs = plt.subplots(3,4,figsize=(15,15))
model.eval()
for batch in test_loader:
    images,mask = batch
    batch_preds = torch.sigmoid(model(images.to(device)) )
    batch_preds = batch_preds.detach().cpu()  
    for i in range(len(images)+len(mask)+len(images)):
        if i <len(images):
            axs[0,i].imshow(batch_preds[i].squeeze(0).cpu().numpy(),cmap='gray')
        elif i >=len(images) and i<(len(images)+len(mask)):
            axs[1,i%4].imshow(images[i%4].permute(1,2,0).cpu().numpy())
        else:
            axs[2,i%4].imshow(mask[i%4].cpu().numpy(),cmap='gray')

In [None]:
model.eval
with torch.no_grad():
    loss = 0
    criterion_bce = torch.nn.BCEWithLogitsLoss()
    for batch in val_loader:
        image, mask = batch[0].cuda(), batch[1].cuda()
        result = model(image)
        result = result.squeeze(1)
        loss_bce = criterion_bce(result, mask) 
        loss +=  loss_bce.item()* image.size(0)
    epoch_loss = loss / len(val_loader)
    print("--------------------")
    print(epoch_loss)

In [None]:
model.eval
with torch.no_grad():
    loss = 0
    criterion_bce = torch.nn.BCEWithLogitsLoss()
    for batch in test_loader:
        image, mask = batch[0].cuda(), batch[1].cuda()
        result = model(image)
        result = result.squeeze(1)
        loss_bce = criterion_bce(result, mask) 
        loss +=  loss_bce.item()* image.size(0)
    epoch_loss = loss / len(val_loader)
    print("--------------------")
    print(epoch_loss)