In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import cv2
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
ImagesPath = input("Enter path to folder containing unmasked images: ")
imgs = os.listdir(ImagesPath)
TrainImages = []

for im in imgs:
    img = cv2.imread(os.path.join(ImagesPath, im))
    img = cv2.resize(img, (256, 256))
    TrainImages.append(img)

In [None]:
class AugmentSet(Dataset):
    def __init__(self, X, y, dim = (256, 256), n_channels = 3):
        super(AugmentSet, self).__init__()
        self.X = X
        self.y = y
        self.dim = dim
        self.n_channels = n_channels
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        X_input, y_input = self.__data_generator(idx)
        return X_input, y_input
    
    def __create_masked_image(self, img):
                height = 256; width = 256

        mask = np.zeros((height, width, 3), np.uint8) * 255

        # Draw random black lines
        num_lines = np.random.randint(5, 10)
        for i in range(num_lines):
            thickness = np.random.randint(1, 5)
            x1, y1 = np.random.randint(0, width), np.random.randint(0, height)
            x2, y2 = np.random.randint(0, width), np.random.randint(0, height)
            cv2.line(mask, (x1, y1), (x2, y2), (255, 255, 255), thickness)
        # Draw random black curves

        num_curves = np.random.randint(2, 4)
        for i in range(num_curves):
            points = np.random.randint(0, min(height, width), size=(4, 2))
            thickness = np.random.randint(1, 5)
            angle = np.random.randint(0, 180)
            arc_length = np.random.randint(20, 80)
            curve_center = np.mean(points[:2], axis=0)
            start_angle = np.arctan2(points[0][1] - curve_center[1], points[0][0] - curve_center[0]) * 180 / np.pi
            end_angle = np.arctan2(points[1][1] - curve_center[1], points[1][0] - curve_center[0]) * 180 / np.pi
            cv2.ellipse(mask, tuple(curve_center.astype(int)), (arc_length, arc_length), angle, start_angle, end_angle, (255, 255, 255), thickness)

            curve_center = np.mean(points[2:], axis=0)
            start_angle = np.arctan2(points[2][1] - curve_center[1], points[2][0] - curve_center[0]) * 180 / np.pi
            end_angle = np.arctan2(points[3][1] - curve_center[1], points[3][0] - curve_center[0]) * 180 / np.pi
            cv2.ellipse(mask, tuple(curve_center.astype(int)), (arc_length, arc_length), angle, start_angle, end_angle, (255, 255, 255), thickness)


        # Draw random black circles
        num_circles = np.random.randint(5, 10)
        for i in range(num_circles):
            radius = np.random.randint(1, 20)
            thickness = -1
            x, y = np.random.randint(0, width), np.random.randint(0, height)
            cv2.circle(mask, (x, y), radius, (255, 255, 255), thickness)
        
        return masked_image, mask
    
    def __data_generator(self, idx):
        img_copy = self.X[idx].copy()
        masked_img, mask = self.__create_masked_image(img_copy)
        masked_img = (torch.tensor((masked_img/255.0).astype("float32")).reshape((3, self.X[idx].shape[0], self.X[idx].shape[1])))
        mask = (torch.tensor((mask/255.0).astype("float32")).reshape((3, self.X[idx].shape[0], self.X[idx].shape[1])))
        y_img = (torch.tensor((self.y[idx]/255.0).astype("float32")).reshape((3, self.X[idx].shape[0], self.X[idx].shape[1])))
        
        return [masked_img, mask], y_img

In [None]:
random.shuffle(TrainImages)
AugmentedDataset = AugmentSet(TrainImages, TrainImages, (256, 256), 3)

In [None]:
dataloader = DataLoader(AugmentedDataset, batch_size = 8, shuffle = True)

