Code reference - https://www.kaggle.com/code/balraj98/context-encoder-gan-for-image-inpainting-pytorch

In [1]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random
import sewar

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
 
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

  Referenced from: <955A2762-10C6-381C-99FA-06B93B8D6AED> /Users/kushagraagrawal/anaconda3/envs/newbase/lib/python3.8/site-packages/torchvision/image.so
  Expected in:     <6607DFFE-F5CB-30CC-8D45-014046A9CD96> /Users/kushagraagrawal/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [9]:
folders = sorted(list(os.listdir('StyleGAN.pytorch/ffhq')))[1:]


In [10]:
allImages = []
root = 'StyleGAN.pytorch/ffhq'
for folder in folders:
    allImages.extend(sorted(glob.glob("%s/%s/*.png" %(root,folder))))

In [11]:
allImages, _ = train_test_split(allImages, test_size=0.25, random_state=42) # training on 75% data

In [12]:
trainImage, testImage = train_test_split(allImages, test_size=0.2, random_state=42)
valImage, testImage = train_test_split(testImage, test_size=0.5, random_state=42)

In [13]:
class ImageDataset(Dataset):
    def __init__(self, files, transforms_=None, img_size=128, mask_size=64, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode
        self.files = files
        # for folder in folders:
        #    self.files.extend(sorted(glob.glob("%s/%s/*.png" %(root,folder))))
        # self.files = self.files[:-4000] if mode == "train" else self.files[-4000:]

    def apply_random_mask(self, img):
        """Randomly masks image"""
        y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        masked_part = img[:, y1:y2, x1:x2]
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1

        return masked_img, masked_part

    def apply_center_mask(self, img):
        """Mask center part of image"""
        # Get upper-left pixel coordinate
        i = (self.img_size - self.mask_size) // 2
        masked_img = img.clone()
        masked_part = masked_img[:, i : i + self.mask_size, i : i + self.mask_size]
        masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1

        return masked_img, masked_part, i

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        if(self.mode=="train"):
            masked_img, aux = self.apply_random_mask(img)
            return img, masked_img, aux
        else:
            masked_img, masked_part, i = self.apply_center_mask(img)
            return img, masked_img, masked_part, i

        

    def __len__(self):
        return len(self.files)

In [14]:
transforms_ = [
    transforms.Resize((128, 128), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
trainDL = DataLoader(
    ImageDataset(files=trainImage, transforms_=transforms_, mode="train"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)
valDL = DataLoader(
    ImageDataset(files=valImage, transforms_=transforms_, mode="train"),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)
testDL = DataLoader(
    ImageDataset(files=testImage, transforms_=transforms_, mode="test"),
    batch_size=12,
    shuffle=False,
    num_workers=1,
)

In [15]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            nn.Conv2d(64, channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [17]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# def save_sample(batches_done):
    

    
# Loss function
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
generator = Generator(channels=3)
discriminator = Discriminator(channels=3)


generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor

In [18]:
patch_h, patch_w = int(64 / 2 ** 3), int(64 / 2 ** 3)
patch = (1, patch_h, patch_w)

In [19]:
best_loss = np.inf

In [12]:
checkpoint = torch.load('saved_models/checkpoint_0.75_data.ckpt')
generator.load_state_dict(checkpoint['generator'])
discriminator.load_state_dict(checkpoint['discriminator'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
best_loss = checkpoint['loss']
e = checkpoint['epoch']

In [21]:
gen_adv_losses, gen_pixel_losses, disc_losses, counter = [], [], [], []


for epoch in range(e+1, 50):
    
    ### Training ###
    generator.train()
    discriminator.train()
    gen_adv_loss, gen_pixel_loss, disc_loss = 0, 0, 0
    tqdm_bar = tqdm(trainDL, desc=f'Training Epoch {epoch} ', total=int(len(trainDL)))
    for i, (imgs, masked_imgs, masked_parts) in enumerate(tqdm_bar):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

        # Configure input
        imgs = Variable(imgs.type(Tensor))
        masked_imgs = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))

        ## Train Generator ##
        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_parts = generator(masked_imgs)

        # Adversarial and pixelwise loss
        g_adv = adversarial_loss(discriminator(gen_parts), valid)
        # print(gen_parts.shape, masked_parts.shape)
        g_pixel = pixelwise_loss(gen_parts, masked_parts)
        # Total loss
        g_loss = 0.001 * g_adv + 0.999 * g_pixel

        g_loss.backward()
        optimizer_G.step()

        ## Train Discriminator ##
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(masked_parts), valid)
        fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()
        
        gen_adv_loss, gen_pixel_loss, disc_loss
        gen_adv_losses, gen_pixel_losses, disc_losses, counter
        
        gen_adv_loss += g_adv.item()
        gen_pixel_loss += g_pixel.item()
        gen_adv_losses.append(g_adv.item())
        gen_pixel_losses.append(g_pixel.item())
        disc_loss += d_loss.item()
        disc_losses.append(d_loss.item())
        counter.append(i*12 + imgs.size(0) + epoch*len(trainDL.dataset))
        tqdm_bar.set_postfix(gen_adv_loss=gen_adv_loss/(i+1), gen_pixel_loss=gen_pixel_loss/(i+1), disc_loss=disc_loss/(i+1))
        
        # Generate sample at sample interval
        batches_done = epoch * len(trainDL) + i
#         if batches_done % 1000 == 0:
#             save_sample(batches_done)
    
    with torch.no_grad():
        generator.eval()
        discriminator.eval()
        adv_val_loss, pixel_val_loss, disc_val_loss = 0, 0, 0
        val_tqdm_bar = tqdm(valDL, desc=f'Val Epoch {epoch} ', total=int(len(valDL)))
        for i, (imgs, masked_imgs, masked_parts) in enumerate(val_tqdm_bar):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

            # Configure input
            imgs = Variable(imgs.type(Tensor))
            masked_imgs = Variable(masked_imgs.type(Tensor))
            masked_parts = Variable(masked_parts.type(Tensor))

            # Generate a batch of images
            gen_parts = generator(masked_imgs)

            # Adversarial and pixelwise loss
            g_adv = adversarial_loss(discriminator(gen_parts), valid)
            g_pixel = pixelwise_loss(gen_parts, masked_parts)
            # Total loss
            g_loss = 0.001 * g_adv + 0.999 * g_pixel

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(masked_parts), valid)
            fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
            d_loss = 0.5 * (real_loss + fake_loss)

        
            adv_val_loss += g_adv.item()
            pixel_val_loss += g_pixel.item()
            disc_val_loss += d_loss.item()
            val_tqdm_bar.set_postfix(adv_val_loss=adv_val_loss/(i+1), pixel_val_loss=pixel_val_loss/(i+1), disc_val_loss=disc_val_loss/(i+1))
        print("Epoch: %d, val adv loss: %f, val pixel loss: %f, val disc loss: %f"%(epoch, adv_val_loss, pixel_val_loss, disc_val_loss))

        if((0.001 * adv_val_loss + 0.999 * pixel_val_loss) < best_loss):
            best_loss = 0.001 * adv_val_loss + 0.999 * pixel_val_loss
            PATH = 'saved_models/checkpoint_0.75_data.ckpt'
            torch.save({
                       'generator': generator.state_dict(),
                       'discriminator': discriminator.state_dict(),
                       'optimizer_G':optimizer_G.state_dict(),
                       'optimizer_D':optimizer_D.state_dict(),
                       'epoch': epoch,
                       'loss': best_loss
                       }, PATH)
            # torch.save(generator.state_dict(), "saved_models/generator.pth")
            # torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")

Training Epoch 0 :   0%|          | 0/3450 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/kushagraagrawal/anaconda3/envs/newbase/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/kushagraagrawal/anaconda3/envs/newbase/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'ImageDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [15]:
# Load best model
import cv2
def normalize_img(img):
    norm_image = cv2.normalize(img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    norm_image = norm_image.astype(np.uint8)
    return norm_image

with torch.no_grad():
    checkpoint = torch.load('saved_models/checkpoint_0.75_data.ckpt')
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    generator.eval()
    discriminator.eval()

    test_tqdm_bar = tqdm(testDL, desc=f'Test Epoch 1 ', total=int(len(testDL)))
    pixel_test_loss = 0
    for step, (imgs, masked_imgs, masked_parts, i) in enumerate(test_tqdm_bar):

#     samples, masked_samples, i = next(iter(testDL))
        samples = Variable(imgs.type(Tensor))
        masked_samples = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))
        i = i[0].item()  # Upper-left coordinate of mask
        # Generate inpainted image
        gen_mask = generator(masked_samples)
    
        g_pixel = pixelwise_loss(gen_mask, masked_parts)
        # print(gen_mask.shape, coords)
        filled_samples = masked_samples.clone()
        filled_samples[:, :, i : i + 64, i : i + 64] = gen_mask
        # Save sample
        sample = torch.cat((masked_samples.data, filled_samples.data, samples.data), -2)
        psnr = 0
        ssim = 0
        for i in range(12):
            pred = normalize_img(filled_samples[i].permute(1,2,0).cpu().numpy())
            gt = normalize_img(samples[i].permute(1,2,0).cpu().numpy())
            # print(pred.shape, gt.shape)
            psnr_score = sewar.psnr(pred,gt)
            # print(psnr_score)
            #ssim
            ssim_score = sewar.ssim(pred,gt)[0]
            psnr += psnr_score
            ssim += ssim_score
        print("psnr_score: %f, ssim_score: %f"%(psnr/12, ssim/12))
        pixel_test_loss += g_pixel.item()
        if(step % 100 == 0):
            save_image(sample, "images/%d_75.png" % step, nrow=6, normalize=True)
    print("final pixel loss: {:.4f}".format(pixel_test_loss/int(len(testDL))))

Test Epoch 1 :   0%|          | 0/584 [00:00<?, ?it/s]

psnr_score: 25.773566, ssim_score: 0.879482
psnr_score: 25.377148, ssim_score: 0.874790
psnr_score: 25.435174, ssim_score: 0.880449
psnr_score: 24.378827, ssim_score: 0.862627
psnr_score: 26.267398, ssim_score: 0.881934
psnr_score: 25.207904, ssim_score: 0.881557
psnr_score: 26.333401, ssim_score: 0.891545
psnr_score: 24.858086, ssim_score: 0.876712
psnr_score: 25.638654, ssim_score: 0.875017
psnr_score: 24.545486, ssim_score: 0.869624
psnr_score: 25.164248, ssim_score: 0.879841
psnr_score: 26.072815, ssim_score: 0.888910
psnr_score: 25.204582, ssim_score: 0.873574
psnr_score: 24.860844, ssim_score: 0.876341
psnr_score: 25.655674, ssim_score: 0.876466
psnr_score: 24.571022, ssim_score: 0.873553
psnr_score: 25.687770, ssim_score: 0.869561
psnr_score: 24.185347, ssim_score: 0.861311
psnr_score: 23.971751, ssim_score: 0.855961
psnr_score: 24.023803, ssim_score: 0.864102
psnr_score: 24.645585, ssim_score: 0.869901
psnr_score: 24.857900, ssim_score: 0.877855
psnr_score: 24.877574, ssim_scor

psnr_score: 26.330039, ssim_score: 0.876460
psnr_score: 25.947574, ssim_score: 0.872745
psnr_score: 23.625593, ssim_score: 0.867077
psnr_score: 26.361006, ssim_score: 0.882490
psnr_score: 25.221725, ssim_score: 0.870087
psnr_score: 24.809656, ssim_score: 0.873340
psnr_score: 25.030916, ssim_score: 0.868861
psnr_score: 25.465862, ssim_score: 0.868168
psnr_score: 25.363409, ssim_score: 0.875014
psnr_score: 25.186649, ssim_score: 0.884772
psnr_score: 23.297334, ssim_score: 0.844888
psnr_score: 25.306413, ssim_score: 0.870745
psnr_score: 23.963944, ssim_score: 0.859380
psnr_score: 24.196656, ssim_score: 0.861481
psnr_score: 23.340527, ssim_score: 0.860390
psnr_score: 26.176173, ssim_score: 0.879665
psnr_score: 25.460982, ssim_score: 0.872411
psnr_score: 24.406899, ssim_score: 0.880058
psnr_score: 24.449274, ssim_score: 0.881711
psnr_score: 25.199332, ssim_score: 0.867236
psnr_score: 24.687757, ssim_score: 0.865710
psnr_score: 25.595474, ssim_score: 0.875868
psnr_score: 24.711742, ssim_scor

psnr_score: 25.358484, ssim_score: 0.875778
psnr_score: 25.511493, ssim_score: 0.872808
psnr_score: 24.021179, ssim_score: 0.857626
psnr_score: 24.605205, ssim_score: 0.881840
psnr_score: 25.044654, ssim_score: 0.866119
psnr_score: 25.118880, ssim_score: 0.869683
psnr_score: 22.872596, ssim_score: 0.857620
psnr_score: 25.055624, ssim_score: 0.863633
psnr_score: 25.294164, ssim_score: 0.868477
psnr_score: 23.373371, ssim_score: 0.845203
psnr_score: 23.819467, ssim_score: 0.870052
psnr_score: 25.348655, ssim_score: 0.869713
psnr_score: 24.343320, ssim_score: 0.874778
psnr_score: 25.506458, ssim_score: 0.878996
psnr_score: 25.872450, ssim_score: 0.875548
psnr_score: 25.144040, ssim_score: 0.888900
psnr_score: 24.097496, ssim_score: 0.861122
psnr_score: 25.064972, ssim_score: 0.877852
psnr_score: 25.662737, ssim_score: 0.885109
psnr_score: 24.903539, ssim_score: 0.861173
psnr_score: 24.753197, ssim_score: 0.882566
psnr_score: 25.598601, ssim_score: 0.870610
psnr_score: 25.199219, ssim_scor

psnr_score: 24.755054, ssim_score: 0.869547
psnr_score: 24.890571, ssim_score: 0.870262
psnr_score: 23.687698, ssim_score: 0.866114
psnr_score: 24.053134, ssim_score: 0.864823
psnr_score: 25.246977, ssim_score: 0.878193
psnr_score: 24.539676, ssim_score: 0.864724
psnr_score: 24.937000, ssim_score: 0.875284
psnr_score: 26.062562, ssim_score: 0.883664
psnr_score: 25.200018, ssim_score: 0.868201
psnr_score: 24.391823, ssim_score: 0.858853
psnr_score: 24.713842, ssim_score: 0.867012
psnr_score: 25.297549, ssim_score: 0.868810
psnr_score: 25.544650, ssim_score: 0.880071
psnr_score: 24.472348, ssim_score: 0.872732
psnr_score: 25.268668, ssim_score: 0.871984
psnr_score: 25.348591, ssim_score: 0.881878
psnr_score: 24.688783, ssim_score: 0.881149
psnr_score: 23.738825, ssim_score: 0.865069
psnr_score: 25.664326, ssim_score: 0.881375
psnr_score: 23.975125, ssim_score: 0.859770
psnr_score: 24.835357, ssim_score: 0.868103
psnr_score: 25.359298, ssim_score: 0.891104


IndexError: index 4 is out of bounds for dimension 0 with size 4