In [1]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn

This is EDA of PanNuke dataset (https://jgamper.github.io/PanNukeDataset/). Data is split in 3 folds, each stored as separate .npy file; 19 tissues in total, 5 nuclei classes, ~200,000 labeled nuclei; masks are stored in individual channels (OHE encoded: neoplastic, non-neoplastic epithelial, inflammatory, connective, dead, background); This dataset also contains instance segmentation of each nuclei , which we are not going to use

In [None]:
images = np.load('../data/PanNuke/images/fold1_images.npy')
masks = np.load('../data/PanNuke/masks/fold1_masks.npy')
types = np.load('../data/PanNuke/types/fold1_types.npy')

In [None]:
def visualize_examples(images, masks, types, n_plot = 6):
    f, ax = plt.subplots(n_plot, n_plot, figsize=(4*n_plot, 4*n_plot))
    ax = ax.flatten()
    idx_choice = np.random.choice(images.shape[0], size=n_plot**2)
    for idx, idx_plot in enumerate(idx_choice):
        ax[idx].imshow(images[idx_plot].astype(int))
        ax[idx].imshow(np.argmax(masks[idx_plot].astype(int), axis=2), alpha=0.5, cmap='Accent', vmin=0, vmax=5)
        ax[idx].set_title(types[idx_plot])
        ax[idx].axis('off')
    plt.tight_layout()

In [None]:
visualize_examples(images, masks, types)

Simplest case: we can train the model to segment nuclei into one of the 5 classes (+background), just an example of multiclass semantic segmentation. We can test different models (for example, Unet, Unet++, FPN, Linknet, PSPNet, PAN) with different encoders (different variations of ResNet, EfficientNet, RegNet, ResNest, etc) and different augmentation techniques (D4, scaling, crops, cutouts, cutmix, etc). Testing different LR and optimizers will be too much, so we may just use RAdam + Lookahead, since they perform OK and don't require rigorous LR + scheduler tuning

We will be using Catalyst, since it works OK, easy to use and allows logging results, saving multiple checkpoints, have nice callbacks and so on. 

We start with defining datasets and dataloaders

In [None]:
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict

class PanNukeDataset(Dataset):
    def __init__(
        self,
        images,
        masks,
        types,
            transforms):
        self.images = images
        self.masks = masks
        self.types = types
        self.transforms = transforms

    def __len__(self):
        return(len(self.images))

    def __getitem__(self, idx):
        """Will load the mask, get random coordinates around/with the mask,
        load the image by coordinates
        """
        sample_image = self.images[idx]
        sample_mask = np.argmax(self.masks[idx].astype(int), axis=2)
        #sample_mask = (self.masks[idx] > 0).astype(int)
        sample_type = self.types[idx]
        augmented = self.transforms(image=sample_image, mask=sample_mask)
        sample_image = augmented['image']
        sample_image = sample_image.transpose(2, 0, 1)  # channels first
        #sample_mask = sample_mask.transpose(2, 0, 1)  # channels first
        #sample_mask = np.expand_dims(augmented['mask'], 0)

        data = {'features': torch.from_numpy(sample_image.copy()).float(),
                'mask': torch.from_numpy(sample_mask.copy())}
        return(data)
    
def get_valid_transforms():
    return A.Compose(
        [
            A.Normalize()
        ],
        p=1.0)

def light_training_transforms():
    return A.Compose([
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.Normalize()
    ])

def medium_training_transforms():
    return A.Compose([
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                GridMask(num_grid=6),
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
        A.Normalize()
    ])


def heavy_training_transforms():
    return A.Compose([
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.ElasticTransform(),
                A.GridDistortion(),
                A.OpticalDistortion(),
                A.NoOp(),
                A.ShiftScaleRotate(),
            ], p=1.0),
        A.OneOf(
            [
                A.GaussNoise(),
                A.GaussianBlur(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.CLAHE(),
                A.RGBShift(),
                A.RandomBrightnessContrast(),
                A.RandomGamma(),
                A.HueSaturationValue(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                GridMask(num_grid=6),
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
        A.Normalize()
    ])

def get_training_trasnforms(transforms_type):
    if transforms_type == 'light':
        return(light_training_transforms())
    elif transforms_type == 'medium':
        return(medium_training_transforms())
    elif transforms_type == 'heavy':
        return(heavy_training_transforms())
    else:
        raise NotImplementedError("Not implemented transformation configuration")

In [None]:
# pre-load data
images = np.load('../data/PanNuke/images/fold1_images.npy')
masks = np.load('../data/PanNuke/masks/fold1_masks.npy')
types = np.load('../data/PanNuke/types/fold1_types.npy')


images_val = np.load('../data/PanNuke/images/fold2_images.npy')
masks_val = np.load('../data/PanNuke/masks/fold2_masks.npy')
types_val = np.load('../data/PanNuke/types/fold2_types.npy')

In [None]:
train_dataset = PanNukeDataset(images, masks, types, get_training_trasnforms('light'))
val_dataset = PanNukeDataset(images_val, masks_val, types_val, get_valid_transforms())

loaders = {
    'train': DataLoader(train_dataset, batch_size=8, shuffle=True),
    'valid': DataLoader(val_dataset, batch_size=8, shuffle=False)
}

In [3]:
from pytorch_toolbelt.losses import DiceLoss, 
from pytorch_toolbelt.utils.catalyst import IoUMetricsCallback

In [None]:
model = smp.UnetPlusPlus('resnet18', classes=6)
model.cuda()
learning_rate = 0.001
encoder_learning_rate = 0.0005
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)
criterion = {
    "dice": DiceLoss(mode='multiclass'),
    "ce": nn.CrossEntropyLoss()
}

In [None]:
from catalyst.dl import  CriterionCallback, MetricAggregationCallback

callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
       input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_ce",
        criterion_key="ce"
    ),

    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum", 
        metrics={
            "loss_dice": 1.0, 
            "loss_ce": 0.8
        },
    ),

    # metrics
    IoUMetricsCallback(
        mode='multiclass', 
        input_key='mask', 
        class_names=[
            'neoplastic', 
            'non-neoplastic epithelial', 
            'inflammatory', 
            'connective',
            'dead',
            'background'
        ]
    )
    
]

In [None]:
runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='../logs/initial_test',
    num_epochs=3,
    main_metric="loss",
    minimize_metric=True,
    verbose=True,
)

Now we can make a train.py which will run thru all our models\encoders, something like
https://github.com/rwightman/pytorch-image-models/blob/master/train.py