# POLYP SEGMENTATION

## The task
Given an endoscopic image, we have to predict a binary mask where 1's pixels are the polyps' regions.

## Assumptions
Let the mask be a random variable $M$ and the input image is also a random variable $X$. To perform the segmentation task, we estimate the conditional probablity $P(M=m|X=x)$. Due to the assumption that the mask is binary, $P(M=m|X=x)$ is a Bernoulli distribution, which means that $m|x \sim \text{Ber}(\lambda(x))$ where $\lambda(x)$ is modeled using a neural network $G_\theta(x)$ ($\theta$ is a set of the network's parameters). In conclusion, we have to estimate $\text{Ber}(m|G_\theta(x))$. Another assumption is that the elements in an arbitrary $m^{<i>}$ are independent for the sake of calculating the later join probabilities.

## Methodology
Given a set of training data including endoscopic images $(x^{<1>}, x^{<2>}, ..., x^{<N>})$ and target masks $(m^{<1>}, m^{<2>}, ..., m^{<N>})$ (any $x^{<i>}$ or $m^{<i>}$ is vector-valued), our goal is maximizing the conditional log-likelihood of this training dataset:<br>
<center>
    <div style="display: inline-block; text-align: left;">
        $\theta^* = \text{argmax}_\theta\ \sum_{i=1}^N \log P(m^{<i>}|x^{<i>})$<br>
        $\ \ \ \  = \text{argmax}_\theta\ \sum_{i=1}^N \log \text{Ber}(m^{<i>};G_\theta(x^{<i>}))$<br>
    </div>
</center>

We have assumed that the elements in an arbitrary $m^{<i>}$ are independent. Therefore:<br>
<center>
    <div style="display: inline-block; text-align: left;">
        $\theta^* = \text{argmax}_\theta\ \sum_{i=1}^N \sum_{k=1}^K \log \text{Ber}(m^{<i>}_k;G_\theta(x^{<i>})_k)$ <br>
        $\ \ \ \ = \text{argmax}_\theta\ \sum_{i=1}^N \sum_{k=1}^K \log G_\theta(x^{<i>})_k^{m^{<i>}_k} (1 - G_\theta(x^{<i>})_k)^{1 - m^{<i>}_k}$<br>
        $\ \ \ \ = \text{argmax}_\theta\ \sum_{i=1}^N \sum_{k=1}^K {m^{<i>}_k}\log G_\theta(x^{<i>})_k - ({1 - m^{<i>}_k})\log(1 - G_\theta(x^{<i>})_k)$<br>
    </div>
</center>

It turns out that we have to optimize a binary cross-entropy loss function.<br>
In the code below, the neural network $G_\theta(x)$ is modeled as a U-Net.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def plot_images(imgs, n_rows=8):    
    imgs = tv.utils.make_grid(imgs, n_rows, padding=4)
    plt.figure(figsize=(12, 12))
    plt.imshow(imgs.permute(1, 2, 0))
    plt.show()

In [None]:
class ImagePairDataset(Dataset):
    def __init__(self, img_dir, label_dir, img_size):
        super().__init__()
        self.img_paths = [
            os.path.join(img_dir, path)
            for path in os.listdir(img_dir)
            if path.endswith('.jpg')
            or path.endswith('.jpeg')
            or path.endswith('.png')
        ]
        self.label_paths = [
            os.path.join(label_dir, path)
            for path in os.listdir(label_dir)
            if path.endswith('.jpg')
            or path.endswith('.jpeg')
            or path.endswith('.png')
        ]
        assert len(self.img_paths) == len(self.label_paths)
        self.resize = transforms.Resize(img_size, antialias=True)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img = tv.io.read_image(self.img_paths[index])
        img = self.resize(img) / 255.0
        label = tv.io.read_image(self.label_paths[index])
        label = self.resize(label)[0] / 255.0
        return img, label.round()

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x

In [None]:
class DeconvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels,
                                         kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.deconv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dropout=None):
        super().__init__()
        self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
        self.conv_block_1 = ConvBlock(in_channels, out_channels, kernel_size, padding='same')
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding='same')
        self.norm_2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)
        if dropout:
            self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        skip = self.res_conv(x)
        x = self.conv_block_1(x)
        x = self.conv_2(x) + skip
        x = self.norm_2(x)
        x = self.relu(x)
        if hasattr(self, 'dropout'):
            x = self.dropout(x)
        return x

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dropout=None):
        super().__init__()
        self.res_block = ResidualBlock(in_channels, out_channels, kernel_size, dropout)
        self.maxpool = nn.MaxPool2d(2)
    
    def forward(self, x):
        x = self.maxpool(x)
        x = self.res_block(x)
        return x

