Code reference - https://www.kaggle.com/code/balraj98/context-encoder-gan-for-image-inpainting-pytorch

In [1]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random
import sewar

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

In [2]:
folders = sorted(list(os.listdir('/home/k4agrawal/Efficient_Image_Reconstruction/StyleGAN.pytorch/ffhq')))[1:]


In [3]:
allImages = []
root = '/home/k4agrawal/Efficient_Image_Reconstruction/StyleGAN.pytorch/ffhq'
for folder in folders:
    allImages.extend(sorted(glob.glob("%s/%s/*.png" %(root,folder))))

In [4]:
trainImage, testImage = train_test_split(allImages, test_size=0.2, random_state=42)
valImage, testImage = train_test_split(testImage, test_size=0.5, random_state=42)

In [5]:
class ImageDataset(Dataset):
    def __init__(self, files, transforms_=None, img_size=128, mask_size=64, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode
        self.files = files
        # for folder in folders:
        #    self.files.extend(sorted(glob.glob("%s/%s/*.png" %(root,folder))))
        # self.files = self.files[:-4000] if mode == "train" else self.files[-4000:]

    def apply_random_mask(self, img):
        """Randomly masks image"""
        y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        masked_part = img[:, y1:y2, x1:x2]
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1

        return masked_img, masked_part

    def apply_center_mask(self, img):
        """Mask center part of image"""
        # Get upper-left pixel coordinate
        i = (self.img_size - self.mask_size) // 2
        masked_img = img.clone()
        masked_part = masked_img[:, i : i + self.mask_size, i : i + self.mask_size]
        masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1

        return masked_img, masked_part, i

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        if(self.mode=="train"):
            masked_img, aux = self.apply_random_mask(img)
            return img, masked_img, aux
        else:
            masked_img, masked_part, i = self.apply_center_mask(img)
            return img, masked_img, masked_part, i

        

    def __len__(self):
        return len(self.files)

In [6]:
transforms_ = [
    transforms.Resize((128, 128), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
trainDL = DataLoader(
    ImageDataset(files=trainImage, transforms_=transforms_, mode="train"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)
valDL = DataLoader(
    ImageDataset(files=valImage, transforms_=transforms_, mode="train"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)
testDL = DataLoader(
    ImageDataset(files=testImage, transforms_=transforms_, mode="test"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)

In [7]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            nn.Conv2d(64, channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [8]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# def save_sample(batches_done):
    

    
# Loss function
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
generator = Generator(channels=3)
discriminator = Discriminator(channels=3)


generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor

In [9]:
patch_h, patch_w = int(64 / 2 ** 3), int(64 / 2 ** 3)
patch = (1, patch_h, patch_w)

In [None]:
gen_adv_losses, gen_pixel_losses, disc_losses, counter = [], [], [], []
best_loss = np.inf

for epoch in range(50):
    
    ### Training ###
    generator.train()
    discriminator.train()
    gen_adv_loss, gen_pixel_loss, disc_loss = 0, 0, 0
    tqdm_bar = tqdm(trainDL, desc=f'Training Epoch {epoch} ', total=int(len(trainDL)))
    for i, (imgs, masked_imgs, masked_parts) in enumerate(tqdm_bar):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

        # Configure input
        imgs = Variable(imgs.type(Tensor))
        masked_imgs = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))

        ## Train Generator ##
        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_parts = generator(masked_imgs)

        # Adversarial and pixelwise loss
        g_adv = adversarial_loss(discriminator(gen_parts), valid)
        # print(gen_parts.shape, masked_parts.shape)
        g_pixel = pixelwise_loss(gen_parts, masked_parts)
        # Total loss
        g_loss = 0.001 * g_adv + 0.999 * g_pixel

        g_loss.backward()
        optimizer_G.step()

        ## Train Discriminator ##
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(masked_parts), valid)
        fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()
        
        gen_adv_loss, gen_pixel_loss, disc_loss
        gen_adv_losses, gen_pixel_losses, disc_losses, counter
        
        gen_adv_loss += g_adv.item()
        gen_pixel_loss += g_pixel.item()
        gen_adv_losses.append(g_adv.item())
        gen_pixel_losses.append(g_pixel.item())
        disc_loss += d_loss.item()
        disc_losses.append(d_loss.item())
        counter.append(i*12 + imgs.size(0) + epoch*len(trainDL.dataset))
        tqdm_bar.set_postfix(gen_adv_loss=gen_adv_loss/(i+1), gen_pixel_loss=gen_pixel_loss/(i+1), disc_loss=disc_loss/(i+1))
        
        # Generate sample at sample interval
        batches_done = epoch * len(trainDL) + i
#         if batches_done % 1000 == 0:
#             save_sample(batches_done)
    
    with torch.no_grad():
        generator.eval()
        discriminator.eval()
        adv_val_loss, pixel_val_loss, disc_val_loss = 0, 0, 0
        val_tqdm_bar = tqdm(valDL, desc=f'Val Epoch {epoch} ', total=int(len(valDL)))
        for i, (imgs, masked_imgs, masked_parts) in enumerate(val_tqdm_bar):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

            # Configure input
            imgs = Variable(imgs.type(Tensor))
            masked_imgs = Variable(masked_imgs.type(Tensor))
            masked_parts = Variable(masked_parts.type(Tensor))

            # Generate a batch of images
            gen_parts = generator(masked_imgs)

            # Adversarial and pixelwise loss
            g_adv = adversarial_loss(discriminator(gen_parts), valid)
            g_pixel = pixelwise_loss(gen_parts, masked_parts)
            # Total loss
            g_loss = 0.001 * g_adv + 0.999 * g_pixel

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(masked_parts), valid)
            fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
            d_loss = 0.5 * (real_loss + fake_loss)

        
            adv_val_loss += g_adv.item()
            pixel_val_loss += g_pixel.item()
            disc_val_loss += d_loss.item()
            val_tqdm_bar.set_postfix(adv_val_loss=adv_val_loss/(i+1), pixel_val_loss=pixel_val_loss/(i+1), disc_val_loss=disc_val_loss/(i+1))
        print("Epoch: %d, val adv loss: %f, val pixel loss: %f, val disc loss: %f"%(epoch, adv_val_loss, pixel_val_loss, disc_val_loss))

        if((0.001 * adv_val_loss + 0.999 * pixel_val_loss) < best_loss):
            best_loss = 0.001 * adv_val_loss + 0.999 * pixel_val_loss
            torch.save(generator.state_dict(), "saved_models/generator.pth")
            torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")

Training Epoch 0 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 0 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 0, val adv loss: 561.456576, val pixel loss: 142.622745, val disc loss: 1.808908


Training Epoch 1 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 1 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 1, val adv loss: 554.029182, val pixel loss: 144.708080, val disc loss: 1.609072


Training Epoch 2 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 2 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 2, val adv loss: 580.227565, val pixel loss: 126.584137, val disc loss: 0.295926


Training Epoch 3 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 3 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 3, val adv loss: 578.140746, val pixel loss: 123.464378, val disc loss: 0.154754


Training Epoch 4 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 4 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 4, val adv loss: 575.521801, val pixel loss: 120.222854, val disc loss: 0.391044


Training Epoch 5 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 5 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 5, val adv loss: 579.301153, val pixel loss: 118.336148, val disc loss: 0.074646


Training Epoch 6 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 6 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 6, val adv loss: 583.143052, val pixel loss: 116.269866, val disc loss: 0.071606


Training Epoch 7 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 7 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 7, val adv loss: 579.654421, val pixel loss: 114.635081, val disc loss: 0.211934


Training Epoch 8 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 8 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 8, val adv loss: 584.002416, val pixel loss: 113.146289, val disc loss: 0.045975


Training Epoch 9 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 9 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 9, val adv loss: 581.280777, val pixel loss: 111.342367, val disc loss: 0.098447


Training Epoch 10 :   0%|          | 0/4667 [00:00<?, ?it/s]

Val Epoch 10 :   0%|          | 0/584 [00:00<?, ?it/s]

Epoch: 10, val adv loss: 594.953227, val pixel loss: 112.666951, val disc loss: 0.071743


Training Epoch 11 :   0%|          | 0/4667 [00:00<?, ?it/s]

In [None]:
# Load best model
import cv2
def normalize_img(img):
    norm_image = cv2.normalize(img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    norm_image = norm_image.astype(np.uint8)
    return norm_image

with torch.no_grad():
    generator.load_state_dict(torch.load('saved_models/generator.pth'))
    discriminator.load_state_dict(torch.load('saved_models/discriminator.pth'))
    generator.eval()
    discriminator.eval()

    test_tqdm_bar = tqdm(testDL, desc=f'Test Epoch 1 ', total=int(len(testDL)))
    pixel_test_loss = 0
    for step, (imgs, masked_imgs, masked_parts, i) in enumerate(test_tqdm_bar):

#     samples, masked_samples, i = next(iter(testDL))
        samples = Variable(imgs.type(Tensor))
        masked_samples = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))
        i = i[0].item()  # Upper-left coordinate of mask
        # Generate inpainted image
        gen_mask = generator(masked_samples)
    
        g_pixel = pixelwise_loss(gen_mask, masked_parts)
        # print(gen_mask.shape, coords)
        filled_samples = masked_samples.clone()
        filled_samples[:, :, i : i + 64, i : i + 64] = gen_mask
        # Save sample
        sample = torch.cat((masked_samples.data, filled_samples.data, samples.data), -2)
        psnr = 0
        ssim = 0
        for i in range(12):
            pred = normalize_img(filled_samples[i].permute(1,2,0).cpu().numpy())
            gt = normalize_img(samples[i].permute(1,2,0).cpu().numpy())
            # print(pred.shape, gt.shape)
            psnr_score = sewar.psnr(pred,gt)
            # print(psnr_score)
            #ssim
            ssim_score = sewar.ssim(pred,gt)[0]
            psnr += psnr_score
            ssim += ssim_score
        print("psnr_score: %f, ssim_score: %f"%(psnr/12, ssim/12))
        pixel_test_loss += g_pixel.item()
        if(step % 50 == 0):
            save_image(sample, "images/%d.png" % step, nrow=6, normalize=True)
    print("final pixel loss: {:.4f}".format(pixel_test_loss/int(len(testDL))))