In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np
import os
import re
from easydict import EasyDict as edict
from PIL import Image
from skimage import io, transform
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

In [None]:
# Define datapath
mask_path = '../input/face-mask-lite-dataset/with_mask'
face_path = '../input/face-mask-lite-dataset/without_mask'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Setting hyperparameters
args = edict()
args.EPOCHS = 10
args.BATCH_SIZE=50
args.LR = 0.0002
args.B1 = 0.5
args.B2 = 0.999
args.N_CPU = 9
args.LATENT_DIM = 100
args.IMG_SIZE = 256
args.CHANNELS = 3
args.NUM_IMG = 10000
args.TRAINING_SIZE = int(0.9*args.NUM_IMG)

In [None]:
class FaceTrainDataset(Dataset):
    def __init__(self, face_path, mask_path):

        def sorted_alphanumeric(data):  
            convert = lambda text: int(text) if text.isdigit() else text.lower()
            alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
            return sorted(data,key = alphanum_key)
        
        self.transforms = transforms.Compose(
                            [transforms.Resize([args.IMG_SIZE, args.IMG_SIZE]),
                             transforms.ToTensor(),])
        
        self.face_path = face_path
        self.mask_path = mask_path
        self.face_file = sorted_alphanumeric(os.listdir(face_path))[:args.TRAINING_SIZE]
        self.mask_file = sorted_alphanumeric(os.listdir(mask_path))[:args.TRAINING_SIZE]

    def __len__(self):
        return len(self.face_file)

    def __getitem__(self, idx):
        face_image = Image.open(self.face_path + '/' + self.face_file[idx])
        mask_image = Image.open(self.mask_path + '/' + self.mask_file[idx])
        face_image = self.transforms(face_image)
        mask_image = self.transforms(mask_image)

        return (face_image,mask_image)

class FaceTestDataset(Dataset):
    def __init__(self, face_path, mask_path):

        def sorted_alphanumeric(data):  
            convert = lambda text: int(text) if text.isdigit() else text.lower()
            alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
            return sorted(data,key = alphanum_key)
        
        self.transforms = transforms.Compose(
                            [transforms.Resize([args.IMG_SIZE, args.IMG_SIZE]),
                             transforms.ToTensor(),])
        
        self.face_path = face_path
        self.mask_path = mask_path
        self.face_file = sorted_alphanumeric(os.listdir(face_path))[args.TRAINING_SIZE:]
        self.mask_file = sorted_alphanumeric(os.listdir(mask_path))[args.TRAINING_SIZE:]

    def __len__(self):
        return len(self.face_file)

    def __getitem__(self, idx):
        face_image = Image.open(self.face_path + '/' + self.face_file[idx])
        mask_image = Image.open(self.mask_path + '/' + self.mask_file[idx])
        face_image = self.transforms(face_image)
        mask_image = self.transforms(mask_image)

        return (face_image,mask_image)

In [None]:
train_dataset = FaceTrainDataset(face_path=face_path, mask_path=mask_path)
train_dataloader = DataLoader(train_dataset, batch_size=args.BATCH_SIZE)
test_dataset = FaceTestDataset(face_path=face_path, mask_path=mask_path)
test_dataloader = DataLoader(test_dataset, batch_size=args.BATCH_SIZE)