In [None]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dropout=None):
        super().__init__()
        self.res_block = ResidualBlock(in_channels, out_channels, kernel_size, dropout)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
    
    def forward(self, x):        
        x = self.res_block(x)
        x = self.upsample(x)
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.downsample = nn.ModuleList([
            ResidualBlock(3, 64, 3, 0.1),
            DownBlock(64, 128, 3, 0.1),
            DownBlock(128, 256, 3, 0.1),
            DownBlock(256, 512, 3, 0.1),
            DownBlock(512, 1024, 3, 0.1),
        ])
        
        self.upsample = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='bilinear'),
            UpBlock(1024 + 512, 512, 3, 0.1),
            UpBlock(512 + 256, 256, 3, 0.1),
            UpBlock(256 + 128, 128, 3, 0.1),
        ])
        
        self.last_blocks = nn.ModuleList([
            ResidualBlock(128 + 64, 64, 3, 0.1),
            nn.Conv2d(64, 1, 3, padding='same'),
            nn.Sigmoid(),
        ])
    
    def forward(self, x):
        skips = []
        for block in self.downsample:
            x = block(x)
            skips.append(x)
        skips = reversed(skips[:-1])
        for block, skip in zip(self.upsample, skips):
            x = block(x)
            x = torch.concat([x, skip], dim=1)
        for block in self.last_blocks:
            x = block(x)
        return x.squeeze()

In [None]:
@torch.no_grad()
def evaluate(val_data_loader):
    net.eval()
    N = len(val_data_loader)
    loss = 0
    acc = 0
    for images, labels in val_data_loader:
        images = images.to(device)
        labels = labels.to(device)
        pred = net(images)
        loss += F.binary_cross_entropy(pred, labels)
        acc += (pred.round() == labels).sum() / labels.numel()
    loss /= N
    acc /= N
    return loss.item(), acc.item()


@torch.no_grad()
def evaluate_batch(images, labels):
    net.eval()
    pred = net(images)
    loss = F.binary_cross_entropy(pred, labels)
    acc = (pred.round() == labels).sum() / labels.numel()
    return loss.item(), acc.item()

In [None]:
@torch.no_grad()
def test(images, labels):
    net.eval()
    loss, acc = evaluate_batch(images, labels)
    print(f'loss: {loss:.8f} - acc: {acc:.8f}')
    mask_pred = net(images)
    plot_images(images.cpu())
    plot_images((images * 0.4 + 0.6 * mask_pred.unsqueeze(1)).cpu())
    print('Predicted masks:')
    plot_images(mask_pred.unsqueeze(1).cpu())
    print('Ground truth:')
    plot_images(labels.unsqueeze(1).cpu())

In [None]:
net = UNet()
net = nn.DataParallel(net)
net.to(device)
optimizer = torch.optim.Adam(net.parameters())

In [None]:
train_ds = ImagePairDataset(
    img_dir='/kaggle/input/intern-data/Kvasir_SEG_Training_880/Kvasir_SEG_Training_880/image',
    label_dir='/kaggle/input/intern-data/Kvasir_SEG_Training_880/Kvasir_SEG_Training_880/mask',
    img_size=(128, 128)
)
val_ds = ImagePairDataset(
    img_dir='/kaggle/input/intern-data/Kvasir_SEG_Validation_120/Kvasir_SEG_Validation_120/images',
    label_dir='/kaggle/input/intern-data/Kvasir_SEG_Validation_120/Kvasir_SEG_Validation_120/masks',
    img_size=(128, 128)
)

train_data_loader = DataLoader(train_ds, batch_size=128, shuffle=True, prefetch_factor=2, num_workers=2)
val_data_loader = DataLoader(val_ds, batch_size=32, shuffle=True)

In [None]:
EPOCHS = 30
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}
best_weight_state = None

for epoch in range(1, 1 + EPOCHS):
    for images, labels in (bar := tqdm(train_data_loader)):
        net.train()
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        pred = net(images)
        loss = F.binary_cross_entropy(pred, labels)
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            train_acc = (pred.detach().round() == labels).sum()
            train_acc = train_acc.float() / labels.numel()
            history['train_acc'].append(train_acc.item())
            train_loss = loss.detach().item()
            history['train_loss'].append(train_loss)
        
        val_loss, val_acc = evaluate(val_data_loader)
        if all(val_acc > acc for acc in history['val_acc']):
            best_weight_state = (net.state_dict(),
                                 {'train_acc': train_acc.item(),
                                  'val_acc': val_acc,
                                  'epoch': epoch})
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        bar.set_description(
            ' - '.join([
                f'epoch: {epoch}/{EPOCHS}',
                f'train_loss: {train_loss:.8f}',
                f'train_acc: {train_acc:.8f}',
                f'val_loss: {val_loss:.8f}',
                f'val_acc: {val_acc:.8f}',
            ])
        )
print(f'The best weight state: \n{best_weight_state[1]}')

In [None]:
val_iter = iter(val_data_loader)
net.load_state_dict(best_weight_state[0])

images, labels = next(val_iter)
images = images.to(device)
labels = labels.to(device)
test(images[:16], labels[:16])

In [None]:
torch.save(best_weight_state[0], 'best_weights.pt')
torch.save(net.module.state_dict(), 'unet_128.pt')