In [1]:
import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp
from torchmetrics import IoU

import albumentations as albu
from albumentations.pytorch import ToTensorV2 as ToTensor

In [2]:
!echo $CUDA_VISIBLE_DEVICES

GPU-f7a950ee-915e-04ac-684b-156afa1e30e1


In [3]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_training_augmentation(preprocessing_fn):
    train_transform = [
        
        albu.Resize(224, 224),
        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        #albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        #albu.RandomCrop(height=128, width=128, always_apply=True),

        albu.GaussNoise(p=0.2),
        albu.Perspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Sharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
        albu.Lambda(image=preprocessing_fn),
        #albu.Lambda(image=to_tensor, mask=to_tensor),
        albu.Normalize(),
        ToTensor(transpose_mask=False),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation(preprocessing_fn):
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(224, 224),
        #albu.PadIfNeeded(384, 480),
        albu.Lambda(image=preprocessing_fn),
        #albu.Lambda(image=to_tensor, mask=to_tensor),
        albu.Normalize(),
        ToTensor(transpose_mask=False),
    ]
    return albu.Compose(test_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [4]:
class TikTokDataset(Dataset):
    def __init__(
            self, 
            images_fps, 
            masks_fps, 
            augmentation=None, 
            #preprocessing=None,
    ):
        self.images_fps = images_fps
        self.masks_fps = masks_fps
        self.augmentation = augmentation
        #self.preprocessing = preprocessing
    
    def __getitem__(self, idx):        
        # read data
        image = cv2.imread(self.images_fps[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[idx], cv2.IMREAD_GRAYSCALE)
        mask = mask / 255.
        #sample = {'image': image, 'mask': mask}
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image = sample["image"]
            mask = sample["mask"]
             
        return image, mask
        
    def __len__(self):
        return len(self.images_fps)

In [5]:
class TikTokDataModule(pl.LightningDataModule):
    def __init__(
        self,
        images_dir,
        masks_dir,
        batch_size,
        train_augs=None,
        val_augs=None,
        test_augs=None,
        #preprocessing=None
    ):
        super().__init__()
        
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        self.batch_size = batch_size
        self.train_augs = train_augs
        self.val_augs = val_augs
        self.test_augs = test_augs
        #self.preprocessing = preprocessing
        
    def setup(self, stage=None):
        
        train_imgs, test_imgs, train_masks, test_masks = train_test_split(self.images_fps, self.masks_fps, test_size=0.1)
        train_imgs, val_imgs, train_masks, val_masks = train_test_split(train_imgs, train_masks, test_size=0.2)
        
        self.train_data = TikTokDataset(train_imgs, train_masks, self.train_augs)
        self.val_data = TikTokDataset(val_imgs, val_masks, self.val_augs)
        self.test_data = TikTokDataset(test_imgs, test_masks, self.test_augs)
        
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=8)
        
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=8)
    
    def test_dataloder(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=8)

In [6]:
class SegmentationModel(pl.LightningModule):
    
    def __init__(
        self,
        encoder,
        encoder_weights,
        num_classes,
        activation,
        criterion,
        metrics,
        learning_rate,
    ):
        super(SegmentationModel, self).__init__()
        
        self.model = smp.FPN(
                encoder_name=encoder, 
                encoder_weights=encoder_weights, 
                classes=num_classes,
                activation=activation,
                )
        
        self.loss = criterion
        self.metrics = metrics               
        self.learning_rate = learning_rate
        
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate) 
    
    def training_step(self, train_batch, batch_idx):     
        # Defining training steps for our model
        x, y = train_batch
        y = y.unsqueeze(1)
        
        logits = self.forward(x)
        loss = self.loss(logits, y)
              
        score = self.metrics(logits, y)
 
        logs = {'valid_loss': loss, 'valid_metrics': score}
    
        self.log_dict(logs, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, valid_batch, batch_idx):         
        # Defining validation steps for our model
        x, y = valid_batch 
        y = y.unsqueeze(1)
        
        logits = self.forward(x)
        loss = self.loss(logits, y)
        
        score = self.metrics(logits, y)
 
        logs = {'valid_loss': loss, 'valid_metrics': score}
    
        self.log_dict(logs, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

In [7]:
IMAGES_DIR = "../data/dataset/images"
MASKS_DIR = "../data/dataset/masks"

preprocessing_fn = smp.encoders.get_preprocessing_fn("se_resnext50_32x4d", "imagenet")

metrics = smp.utils.metrics.IoU(threshold=0.5)

In [8]:
data = TikTokDataModule(
    IMAGES_DIR,
    MASKS_DIR,
    batch_size=16,
    train_augs=get_training_augmentation(preprocessing_fn),
    val_augs=get_validation_augmentation(preprocessing_fn),
    test_augs=get_validation_augmentation(preprocessing_fn),
)

model = SegmentationModel(
    encoder="se_resnext50_32x4d",
    encoder_weights="imagenet",
    num_classes=1,
    activation="sigmoid",
    criterion=smp.utils.losses.DiceLoss(),
    metrics=metrics,
    learning_rate=0.0001,
)



In [9]:
trainer = pl.Trainer(gpus=1, max_epochs=30)
trainer.fit(model, data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-f7a950ee-915e-04ac-684b-156afa1e30e1]

  | Name    | Type     | Params
-------------------------------------
0 | model   | FPN      | 28.1 M
1 | loss    | DiceLoss | 0     
2 | metrics | IoU      | 0     
-------------------------------------
28.1 M    Trainable params
0         Non-trainable params
28.1 M    Total params
112.476   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]