# Experiments with semantic segmentation 

Faisal Qureshi      
faisal.qureshi@ontariotechu.ca

Setting up a semantic segmentation pipeline in PyTorch for Oxford IIIT Pet Dataset.  

## Readme

- The goal is to learn PyTorch Lightening Package and Albumentations Transformation Package.
- Instal segmentation_models_pytorch as `pip3 install git+https://github.com/qubvel/segmentation_models.pytorch@8bf52c7e862af006e76a23aae6aa17977d7f9a79`.  This code may not work with other versions of segmentation_models_pytorch module.

In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import numpy as np
import albumentations as albu
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
import tqdm
import pickle
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
import copy
from pprint import pprint

## The dataset

I downloaded the dataset from [https://www.robots.ox.ac.uk/~vgg/data/pets/](https://www.robots.ox.ac.uk/~vgg/data/pets/).

In [None]:
DATA_FOLDER = '../../data/oxford-3t-pet-dataset'

In [None]:
images = os.listdir(os.path.join(DATA_FOLDER, 'images'))
trimaps = os.listdir(os.path.join(DATA_FOLDER, 'annotations/trimaps'))

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
i = images[0]
img = cv2.imread(os.path.join(DATA_FOLDER, 'images', i))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(os.path.join(DATA_FOLDER, 'annotations/trimaps', os.path.splitext(i)[0]+'.png'), 0)

visualize(image=img, mask=mask)

### Constructing the Dataset Object

The dataset uses albumentations augmentations library.

In [None]:
class Dataset(BaseDataset):
    def __init__(
        self,
        images_dir,
        masks_dir,
        class_values = None,
        augmentation = None,
        preprocessing = None
    ):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.class_values = class_values
        self.num_classes = len(class_values)
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.idx = np.array([])
        
        try:
            with open('oxford-3t-pets-valid-files.pkl', 'rb') as f:
                self.files = pickle.load(f)['files']
        except:
            # Getting rid of unreadable image files
            images = [os.path.splitext(i)[0] for i in os.listdir(images_dir) if os.path.splitext(i)[1] == '.jpg']
            images_ = []
            print('Checking image files...')
            with tqdm.tqdm(total=len(images), position=0, leave=True) as pbar:
                for i in tqdm.tqdm(range(len(images)), position=0, leave=True):
                    filename = images[i]
                    try:
                        image = cv2.imread(os.path.join(self.images_dir, filename+'.jpg'))
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                        images_.append(filename)
                    except:
                        pass
                    pbar.update()
            images = images_
            # print(len(images))
            
            # Getting rid of unreadable mask files
            masks = [os.path.splitext(i)[0] for i in os.listdir(masks_dir) if os.path.splitext(i)[1] == '.png']
            masks_ = []
            print('Checking mask files')
            with tqdm.tqdm(total=len(masks), position=0, leave=True) as pbar:
                for i in tqdm.tqdm(range(len(masks)), position=0, leave=True):
                    filename = masks[i]
                    try:
                        mask = cv2.imread(os.path.join(self.masks_dir, filename+'.png'), 0)
                        masks_.append(filename)
                    except:
                        pass
                    pbar.update()
            masks = masks_
            # print(len(masks))

            # Selecting image/mask pairs - this avoids the situation of having an image, but no mask, or having
            # a mask, but no image.
            self.files = list(set(images).intersection(set(masks)))
            with open('oxford-3t-pets-valid-files.pkl', 'wb') as f:
                pickle.dump({'files': self.files}, f)
            print('Saved oxford-3t-pets-valid-files.pkl')
            
        # print(f'# image/Mask pairs {len(self.files)}')
    
    def __getitem__(self, index):
        i = index if len(self.idx) == 0 else self.idx[index]
        
        filename = self.files[i]
        image = cv2.imread(os.path.join(self.images_dir, filename+'.jpg'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.masks_dir, filename+'.png'), 0)
        
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        return image, mask
    
    def __len__(self):
        return len(self.files) if len(self.idx) == 0 else len(self.idx)


            'loss': loss,Checking if the dataset works without any augmentations

In [None]:
IMAGE_FOLDER = os.path.join(DATA_FOLDER, 'images')
MASK_FOLDER = os.path.join(DATA_FOLDER, 'annotations/trimaps')
CLASS_VALUES = [1,2,3]  # 1: foreground, 2: background, and 3: not classified
                        # These are mapped to 0, 1, 2 channels, respectively.

dataset = Dataset(IMAGE_FOLDER, MASK_FOLDER, CLASS_VALUES)
print(f'items in dataset = {len(dataset)}')

In [None]:
image, mask = dataset[34]
visualize(image=image, mask=mask)

#### Looping over all data items to catch any issues

In [None]:
# for i in range(len(dataset)):
#     try:
#         image, mask = dataset[i]
#     except:
#         print(i)

### Defining Augmentations

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),
        albu.GaussNoise(p=0.2),
        albu.IAAPerspective(p=0.5),
        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.LongestMaxSize(320),
        albu.PadIfNeeded(min_height=320, min_width=320)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

#### Checking out augmented dataset.  

In [None]:
augmented_dataset = Dataset(IMAGE_FOLDER, MASK_FOLDER, CLASS_VALUES, get_training_augmentation())

In [None]:
for i in range(3):
    image, mask = augmented_dataset[1] # pick the same item three times
    visualize(image=image, mask=mask)

## Model

Now we construct a semantic segmentation model.  We will use `segmentation_model_pytorch` library to setup our model.

In [None]:
ENCODER = 'se_resnext50_32x4d' 
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax2d'
NUM_CLASSES = 3
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Select the appropriate preprocessing function, which are used to normalize the images correctly for a particular model.  The following will be used within the dataset object.

