In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import cv2
from PIL import Image
import random

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
import torchvision.transforms.functional as TF
from torchvision import transforms
import albumentations as A

In [None]:
root_path = '../input/chest-xray-masks-and-labels/Lung Segmentation/'
img_dir = root_path + 'CXR_png/'
mask_dir = root_path + 'masks/'

In [None]:
fname_imgs = os.listdir(img_dir)
fname_masks = os.listdir(mask_dir)
print("Images:", len(fname_imgs), "\nMasks:", len(fname_masks))

In [None]:
img = Image.open('../input/chest-xray-masks-and-labels/Lung Segmentation/CXR_png/CHNCXR_0001_0.png')
print(img.size)

# Creating our custom Dataset and DataLoader

In [None]:
class LungDataset(Dataset):
    def __init__(self, img_dir, mask_dir, mask_list, train=True, tfms=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.mask_list = mask_list
        self.train = train
        self.tfms = tfms
        
        self.train_mask_list, self.val_mask_list = train_test_split(self.mask_list, test_size=0.2, random_state=42)
        
    def __len__(self):
        if self.train:
            return len(self.train_mask_list)
        else:
            return len(self.val_mask_list)
    
    def __getitem__(self, idx):
        if self.train:
            mask_name = self.train_mask_list[idx]
        else:
            mask_name = self.val_mask_list[idx]
        
        img_name = mask_name.replace('_mask.png', '.png') if 'mask' in mask_name else mask_name
        mask_path = self.mask_dir + mask_name
        img_path = self.img_dir + img_name 
        
        img = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')
        
        if self.train:
            if random.random() > 0.5:
                img = TF.hflip(img)
                mask = TF.hflip(mask)
            
            if random.random() > 0.5:
                img = TF.vflip(img)
                mask = TF.vflip(mask)
                
            if random.random() > 0.8:
                angles = range(-90, 105, 15)
                angle = random.choice(angles)
                img = TF.rotate(img, angle)
                mask = TF.rotate(mask, angle)
                
        if self.tfms is not None:
            img = self.tfms['img'](img)
            mask = self.tfms['mask'](mask)
            
            ret_value = {'image': img, 'mask': mask}
            return ret_value

In [None]:
tfms_img = transforms.Compose([ transforms.Resize((256, 256)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,))
                              ])

tfms_mask = transforms.Compose([ transforms.Resize((256, 256)),
                               transforms.ToTensor(),
                              ])

tfms = {'img': tfms_img, 'mask': tfms_mask}

train_dataset = LungDataset(img_dir, mask_dir, fname_masks, train=True, tfms=tfms)
val_dataset = LungDataset(img_dir, mask_dir, fname_masks, train=False, tfms=tfms)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=True, num_workers=2)

In [None]:
def imshow(img, mask=False):
    img = img.cpu().clone().detach().numpy()
    img = img.transpose(1, 2, 0)
    print(img.shape)
    
    if mask:
        img = img * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
    else:
        img = img * np.array((1.0, 1.0, 1.0))
    
    img  = img.clip(0, 1)
    return img

In [None]:
print(len(train_dataset))

In [None]:
fig = plt.figure(figsize = (15,6))
for ith_batch, sample_batched in enumerate(train_loader):
    print(ith_batch, sample_batched['image'].size(), sample_batched['mask'].size())
    
    for index in range(2):
        ax = fig.add_subplot(2, 2 , index + 1)  # subplot index starts from 1
        plt.imshow(imshow(sample_batched['image'][index]))
        ax = fig.add_subplot(2, 2, index + 3)
        plt.imshow(imshow(sample_batched['mask'][index]))
    break

# Architecture of the Pix2Pix

In [None]:
class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        The paper uses:
        - Convolutions of 4x4 spatial filters applied with stride of 2
        - Encoder downsampling by a factor of 2 
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)
        
        if self.batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
            
        if self.activation:
            self.relu = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.relu(x)
        return x

In [None]:
class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True, dropout=False):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout 
        
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)
        
        if self.batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
            
        if self.activation:
            self.relu = nn.ReLU(True)
            
        if self.dropout:
            self.drop = nn.Dropout2d(0.5)
            
    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.relu:
            x = self.relu(x)
        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownConv(input_channels, 64, batchnorm=False)
        self.d2 = DownConv(64, 128)
        self.d3 = DownConv(128, 256)
        self.d4 = DownConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size=1)
        
    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        return xn

In [None]:
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        
        self.encoders = [
            DownConv(input_channels, 64, batchnorm=False), #batch_size x 64 x 128 x 128
            DownConv(64, 128),
            DownConv(128, 256), 
            DownConv(256, 512),
            DownConv(512, 512),
            DownConv(512, 512),
            DownConv(512, 512),
            DownConv(512, 512, batchnorm=False)
        ]
        
        self.decoders = [
            UpConv(512, 512, dropout=True),
            UpConv(1024, 512, dropout=True),
            UpConv(1024, 512, dropout=True),
            UpConv(1024, 512),
            UpConv(1024, 256),
            UpConv(512, 128),
            UpConv(256, 64)
        ]
        
        self.dec_channels = [512, 512, 512, 512, 256, 128, 64]
        self.final_conv = nn.ConvTranspose2d(64, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()
        
        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)
    
    def forward(self, x):
        skip_conns = []
        for encoder in self.encoders:
            x = encoder(x)
            skip_conns.append(x)
            
        skip_conns = list(reversed(skip_conns[:-1]))
        decoders = self.decoders[:-1]
        
        for decoder, skip in zip(decoders, skip_conns):
            x = decoder(x)
            # print(x.shape, skip.shape)
            x = torch.cat((x, skip), axis=1)

        x = self.decoders[-1](x)
        # print(x.shape)
        x = self.final_conv(x)
        return self.tanh(x)