In [None]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ELU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            *upsample(64, 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [None]:
disc = Discriminator(channels = 3).to(device)
reckt = Generator().to(device)

In [None]:
# loss function created to rebuild masked area
class ReconstructLossL2(nn.Module):
    def __init__(self):
        super(ReconstructLossL2, self).__init__()
        
    def forward(self, mask, y_pred, ipt):
        diff = ipt - y_pred
        loss = torch.mul(mask, diff)
        l1loss = torch.sum(torch.abs(loss))
    
        return l1loss

In [None]:
# replace masked area in image with that predicted by the model
def get_filled_img(mask, masked_img, output):
    mi = np.array(masked_img.cpu()).reshape(256, 256, 3)
    m = np.array(mask.cpu()).reshape(256, 256, 3)
    o = np.array(output.cpu().detach()).reshape(256, 256, 3)
    mt = cv2.cvtColor(m, cv2.COLOR_BGR2RGB)
    mt = cv2.cvtColor(mt, cv2.COLOR_RGB2GRAY)
    
    ex = cv2.bitwise_and(o,o,mask=mt.astype("uint8"))    
    img = cv2.add(ex, mi)
    
    return img

## Training Area

In [None]:
# loading model files, trained earlier on local GPU of laptop
reckt.load_state_dict(torch.load("generator_ckpt.pt"))
disc.load_state_dict(torch.load("discriminator_ckpt.pt"))

In [None]:
num_epochs = 30
pixelwise_loss = ReconstructLossL2()
adversarial_loss = nn.MSELoss()
optimizer_g = torch.optim.Adam(reckt.parameters(), lr = 0.0004, betas = (0.1, 0.99))
optimizer_d = torch.optim.Adam(disc.parameters(), lr = 0.0004, betas = (0.1, 0.999))
mask_size = 256
Tensor = torch.cuda.FloatTensor
patch_h, patch_w = int(mask_size / 2 ** 3), int(mask_size / 2 ** 3)
patch = (1, patch_h, patch_w)
g_loss_items = []
d_loss_items = []

reckt.train()
disc.train()

for epoch in range(num_epochs):
    running_g_loss = 0
    running_d_loss = 0
    
    for i, data in enumerate(dataloader):
        inputs, labels = data
        masked_img = inputs[0]
        mask = inputs[1]
        masked_img, labels, mask = masked_img.to(device), labels.to(device), mask.to(device)
        
        valid = Variable(Tensor(labels.shape[0], *patch).fill_(1.0), requires_grad=False).to(device)
        fake = Variable(Tensor(labels.shape[0], *patch).fill_(0.0), requires_grad=False).to(device)
        
        # training the generator network
        optimizer_g.zero_grad()
        
        outputs = reckt(masked_img)  
        oos = outputs.detach().clone()
        
        # getting proper output images formed from generator output
        a = []

        for (m, mi, o) in zip(mask, masked_img, outputs):
            a.append(get_filled_img(m, mi, o))

        a = torch.tensor(np.array(a))
        gen_predicts = a.reshape(outputs.shape[0], 3, 256, 256)
        
        g_pixel = pixelwise_loss(mask, outputs, labels)  
        g_adv = adversarial_loss(disc(gen_predicts.to(device)), valid)
        g_loss = 1000 * g_adv + 0.5 * g_pixel
        
        g_loss.backward()
        optimizer_g.step()
        running_g_loss += g_loss.item()
        
        # training the discriminator network
        optimizer_d.zero_grad()
        
        fake_loss = adversarial_loss(disc(gen_predicts.to(device)), fake)
        real_loss = adversarial_loss(disc(labels.to(device)), valid)
        d_loss = 0.5 * (fake_loss + real_loss)
        
        d_loss.backward()
        optimizer_d.step()
        running_d_loss += d_loss.item()

        g_loss_items.append(g_loss.item())
        d_loss_items.append(d_loss.item())
        
        if (i % 100 == 99):
            print(f"Epoch - {epoch + 1}, iteration = {i + 1}, generator-loss: {running_g_loss / 100:.3f}, discriminator-loss: {running_d_loss / 100:.3f}")#", acc: {running_acc / 100:.3f}")
            running_g_loss = 0
            running_d_loss = 0
            torch.save(reckt.state_dict(), "generator_ckpt.pt")
            torch.save(disc.state_dict(), "discriminator_ckpt.pt") 

In [None]:
import seaborn as sns
import pandas as pd

In [None]:
# generator loss plot
g_l_dict = {"iterations" : list(range(len(g_loss_items))), "generator_loss" : g_loss_items}
generator_loss_df = pd.DataFrame(g_l_dict)
sns.lineplot(data = generator_loss_df, x = 'iterations', y = 'generator_loss', color = 'g')

In [None]:
# discriminator loss plot
d_l_dict = {"iterations" : list(range(len(d_loss_items))), "discriminator_loss" : d_loss_items}
discriminator_loss_df = pd.DataFrame(d_l_dict)
sns.lineplot(data = discriminator_loss_df, x = 'iterations', y = 'discriminator_loss', color = 'b')

## Testing Area

In [None]:
import os
SamplePath = input("Enter path to testing data containing masked, painted images and masks in same format as in sample testing: ")
sample_set_path = SamplePath
imgs = os.listdir(sample_set_path)
imgs.sort()
imgs[0], imgs[1] = imgs[1], imgs[0] # to keep them in order... ensure proper order before continuing

In [None]:
masks = []
label_imgs = []
inpaints = []

for img_name in imgs:
    img_path = os.path.join(sample_set_path, img_name)
    img_arr = cv2.imread(img_path)
    if ("mask" in img_name):
        masks.append(img_arr)
    if ("inpainted" in img_name):
        inpaints.append(img_arr)
    else:
        label_imgs.append(img_arr)

In [None]:
index = int(input("Enter index of image to paint upon: "))
res = reckt(torch.tensor((inpaints[index]/255).astype("float32").reshape(1, 3, 256, 256)).to(device))

In [None]:
m = torch.tensor(masks[index]).reshape(3, 256, 256)/255
mi = torch.tensor(inpaints[index]).reshape(3, 256, 256)/255
o = torch.tensor(res).reshape(3, 256, 256)

mi = np.array(mi.cpu()).reshape(256, 256, 3)
m = 1 - m
m = np.array(m.cpu()).reshape(256, 256, 3)
o = np.array(o.cpu().detach()).reshape(256, 256, 3)
mt = cv2.cvtColor(m, cv2.COLOR_BGR2RGB)
mt = cv2.cvtColor(mt, cv2.COLOR_RGB2GRAY)

mt[mt > 0.8] = 1

ex = cv2.bitwise_and(o,o,mask=mt.astype("uint8"))
mi[mt > 0.5] = 0
img = cv2.add(ex, mi)

In [None]:
plt.imshow(mi)

In [None]:
plt.imshow(ex)

In [None]:
plt.imshow(res.cpu().detach().reshape(256, 256, 3))

In [None]:
plt.imshow(cv2.add(ex, mi))