In [None]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn
from skimage.io import imread
import os
from sklearn.model_selection import train_test_split

This is EDA of Chest XRay dataset (https://www.kaggle.com/kmader/pulmonary-chest-xray-abnormalities/home?select=Montgomery + https://www.kaggle.com/yoctoman/shcxr-lung-mask)

In [None]:
def visualize_examples(images, masks, n_plot = 6):
    f, ax = plt.subplots(n_plot, n_plot, figsize=(4*n_plot, 4*n_plot))
    ax = ax.flatten()
    idx_choice = np.random.choice(images.shape[0], size=n_plot**2)
    for idx, idx_plot in enumerate(idx_choice):
        image = imread(images[idx])
        mask = imread(masks[idx])
        ax[idx].imshow(image)
        ax[idx].imshow(mask, alpha=0.5)
        ax[idx].axis('off')
    plt.tight_layout()

In [None]:
images_dir = '/data/personal_folders/skolchenko/segmentation_benchmark/ChestXray_prepared/masks/'
masks_dir = '/data/personal_folders/skolchenko/segmentation_benchmark/ChestXray_prepared/images/'

In [None]:
images = os.listdir(images_dir)
masks = np.array([masks_dir+image_path for image_path in images])
images = np.array([images_dir+image_path for image_path in images])

In [None]:
#visualize_examples(images, masks, n_plot = 6)

We immediatly see that data is heterogeneous, with different provides, formats, scale 

In [None]:
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict

class ChestXRayDataset(Dataset):
    def __init__(
        self,
        images,
        masks,
            transforms):
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self):
        return(len(self.images))

    def __getitem__(self, idx):
        """Will load the mask, get random coordinates around/with the mask,
        load the image by coordinates
        """
        sample_image = imread(self.images[idx])
        sample_image = np.expand_dims(sample_image, 2) / 255
        sample_mask = imread(self.masks[idx]) / 255
        augmented = self.transforms(image=sample_image, mask=sample_mask)
        #augmented = self.transforms(image=sample_image)
        sample_image = augmented['image']
        sample_mask = augmented['mask']  
        sample_image = sample_image.transpose(2, 0, 1)  # channels first
        sample_mask = np.expand_dims(sample_mask, 0)
        #sample_mask = sample_mask.transpose(2, 0, 1) 

        data = {'features': torch.from_numpy(sample_image.copy()).float(),
                'mask': torch.from_numpy(sample_mask.copy())}
        return(data)
    
def get_valid_transforms(crop_size=256):
    return A.Compose(
        [
            A.Resize(crop_size, crop_size),
        ],
        p=1.0)

def light_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
    ])

def medium_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])


def heavy_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])

def get_training_trasnforms(transforms_type):
    if transforms_type == 'light':
        return(light_training_transforms())
    elif transforms_type == 'medium':
        return(medium_training_transforms())
    elif transforms_type == 'heavy':
        return(heavy_training_transforms())
    else:
        raise NotImplementedError("Not implemented transformation configuration")

In [None]:
images_train, images_valid, masks_train, masks_valid = train_test_split(images, masks, test_size=0.25, random_state=42)

In [None]:
train_dataset = ChestXRayDataset(images_train, masks_train, get_training_trasnforms('heavy'))
val_dataset = ChestXRayDataset(images_valid, masks_valid, get_valid_transforms())

loaders = {
    'train': DataLoader(train_dataset, batch_size=2, shuffle=True),
    'valid': DataLoader(val_dataset, batch_size=2, shuffle=False)
}

In [None]:
sample = next(iter(loaders['train']))

In [None]:
plt.imshow(sample['features'][0].cpu().numpy().transpose((1,2,0))[..., 0])
plt.imshow(sample['mask'][0].cpu().numpy()[0], alpha=0.75)

In [None]:
from pytorch_toolbelt.losses import DiceLoss
from pytorch_toolbelt.utils.catalyst import IoUMetricsCallback

In [None]:
model = smp.UnetPlusPlus('timm-regnety_004', classes=1, in_channels=1)
#model.cuda()
learning_rate = 5e-3
encoder_learning_rate = 5e-3 / 10
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
criterion = {
    "dice": DiceLoss(mode='binary'),
    "bce": nn.BCEWithLogitsLoss()
}

In [None]:
from catalyst.dl import  CriterionCallback, MetricAggregationCallback

callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
       input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_bce",
        criterion_key="bce"
    ),

    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum", 
        metrics={
            "loss_dice": 1.0, 
            "loss_bce": 0.8
        },
    ),

    # metrics
    IoUMetricsCallback(
        mode='binary', 
        input_key='mask', 
    )
    
]

In [None]:
runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='../logs/initial_test_xray',
    num_epochs=100,
    main_metric="loss",
    minimize_metric=True,
    verbose=True,
)