## Image Segmentation for Medical diagnosis using PyTorch

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import glob
import json
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import cv2
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torchinfo import summary

### Downlaod the CVC clinic dataset and extract

In [2]:
# !curl -L -o ~/Downloads/cvcclinicdb.zip\
#   https://www.kaggle.com/api/v1/datasets/download/balraj98/cvcclinicdb
# !mkdir data
# !unzip ~/Downloads/cvcclinicdb.zip -d data

In [3]:
with open("./config.json","r") as f:
    config = json.load(f)
config
image_path = config['data_path']+"Original/"
mask_path  = config['data_path']+"Ground Truth/"

In [4]:
images = glob.glob(image_path+"*"+config['file_extn'])
masks  = glob.glob(mask_path+"*"+config['file_extn'])

## Create CVC custom dataset and pre process

In [5]:
"""
Make CVC dataset
"""

class CVCDataset(Dataset):
    def __init__(self, images, transform=None):
        super(CVCDataset, self).__init__()
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img_split = img_path.split('/')
        img_name = img_split[-1]
        img_dir = '/'.join(img_split[:-2])
        mask_path = img_dir+"/Ground Truth/"+img_name
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


### Create train and test set

In [6]:
image_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    #transforms.Normalize()
])

mask_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

imgset = CVCDataset(images, image_transforms)
trainset, testset = random_split(imgset, [0.9, 0.1])
batch_size = 16
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, drop_last=True)

In [7]:
""" Parts of the U-Net model """

class DoubleConv(nn.Module):
    """(conv => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """ Downscaling with maxpool """

    def __init__(self):
        super(Down,self).__init__()
        self.down = nn.MaxPool2d(2)
    
    def forward(self, x):
        return self.down(x)


class DownDoubleConv(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super(DownDoubleConv, self).__init__()
        self.down_doubleconv = nn.Sequential(
            Down(),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down_doubleconv(x)


class UpDoubleConv(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super(UpDoubleConv, self).__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.double_conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.double_conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is BCHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.double_conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [8]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False, checkpointing=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.checkpointing = checkpointing

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (DownDoubleConv(64, 128))
        self.down2 = (DownDoubleConv(128, 256))
        self.down3 = (DownDoubleConv(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (DownDoubleConv(512, 1024 // factor))
        self.up1 = (UpDoubleConv(1024, 512 // factor, bilinear))
        self.up2 = (UpDoubleConv(512, 256 // factor, bilinear))
        self.up3 = (UpDoubleConv(256, 128 // factor, bilinear))
        self.up4 = (UpDoubleConv(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        if self.checkpointing:
            x1 = checkpoint(self.inc, x)
            x2 = checkpoint(self.down1, x1)
            x3 = checkpoint(self.down2, x2)
            x4 = checkpoint(self.down3, x3)
            x5 = checkpoint(self.down4, x4)
            x = checkpoint(self.up1, x5,x4)
            x = checkpoint(self.up2, x,x3)
            x = checkpoint(self.up3, x,x2)
            x = checkpoint(self.up4, x,x1)
            logits = checkpoint(self.outc, x)
            return logits

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [9]:
# test model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_input = torch.randn((64, 3, 256, 256)).to(device)
test_model = UNet(3,1,checkpointing=True).to(device)
print(summary(test_model))
torch.cuda.reset_peak_memory_stats()
test_out = test_model(test_input)
print(f"Peak Memory Usage With Checkpointing: {torch.cuda.max_memory_allocated()/1e6} MB")
print(test_out.shape)

Layer (type:depth-idx)                        Param #
UNet                                          --
├─DoubleConv: 1-1                             --
│    └─Sequential: 2-1                        --
│    │    └─Conv2d: 3-1                       1,728
│    │    └─BatchNorm2d: 3-2                  128
│    │    └─ReLU: 3-3                         --
│    │    └─Conv2d: 3-4                       36,864
│    │    └─BatchNorm2d: 3-5                  128
│    │    └─ReLU: 3-6                         --
├─DownDoubleConv: 1-2                         --
│    └─Sequential: 2-2                        --
│    │    └─Down: 3-7                         --
│    │    └─DoubleConv: 3-8                   221,696
├─DownDoubleConv: 1-3                         --
│    └─Sequential: 2-3                        --
│    │    └─Down: 3-9                         --
│    │    └─DoubleConv: 3-10                  885,760
├─DownDoubleConv: 1-4                         --
│    └─Sequential: 2-4                       

  return fn(*args, **kwargs)


Peak Memory Usage With Checkpointing: 9236.502016 MB
torch.Size([64, 1, 256, 256])


In [None]:
# BCHW
im1 = torch.randn(64,3,10,14)
im1 = F.pad(im1, [1,1,2,2])
im1.shape