In [None]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
print(preprocessing_fn)

### Constructing training, validation, and test datasets

Since the dataset objects use different augmentations for training and validation/test, we will split the dataset without augmentations first and then we will set the augmentations accordingly.

In [None]:
dataset = Dataset(
    IMAGE_FOLDER, 
    MASK_FOLDER, 
    CLASS_VALUES, 
    preprocessing=get_preprocessing(preprocessing_fn)
)

n_total = len(dataset)
n_train = int(0.8*n_total)
n_valid = int(0.2*n_total)

idx = np.arange(n_total)
np.random.shuffle(idx)

valid_idx = idx[:n_valid]
train_idx = idx[n_valid:n_train]
test_idx = idx[n_train:]

train_dataset = copy.deepcopy(dataset)
train_dataset.idx = train_idx
train_dataset.augmentation = get_training_augmentation()

valid_dataset = copy.deepcopy(dataset)
valid_dataset.idx = valid_idx
valid_dataset.augmentation = get_validation_augmentation()

test_dataset = copy.deepcopy(dataset)
test_dataset.idx = test_idx
test_dataset.augmentation = get_validation_augmentation()

print(f'dataset = {len(dataset)}')
print(f'train dataset = {len(train_dataset)}')
print(f'valid dataset = {len(valid_dataset)}')
print(f'test dataset = {len(test_dataset)}')

### Set up the training, validation and test dataloader

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8)

### Model, Loss, error metrics

PetModel is dervied from pl.LightningModule

In [None]:
class PetModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()

        self.model = smp.FPN(
            encoder_name = ENCODER,
            encoder_weights = ENCODER_WEIGHTS,
            classes = NUM_CLASSES,
            activation = ACTIVATION
        )
            'loss': loss,
        
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=False)

    def forward(self, x):
        mask = self.model(x)
        return mask
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def step(self, stage, batch, batch_idx):
        x, y = batch # x is image and y is mask
        yhat = self.forward(x)
        loss = self.loss_fn(yhat, y)
        
        tp, fp, fn, tn = smp.metrics.get_stats(yhat.long(), y.long(), mode='binary', threshold=0.5)
        
        return {
            'loss': loss, # this is required by PyTorch Lightning framework.
            'tp': tp, 
            'fp': fp,
            'fn': fn,
            'tn': tn            
        }
        
    def step_end(self, stage, outputs):
        pass
        
    def epoch_end(self, stage, outputs):      
        # aggregate step metrics
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        tp = torch.cat([x['tp'] for x in outputs])
        fp = torch.cat([x['fp'] for x in outputs])
        fn = torch.cat([x['fn'] for x in outputs])
        tn = torch.cat([x['tn'] for x in outputs])
        
        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
  
        metrics = {
            f'{stage}_loss': loss,
            f'{stage}_per_image_iou': per_image_iou,
            f'{stage}_dataset_iou': dataset_iou,
        }
        
        self.log_dict(metrics)
        
        # self.logger.experiment.add_scalar(f'loss/{stage}', loss, self.current_epoch)
    
    # train
    def training_step(self, batch, batch_idx):
        return self.step('train', batch, batch_idx)
    
    def training_step_end(self, outputs):
        return self.step_end('train', outputs)
    
    def training_epoch_end(self, outputs):
        return self.epoch_end('train', outputs)
    
    # test
    def test_step(self, batch, batch_idx):
        return self.step('test', batch, batch_idx)
    
    def test_step_end(self, outputs):
        return self.step_end('test', outputs)
    
    def test_epoch_end(self, outputs):
        return epoch_end('test', outputs)
    
    # evaluate
    def validation_step(self, batch, batch_idx):
        return self.step('valid', batch, batch_idx)

    def validation_step_end(self, outputs):
        return self.step_end('valid', outputs)
    
    def validation_epoch_end(self, outputs):
        return self.epoch_end('valid', outputs)

## Training

In [None]:
checkpoint_callback = ModelCheckpoint(monitor='loss', dirpath='./ckpt')
logger = TensorBoardLogger('tb_logs', name='pet_model_1')

model = PetModel()
trainer = pl.Trainer(gpus=1, max_epochs=5, logger=logger, callbacks=[checkpoint_callback])
trainer.fit(
    model=model,
    train_dataloaders = train_dataloader,
    val_dataloaders = valid_dataloader
)

In [None]:
checkpoint_callback.best_model_path

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=10, logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader, ckpt_path=checkpoint_callback.best_model_path)

## Validation and test metrics

In [None]:
valid_metrics = trainer.validate(model, dataloaders=valid_dataloader, verbose=False)
pprint(valid_metrics)

In [None]:
test_metrics = trainer.validate(model, dataloaders=test_dataloader, verbose=False)
pprint(test_metrics)

## Visualizing results

In [None]:
images, gt_masks = next(iter(test_dataloader))
with torch.no_grad():
    model.eval()
    masks = model(images)
    
for image, gt_mask, mask in zip(images, gt_masks, masks):
    plt.figure(figsize=(10,5))
    plt.subplot(1,3,1)
    plt.imshow(image.numpy().transpose(1,2,0))
    plt.axis('off')
    plt.title('image')
    plt.subplot(1,3,2)
    plt.imshow(gt_mask.numpy().transpose(1,2,0))
    plt.axis('off')
    plt.title('gt mask')
    plt.subplot(1,3,3)
    plt.imshow(mask.numpy().transpose(1,2,0))
    plt.axis('off')
    plt.title('mask')
    plt.show()