# Lung segmentation

In [None]:
import glob

import torch
from torch import nn, distributions
import numpy as np

import cv2
from skimage import io, transform, exposure, color, morphology, measure, img_as_ubyte

In [None]:
DATA_PATH = '../data/'
MODEL_PATH = './uVAE.pt'
HIDDEN = 16
LATENT = 8
PADDING = 32
POST_PROCESS = True
PREDICTIONS_DIR = '../data/segmented/'
MASKS_DIR = '../data/masks/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## lungVAE implementation

To segment lungs, model introduced in [*Lung Segmentation from Chest X-rays using Variational Data Imputation*](https://arxiv.org/abs/2005.10052) by Raghavendra Selvan et al.

The full implementation is available [here](https://github.com/raghavian/lungVAE)

In [None]:
class uVAE(nn.Module):
    def __init__(self, nlatent, unet=False,
                 nhid=8, ker=3, inCh=1, h=640, w=512):
        super(uVAE, self).__init__()
        self.latent_space = nlatent
        self.unet = unet

        if not self.unet:
            self.enc11 = nn.Conv2d(inCh, nhid, kernel_size=ker, padding=1)
            self.enc12 = nn.Conv2d(nhid, nhid, kernel_size=ker, padding=1)

            self.enc2 = self.convBlock(nhid, 2 * nhid, 2 * nhid, pool=True)
            self.enc3 = self.convBlock(2 * nhid, 4 * nhid, 4 * nhid, pool=True)
            self.enc4 = self.convBlock(4 * nhid, 8 * nhid, 8 * nhid, pool=True)
            self.enc5 = self.convBlock(8 * nhid, 16 * nhid, 16 * nhid, pool=True)

            self.bot11 = nn.Conv1d(16 * nhid, 1, kernel_size=1)
            self.bot12 = nn.Conv1d(int((h / 16) * (w / 16)), 2 * nlatent, kernel_size=1)

            self.bot21 = nn.Conv1d(nlatent, int((h / 64) * (w / 64)), kernel_size=1)
            self.bot22 = nn.Conv1d(1, nhid, kernel_size=1)
            self.bot23 = nn.Conv1d(nhid, 4 * nhid, kernel_size=1)
            self.bot24 = nn.Conv1d(4 * nhid, 16 * nhid, kernel_size=1)

        self.uEnc11 = nn.Conv2d(inCh, nhid, kernel_size=ker, padding=1)
        self.uEnc12 = nn.Conv2d(nhid, nhid, kernel_size=ker, padding=1)

        self.uEnc2 = self.convBlock(nhid, 2 * nhid, 2 * nhid, pool=True, pooling=4)
        self.uEnc3 = self.convBlock(2 * nhid, 4 * nhid, 4 * nhid, pool=True, pooling=4)
        self.uEnc4 = self.convBlock(4 * nhid, 8 * nhid, 8 * nhid, pool=True)
        self.uEnc5 = self.convBlock(8 * nhid, 16 * nhid, 16 * nhid, pool=True)

        if not self.unet:
            self.dec5 = self.convBlock(32 * nhid, 8 * nhid, 8 * nhid, pool=False)
        else:
            self.dec5 = self.convBlock(16 * nhid, 8 * nhid, 8 * nhid, pool=False)

        self.dec4 = self.convBlock(16 * nhid, 4 * nhid, 4 * nhid, pool=False)
        self.dec3 = self.convBlock(8 * nhid, 2 * nhid, 2 * nhid, pool=False, pooling=4)
        self.dec2 = self.convBlock(4 * nhid, nhid, nhid, pool=False, pooling=4)

        self.dec11 = nn.Conv2d(2 * nhid, nhid, kernel_size=ker, padding=1)
        self.dec12 = nn.Conv2d(nhid, inCh, kernel_size=ker, padding=1)

        self.act = nn.ReLU()
        self.mu_0 = torch.zeros((1, nlatent)).to(device)
        self.sigma_0 = torch.ones((1, nlatent)).to(device)

        self.h = h
        self.w = w

    def vae_encoder(self, x):
        x = self.act(self.enc11(x))
        x = self.act(self.enc12(x))
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        x = self.enc5(x)

        z = self.act(self.bot11(x.view(x.shape[0], x.shape[1], -1)))
        z = self.bot12(z.permute(0, 2, 1))

        return z.squeeze(-1)

    def unet_encoder(self, x_in):
        x = []

        x.append(self.act(self.uEnc12(self.act(self.uEnc11(x_in)))))
        x.append(self.uEnc2(x[-1]))
        x.append(self.uEnc3(x[-1]))
        x.append(self.uEnc4(x[-1]))
        x.append(self.uEnc5(x[-1]))

        return x

    def decoder(self, x_enc, z=None):
        if not self.unet:
            x = self.act(self.bot21(z.unsqueeze(2)))
            x = self.act(self.bot22(x.permute(0, 2, 1)))
            x = self.act(self.bot23(x))
            x = self.act(self.bot24(x))

            x = x.view(x.shape[0], x.shape[1],
                       int(self.h / 64), int(self.w / 64))
            x = torch.cat((x, x_enc[-1]), dim=1)
            x = self.dec5(x)
        else:
            x = self.dec5(x_enc[-1])

        x = torch.cat((x, x_enc[-2]), dim=1)
        x = self.dec4(x)
        x = torch.cat((x, x_enc[-3]), dim=1)
        x = self.dec3(x)
        x = torch.cat((x, x_enc[-4]), dim=1)
        x = self.dec2(x)
        x = torch.cat((x, x_enc[-5]), dim=1)

        x = self.act(self.dec11(x))
        x = self.dec12(x)

        return x

    def forward(self, x):
        kl = torch.zeros(1).to(device)
        z = 0.
        x_enc = self.unet_encoder(x)

        if not self.unet:
            emb = self.vae_encoder(x)
            mu, log_var = torch.chunk(emb, 2, dim=1)
            log_var = nn.functional.softplus(log_var)
            sigma = torch.exp(log_var / 2)
            posterior = distributions.Independent(
                distributions.Normal(loc=mu, scale=sigma), 1)
            z = posterior.rsample()
            prior = distributions.Independent(
                distributions.Normal(loc=self.mu_0, scale=self.sigma_0), 1)
            kl = distributions.kl.kl_divergence(posterior, prior).sum()
        xHat = self.decoder(x_enc, z)

        return kl, xHat

    class convBlock(nn.Module):
        def __init__(self, inCh, nhid, nOp, pool=True, ker=3, padding=1, pooling=2):
            super().__init__()
            self.enc1 = nn.Conv2d(inCh, nhid, kernel_size=ker, padding=1)
            self.enc2 = nn.Conv2d(nhid, nOp, kernel_size=ker, padding=1)
            self.bn = nn.BatchNorm2d(inCh)

            if pool:
                self.scale = nn.AvgPool2d(kernel_size=pooling)
            else:
                self.scale = nn.Upsample(scale_factor=pooling)
            self.pool = pool
            self.act = nn.ReLU()

        def forward(self, x):
            x = self.scale(x)
            x = self.bn(x)
            x = self.act(self.enc1(x))
            x = self.act(self.enc2(x))
            return x

In [None]:
def largestCC(lImg, num=2):
    cIdx = np.zeros(num, dtype=int)
    count = np.bincount(lImg.flat)
    count[0] = 0
    lcc = np.zeros(lImg.shape, dtype=bool)
    if len(count) == 2:
        num = 1
    for i in range(num):
        cIdx[i] = np.argmax(count)
        count[cIdx[i]] = 0
        lcc += (lImg == cIdx[i])

    return lcc


def post_process(img, s=11):
    bImg = (img > 0.5)
    if len(bImg.shape) > 2:
        bImg = bImg[:, :, -1]
    sEl = morphology.disk(s)
    lImg = measure.label(bImg)
    lcc = largestCC(lImg)
    pImg = morphology.binary_closing(lcc, sEl)
    return pImg.astype(float)

In [None]:
def pad_img(img):
    temp_width = 448

    img_height = img.shape[0]
    img_width = img.shape[1]

    temp_height = int((img.shape[0] / (img.shape[1] / temp_width)))
    if temp_height > 576:
        temp_height = 576
        temp_width = int((img.shape[1] / (img.shape[0] / temp_height)))

    img = transform.resize(img, (temp_height, temp_width))
    img = exposure.equalize_hist(img)

    img = torch.Tensor(img)
    padded_img = torch.zeros((640, 512))
    padding_h = (int((576 - temp_height) / 2)) + PADDING
    padding_w = int((448 - temp_width) / 2) + PADDING
    roi = torch.zeros(padded_img.shape)

    if padding_w == PADDING:
        padded_img[np.abs(padding_h):(padding_h + img.shape[0]), PADDING:-PADDING] = img
        roi[np.abs(padding_h):(padding_h + img.shape[0]), PADDING:-PADDING] = 1.0
    else:
        padded_img[PADDING:-PADDING, np.abs(padding_w):(padding_w + img.shape[1])] = img
        roi[PADDING:-PADDING, np.abs(padding_w):(padding_w + img.shape[1])] = 1.0

    padded_img = padded_img.unsqueeze(0).unsqueeze(0)
    return padded_img, roi, padding_h, padding_w, temp_height, temp_width, img_height, img_width

In [None]:
def unpad_mask(padded_mask, padding_h, padding_w, temp_height, temp_width, img_height, img_width, apply_post_process=False):
    padded_mask = padded_mask.data.numpy()

    if apply_post_process:
        padded_mask = post_process(padded_mask)

    if padding_w == PADDING:
        mask = padded_mask[np.abs(padding_h):(padding_h + temp_height), PADDING:-PADDING]
        mask = transform.resize(image=mask,
                      output_shape=(img_height, img_width),
                      preserve_range=True)
    else:
        mask = padded_mask[PADDING:-PADDING, np.abs(padding_w):(padding_w + temp_width)]
        mask = transform.resize(image=mask,
                      output_shape=(img_height, img_width),
                      preserve_range=True)

    mask = img_as_ubyte(mask > 0.5)

    return mask

In [None]:
def get_mask(model, img):
    padded_img, roi, p_h, p_w, temp_h, temp_w, original_h, original_w = pad_img(img)

    padded_img = padded_img.to(device)
    _, padded_mask = model(padded_img)
    padded_mask = torch.sigmoid(padded_mask * roi)

    mask = unpad_mask(padded_mask.squeeze(), p_h, p_w, temp_h, temp_w, original_h, original_w, POST_PROCESS)
    
    return mask

In [None]:
def load_img(path):
    img = io.imread(path)
    img = img / img.max()
#     img = color.rgb2gray(img[:, :, :3])
    
    return img

### Apply CLAHE and overlay mask

In [None]:
def overlay_mask(img, mask, histogram_equalization=True):
    img = (img * 255).astype(np.uint8)
    result = cv2.bitwise_and(img, mask)    
    return result

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))

def overlay_mask(img, mask):
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = (img * 255).astype(np.uint8)
    
    cl1 = clahe.apply(img)    
    result = cv2.bitwise_and(cl1, mask)
    return result

# Load model and segment images

In [None]:
net = torch.load(MODEL_PATH)

In [None]:
images = glob.glob(f'{DATA_PATH}*')
images = sorted(images)

In [None]:
for idx, img_path in enumerate(images):
    img_name = img_path.split('/')[-1].split('.')[0]
    pred_path = f'{PREDICTIONS_DIR}{img_name}_pred.png'
    mask_path = f'{MASKS_DIR}{img_name}_mask.png'

    img = load_img(img_path)

    mask = get_mask(net, img)
    pred = overlay_mask(img, mask)

    io.imsave(pred_path, pred)
    io.imsave(mask_path, mask)

    print(f"Segmenting {idx + 1}/{len(images)}: {img_name}")