In [None]:
for i in range(len(train_dataset)):
    sample = train_dataset[i]
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    if i % 2:
        sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
    else:
        sample_img = np.transpose(sample[1].cpu().detach().numpy(), (1,2,0))
    plt.imshow(sample_img)

    if i == 3:
        plt.show()
        break

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(args.CHANNELS, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = args.IMG_SIZE // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, args.LATENT_DIM), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity
        

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.encoder = Encoder().to(device)
        self.init_size = args.IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(args.LATENT_DIM, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, args.CHANNELS, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(args.CHANNELS, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = args.IMG_SIZE // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [None]:
# Loss function
adversarial_loss = torch.nn.BCELoss().to(device)
mse_loss = torch.nn.MSELoss().to(device)
l1_loss = torch.nn.L1Loss().to(device)

# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.LR, betas=(args.B1, args.B2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.LR, betas=(args.B1, args.B2))

In [None]:
# Function to initialize weights
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
# weight initialization
generator.apply(_weights_init)
discriminator.apply(_weights_init)

In [None]:
os.makedirs('./saved_model')

In [None]:
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 1

In [None]:
val_loss = []
train_loss = []
for epoch in range(args.EPOCHS):
    val_loss_=0
    train_loss_=0
    generator.train()
    discriminator.train()
    for i, (face_imgs, mask_imgs) in enumerate(train_dataloader):
        valid = Variable(Tensor(face_imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
        fake = Variable(Tensor(face_imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(device)
        Y_imgs = Variable(face_imgs.type(Tensor)).to(device)
        X_imgs = Variable(mask_imgs.type(Tensor)).to(device)

        optimizer_G.zero_grad()
        gen_imgs = generator(X_imgs)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid) + lambda_pixel*mse_loss(gen_imgs, Y_imgs)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(Y_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()

        train_loss_ += g_loss.item() + d_loss.item()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, args.EPOCHS, i, len(train_dataloader), d_loss.item(), g_loss.item())
        )
        
        if i%20 == 0:
            ax = plt.subplot(2, 2, 1)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(gen_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            ax = plt.subplot(2, 2, 2)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(face_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            plt.show()
    
    train_loss_ /= len(train_dataloader)
    
    generator.eval()
    discriminator.eval()
    for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):

        Y_imgs = Variable(face_imgs.type(Tensor)).to(device)
        X_imgs = Variable(mask_imgs.type(Tensor)).to(device)

        gen_imgs = generator(X_imgs)

        g_loss = adversarial_loss(discriminator(gen_imgs), valid) + lambda_pixel*mse_loss(gen_imgs, Y_imgs)
        real_loss = adversarial_loss(discriminator(Y_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        
        val_loss_ += g_loss.item() + d_loss.item()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, args.EPOCHS, i, len(test_dataloader), d_loss.item(), g_loss.item())
        )
        
    val_loss_ /= len(test_dataloader)

    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    

torch.save(generator.state_dict(), "./saved_model/dcgan_generator_lambda{}.pth".format(lambda_pixel))
torch.save(discriminator.state_dict(), "./saved_model/dcgan_discriminator_lambda{}.pth".format(lambda_pixel))

In [None]:
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

In [None]:
val_loss = []
train_loss = []
for epoch in range(args.EPOCHS):
    val_loss_=0
    train_loss_=0
    generator.train()
    discriminator.train()
    for i, (face_imgs, mask_imgs) in enumerate(train_dataloader):
        valid = Variable(Tensor(face_imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
        fake = Variable(Tensor(face_imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(device)
        Y_imgs = Variable(face_imgs.type(Tensor)).to(device)
        X_imgs = Variable(mask_imgs.type(Tensor)).to(device)

        optimizer_G.zero_grad()
        gen_imgs = generator(X_imgs)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid) + lambda_pixel*mse_loss(gen_imgs, Y_imgs)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(Y_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        train_loss_ += g_loss.item() + d_loss.item()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, args.EPOCHS, i, len(train_dataloader), d_loss.item(), g_loss.item())
        )
        
        if i%20 == 0:
            ax = plt.subplot(2, 2, 1)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(gen_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            ax = plt.subplot(2, 2, 2)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(face_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            plt.show()
    
    train_loss_ /= len(train_dataloader)
    
    generator.eval()
    discriminator.eval()
    for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):
        Y_imgs = Variable(face_imgs.type(Tensor)).to(device)
        X_imgs = Variable(mask_imgs.type(Tensor)).to(device)
        
        gen_imgs = generator(X_imgs)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid) + lambda_pixel*mse_loss(gen_imgs, Y_imgs)
        real_loss = adversarial_loss(discriminator(Y_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        
        val_loss_ += g_loss.item() + d_loss.item()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, args.EPOCHS, i, len(test_dataloader), d_loss.item(), g_loss.item())
        )
        
    val_loss_ /= len(test_dataloader)

    train_loss.append(train_loss_)
    val_loss.append(val_loss_)
    

torch.save(generator.state_dict(), "./saved_model/dcgan_generator_lambda{}.pth".format(lambda_pixel))
torch.save(discriminator.state_dict(), "./saved_model/dcgan_discriminator_lambda{}.pth".format(lambda_pixel))

In [None]:
import pandas as pd
loss_df = pd.DataFrame(data={'train_loss': train_loss, 'val_loss': val_loss})
loss_df.to_csv('./saved_model/dcgan_loss.csv', index=False)

In [None]:
generator.load_state_dict(torch.load('../input/dcgan-generator/dcgan_generator_lambda1.pth'))

In [None]:
generator.eval()
sample = test_dataset[0]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
generator.eval()
sample = test_dataset[1]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
generator.eval()
l1_loss_value = 0
mse_loss_value = 0
for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):
    face_imgs = Variable(face_imgs.type(Tensor)).to(device)
    mask_imgs = Variable(mask_imgs.type(Tensor)).to(device)
    gen_imgs = generator(mask_imgs)
    l1_loss_value += l1_loss(gen_imgs, face_imgs).item()
    mse_loss_value += mse_loss(gen_imgs, face_imgs).item()

l1_loss_value /= len(test_dataloader)
mse_loss_value /= len(test_dataloader)
print("Generator with lambda_pixel = 1")
print("L1 loss: {}".format(l1_loss_value))
print("MSE loss: {}".format(mse_loss_value))

In [None]:
generator.load_state_dict(torch.load('../input/dcgan-generator/dcgan_generator_lambda100.pth'))

In [None]:
generator.eval()
sample = test_dataset[0]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
generator.eval()
sample = test_dataset[1]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
generator.eval()
l1_loss_value = 0
mse_loss_value = 0
for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):
    face_imgs = Variable(face_imgs.type(Tensor)).to(device)
    mask_imgs = Variable(mask_imgs.type(Tensor)).to(device)
    gen_imgs = generator(mask_imgs)
    l1_loss_value += l1_loss(gen_imgs, face_imgs).item()
    mse_loss_value += mse_loss(gen_imgs, face_imgs).item()

l1_loss_value /= len(test_dataloader)
mse_loss_value /= len(test_dataloader)
print("Generator with lambda_pixel = 100")
print("L1 loss: {}".format(l1_loss_value))
print("MSE loss: {}".format(mse_loss_value))