In [None]:
def sanity_check():
    x = torch.randn((1, 1, 256, 256))
    y = torch.randn((1, 1, 256, 256))
    gen1 = Generator(1, 1)
    pred_gen = gen1(x)
    print("Generator Output", pred_gen.shape)
    
sanity_check()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

disc = Discriminator(input_channels=2).to(device)
gen = Generator(input_channels=1, output_channels=1).to(device)

disc_opt = optim.Adam(disc.parameters(), lr=0.0001, betas=(0.5, 0.99))
gen_opt = optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.99))

BCE = nn.BCEWithLogitsLoss()
L1 = nn.L1Loss()

In [None]:
def train_disc(disc, gen, imgs, masks, loss, disc_opt):
    disc.train()
    
    fake_masks = gen(imgs)
    disc_real = disc(imgs, masks) #generate discriminator output for X-Ray and corresponding ground truth mask 
    disc_fake = disc(imgs, fake_masks.detach()) #generate disc output for X-Ray and produced mask
    
    #calculate the loss
    disc_real_loss = BCE(disc_real, torch.ones_like(disc_real))
    disc_fake_loss = BCE(disc_real, torch.zeros_like(disc_fake))
    disc_loss = (disc_real_loss+disc_fake_loss)/2
    
    disc.zero_grad()
    
    disc_loss.backward()
    
    disc_opt.step()
    
    return fake_masks, disc_loss

In [None]:
def train_gen(disc, gen, imgs, masks, fake_masks, loss_bce, loss_l1, gen_opt):
    gen.train()
    
    disc_fake = disc(imgs, fake_masks.detach())
    
    gen_fake_loss = BCE(disc_fake, torch.ones_like(disc_fake))
    l1 = L1(fake_masks, masks)*100
    gen_loss = gen_fake_loss + l1
    
    gen_opt.zero_grad()
    gen_loss.backward()
    gen_opt.step()
    
    return gen_loss

In [None]:
def save_img(gen, valid_dataloader, device, epoch_num, dir_path):
    if os.path.exists(dir_path) == False:
        os.makedirs(dir_path)
        
    sample = next(iter(valid_dataloader))
    imgs = sample['image']
    masks = sample['mask']
    batch_size = imgs.shape[0]
    imgs = imgs.to(device)
    masks = masks.to(device)
    
    gen.eval()
    
    with torch.no_grad():
        fake_masks = gen(imgs)
        fake_mask_grid = make_grid(fake_masks, nrow=4)
        real_mask_grid = make_grid(masks, nrow=4)
        
        save_image(fake_mask_grid, dir_path + f'/fake_masks_{epoch_num}.png')
        save_image(real_mask_grid, dir_path + f'/real_masks_{epoch_num}.png')
        print("Saved intermediate images\n")

In [None]:
def train(gen,
         disc,
         train_dataloader,
         val_dataloader,
         loss_bce,
         loss_l1,
         num_epochs,
         gen_opt,
         disc_opt,
         device,
         dir_path
         ):
    gen_losses = []
    disc_losses = []
    
    num_steps = len(train_dataloader)
    step = 0
    
    for epoch in range(num_epochs):
        for i, sample in enumerate(train_dataloader):
            imgs = sample['image'].to(device)
            masks = sample['mask'].to(device)
            ##print(masks.size())
            
            fake_masks, disc_loss = train_disc(disc, gen, imgs, masks, loss_bce, disc_opt)
            gen_loss = train_gen(disc, gen, imgs, masks, fake_masks, loss_bce, loss_l1, gen_opt)
            
            if i%20 == 0:
                disc_losses.append(disc_loss.item())
                gen_losses.append(gen_loss.item())
                
                print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch, num_epochs, i, num_steps, disc_loss.item(), gen_loss.item()))
                step+=1
                
            save_img(gen, val_dataloader, device, epoch, dir_path)
            
    ret_value = {"disc_loss": disc_losses, "gen_loss": gen_losses}
    return ret_value

In [None]:
history = train(gen, disc, train_loader, val_loader, BCE, L1, 100, gen_opt, disc_opt, device, 'results')

In [None]:
torch.save(gen.state_dict(), '/checkpoints/gen_lung2mask_1.pth')
torch.save(disc.state_dict(), '/checkpoints/disc_lung2mask_1.pth')

In [None]:
plt.plot(history['disc_loss'], '-')
plt.plot(history['gen_loss'], '-')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.legend(['Discriminator Loss', 'Generator Loss'])
plt.title('Loss during training')