In [None]:
# Python Modules
import numpy as np
import os
import cv2
import math
import matplotlib.pyplot as plt
import time
import datetime
import random

from tqdm import trange
from tqdm.notebook import tqdm

from time import sleep
from datetime import datetime

import glob

from PIL import Image

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data

from skimage.util import random_noise
from skimage.filters import gaussian
from skimage.metrics import (adapted_rand_error, variation_of_information)

In [None]:
# define custom dataset for data loader

class CustomDataset():
    def __init__(self, unet_output_paths, target_paths):
        self.unet_output_paths = unet_output_paths
        self.transforms = transforms.Compose([transforms.ToTensor(), 
                                              transforms.Normalize(mean=[0.5],
                                                                   std=[0.5])])
        self.target_paths = target_paths

    def __getitem__(self, index):
        unet_output = np.load(self.unet_output_paths[index])
        target = np.load(self.target_paths[index])

        t_unet_output = self.transforms(unet_output)
        t_target = self.transforms(target)
        return t_unet_output, t_target

    def __len__(self):  # return count of sample we have
        return len(self.unet_output_paths)

In [None]:
# get all the image and mask path and number of images
def generate_loader(path, batch_size, mode):
    folder_unet_output = glob.glob(path + mode + "/*.npy")
    folder_target = glob.glob(path + 'target/' + "*.npy")
    len_data = len(folder_unet_output)
    print(len_data)
    unet_output_paths = folder_unet_output
    target_paths = folder_target
    dataset = CustomDataset(unet_output_paths, target_paths)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return loader

BATCH_SIZE = 32


train_loader = generate_loader(data_path + "train/", BATCH_SIZE, "sigmoid")
valid_loader = generate_loader(data_path + "valid/", BATCH_SIZE, "sigmoid")
test_loader = generate_loader(data_path + "test/", BATCH_SIZE, "sigmoid")

In [None]:
# https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/pix2pix

# initialize weights
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
# Define patchGAN discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)


In [None]:
# UNet Generator
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)

In [None]:
# training our cGAN

# constants
img_height, img_width = 256, 256
channels = 1
epochs = 100
lr = 0.0002
b1 = 0.5
b2 = 0.999

discriminator_checkpoint_path = ""
generator_checkpoint_path = ""

cuda = True if torch.cuda.is_available() else False

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)

# Initialize generator and discriminator
generator = GeneratorUNet(in_channels=1, out_channels=1)
discriminator = Discriminator()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

dataloader = train_loader

val_dataloader = valid_loader

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

prev_time = time.time()

discriminator_loss_progress = []
generator_loss_progress = []

for epoch in tqdm(range(epochs)):

    dis_loss, gen_loss, pix_loss, num_batches = [], [], [], 0
    dice = np.random.randint(0, 9)
    batch_num = 0
    for X, Y in tqdm(dataloader):

        # Model inputs
        real_A = Variable(X.type(Tensor))
        real_B = Variable(Y.type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        ### train generator ###

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        gen_loss.append(loss_G.item())
        pix_loss.append(loss_pixel.item())

        ### train discriminator ###
        
        # train with 0.1 chance
        if batch_num % 10 == dice:
            optimizer_D.zero_grad()

            # Real loss
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            loss_D.backward()
            optimizer_D.step()

            dis_loss.append(loss_D.item()) 
        
        batch_num += 1

    ### save checkpoint ###
    now = datetime.now()
    date_time = now.strftime("%H_%M_%S")
    print(f"Saving checkpoint: {date_time}")
    torch.save(discriminator, discriminator_checkpoint_path + date_time + '_disc_checkpoint.pt')
    torch.save(generator, generator_checkpoint_path + date_time + '_gen_checkpoint.pt')
        
    ### log training progress ###
    dis_loss = np.mean(np.array(dis_loss))
    gen_loss = np.mean(np.array(gen_loss))
    pix_loss = np.mean(np.array(pix_loss))

    discriminator_loss_progress.append(dis_loss)
    generator_loss_progress.append(gen_loss)  

    print(f"Epoch: {epoch}, Discriminator Loss: {dis_loss}, Generator Loss: {gen_loss}, Pixel Loss: {pix_loss}")

    # plot visualization
    for X, Y in val_dataloader:
        fig = plt.figure(epoch)
        ax1 = fig.add_subplot(1, 3, 1)
        ax2 = fig.add_subplot(1, 3, 2)
        ax3 = fig.add_subplot(1, 3, 3)
        ax1.set_title("Unet Output")
        ax2.set_title("Generator Output")
        ax3.set_title("Ground Truth")

        output = generator(X.type(Tensor)).cpu().detach().numpy()
        output = output[0].transpose(1,2,0).squeeze()

        X = X[0].cpu().numpy().transpose(1,2,0).squeeze()
        Y = Y[0].cpu().numpy().transpose(1,2,0).squeeze()

        ax1.imshow(X, cmap='gray')
        ax2.imshow(output, cmap='gray')
        ax3.imshow(Y, cmap='gray')
        fig.show()
        plt.pause(0.05)
        break

    # plot loss curves
    x = [i for i in range(len(discriminator_loss_progress))]

    plt.figure(-epoch)
    plt.plot(x, discriminator_loss_progress)
    plt.plot(x, generator_loss_progress)
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(('discriminator', 'generator'))
    fig.show()
    plt.pause(0.05)

In [None]:
# Evaluate Model

generator = torch.load(generator_checkpoint_path + "16_18_28_gen_checkpoint.pt")
generator.cuda()
generator.eval()

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

rand_error= []
pixel_error = []

N = 0

with torch.no_grad():
    for X, Y in test_loader:
        fig = plt.figure(epoch)
        ax1 = fig.add_subplot(1, 3, 1)
        ax2 = fig.add_subplot(1, 3, 2)
        ax3 = fig.add_subplot(1, 3, 3)
        ax1.set_title("Unet Output")
        ax2.set_title("Generator Output")
        ax3.set_title("Ground Truth")

        output = generator(X.type(Tensor)).cpu().detach().numpy()
        output_example = output[0].transpose(1,2,0).squeeze()

        X = X[0].cpu().numpy().transpose(1,2,0).squeeze()
        y = Y[0].cpu().numpy().transpose(1,2,0).squeeze()
        
        N += X.shape[0]

        ax1.imshow(X, cmap='gray')
        ax2.imshow(output

In [None]:
# plot predictions

val_dataloader = valid_loader
generator_checkpoint_path = ""

generator = torch.load(generator_checkpoint_path + "16_18_28_gen_checkpoint.pt")
generator.cuda()
generator.eval()

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

epoch = 0

with torch.no_grad():
    for X, Y in test_loader:
        fig = plt.figure(epoch)
        ax1 = fig.add_subplot(1, 3, 1)
        ax2 = fig.add_subplot(1, 3, 2)
        ax3 = fig.add_subplot(1, 3, 3)
        ax1.set_title("Unet Output")
        ax2.set_title("Generator Output")
        ax3.set_title("Ground Truth")

        output = generator(X.type(Tensor)).cpu().detach().numpy()
        output = output[0].transpose(1,2,0).squeeze()

        X = X[0].cpu().numpy().transpose(1,2,0).squeeze()
        Y = Y[0].cpu().numpy().transpose(1,2,0).squeeze()

        ax1.imshow(X, cmap='gray')
        ax2.imshow(output, cmap='gray')
        ax3.imshow(Y, cmap='gray')
        fig.show()
        plt.pause(0.05)