In [1]:
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm 
import numpy as np
import os
import re
from easydict import EasyDict as edict
from PIL import Image
from skimage import io, transform
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

In [3]:
mask_path = '../input/face-mask-lite-dataset/with_mask'
face_path = '../input/face-mask-lite-dataset/without_mask'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
args = edict()
args.EPOCHS = 10
args.BATCH_SIZE=50
args.LR = 0.0002
args.B1 = 0.5
args.B2 = 0.999
args.N_CPU = 9
args.LATENT_DIM = 100
args.IMG_SIZE = 256
args.CHANNELS = 3
args.NUM_IMG = 10000
args.TRAINING_SIZE = int(0.9*args.NUM_IMG)

In [6]:
class FaceTrainDataset(Dataset):
    def __init__(self, face_path, mask_path):
        def sorted_alphanumeric(data):  
            convert = lambda text: int(text) if text.isdigit() else text.lower()
            alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
            return sorted(data,key = alphanum_key)
        
        self.transforms = transforms.Compose(
                            [transforms.Resize([args.IMG_SIZE, args.IMG_SIZE]),
                             transforms.ToTensor(),])
        
        self.face_path = face_path
        self.mask_path = mask_path
        self.face_file = sorted_alphanumeric(os.listdir(face_path))[:args.TRAINING_SIZE]
        self.mask_file = sorted_alphanumeric(os.listdir(mask_path))[:args.TRAINING_SIZE]

    def __len__(self):
        return len(self.face_file)

    def __getitem__(self, idx):
        face_image = Image.open(self.face_path + '/' + self.face_file[idx])
        mask_image = Image.open(self.mask_path + '/' + self.mask_file[idx])
        face_image = self.transforms(face_image)
        mask_image = self.transforms(mask_image)

        return (face_image,mask_image)

class FaceTestDataset(Dataset):
    def __init__(self, face_path, mask_path):
        def sorted_alphanumeric(data):  
            convert = lambda text: int(text) if text.isdigit() else text.lower()
            alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)',key)]
            return sorted(data,key = alphanum_key)
        
        self.transforms = transforms.Compose(
                            [transforms.Resize([args.IMG_SIZE, args.IMG_SIZE]),
                             transforms.ToTensor(),])
        
        self.face_path = face_path
        self.mask_path = mask_path
        self.face_file = sorted_alphanumeric(os.listdir(face_path))[args.TRAINING_SIZE:]
        self.mask_file = sorted_alphanumeric(os.listdir(mask_path))[args.TRAINING_SIZE:]

    def __len__(self):
        return len(self.face_file)

    def __getitem__(self, idx):
        face_image = Image.open(self.face_path + '/' + self.face_file[idx])
        mask_image = Image.open(self.mask_path + '/' + self.mask_file[idx])
        face_image = self.transforms(face_image)
        mask_image = self.transforms(mask_image)

        return (face_image,mask_image)

In [7]:
train_dataset = FaceTrainDataset(face_path=face_path, mask_path=mask_path)
train_dataloader = DataLoader(train_dataset, batch_size=args.BATCH_SIZE)
test_dataset = FaceTestDataset(face_path=face_path, mask_path=mask_path)
test_dataloader = DataLoader(test_dataset, batch_size=args.BATCH_SIZE)

In [8]:
for i in range(len(train_dataset)):
    sample = train_dataset[i]
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    if i % 2:
        sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
    else:
        sample_img = np.transpose(sample[1].cpu().detach().numpy(), (1,2,0))
    plt.imshow(sample_img)

    if i == 3:
        plt.show()
        break

In [9]:
class DownSampleConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        Paper details:
        - C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x

In [10]:
class UpSampleConv(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        activation=True,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.ReLU(True)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [11]:
class Generator(nn.Module):

    def __init__(self, in_channels, out_channels):
        """
        Paper details:
        - Encoder: C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        - Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
        """
        super().__init__()

        # encoder/donwsample convs
        self.encoders = [
            DownSampleConv(in_channels, 64, batchnorm=False),  # bs x 64 x 128 x 128
            DownSampleConv(64, 128),  # bs x 128 x 64 x 64
            DownSampleConv(128, 256),  # bs x 256 x 32 x 32
            DownSampleConv(256, 512),  # bs x 512 x 16 x 16
            DownSampleConv(512, 512),  # bs x 512 x 8 x 8
            DownSampleConv(512, 512),  # bs x 512 x 4 x 4
            DownSampleConv(512, 512),  # bs x 512 x 2 x 2
            DownSampleConv(512, 512, batchnorm=False),  # bs x 512 x 1 x 1
        ]

        # decoder/upsample convs
        self.decoders = [
            UpSampleConv(512, 512, dropout=True),  # bs x 512 x 2 x 2
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 4 x 4
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 8 x 8
            UpSampleConv(1024, 512),  # bs x 512 x 16 x 16
            UpSampleConv(1024, 256),  # bs x 256 x 32 x 32
            UpSampleConv(512, 128),  # bs x 128 x 64 x 64
            UpSampleConv(256, 64),  # bs x 64 x 128 x 128
        ]
        self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
        self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)

    def forward(self, x):
        skips_cons = []
        for encoder in self.encoders:
            x = encoder(x)

            skips_cons.append(x)

        skips_cons = list(reversed(skips_cons[:-1]))
        decoders = self.decoders[:-1]

        for decoder, skip in zip(decoders, skips_cons):
            x = decoder(x)
            x = torch.cat((x, skip), axis=1)

        x = self.decoders[-1](x)
        x = self.final_conv(x)
        return self.tanh(x)

In [12]:
l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()

generator = Generator(3,3).to(device)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.LR, betas=(args.B1, args.B2))

In [None]:
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [13]:
generator.apply(_weights_init)

In [14]:
os.makedirs('./saved_model')

In [None]:
train_loss = []
val_loss = []
for epoch in range(args.EPOCHS):
    train_loss_ = 0
    val_loss_ = 0
    generator.train()
    for i, (face_imgs, mask_imgs) in enumerate(train_dataloader):
        face_imgs = Variable(face_imgs.type(Tensor)).to(device)
        mask_imgs = Variable(mask_imgs.type(Tensor)).to(device)

        optimizer_G.zero_grad()
        gen_imgs = generator(mask_imgs)
        loss = l1_loss(gen_imgs, face_imgs)
        train_loss_ += loss.item()
        loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
            % (epoch, args.EPOCHS, i, len(train_dataloader), loss.item())
        )
        
        if i%20 == 0:
            ax = plt.subplot(2, 2, 1)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(gen_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            ax = plt.subplot(2, 2, 2)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            sample_img = np.transpose(face_imgs[0].cpu().detach().numpy(), (1,2,0))
            plt.imshow(sample_img)

            plt.show()
        
    train_loss_ /= len(train_dataloader)
    
    torch.save(generator.state_dict(), "./saved_model/unet_{}.pth".format(epoch))
    
    generator.eval()
    for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):
        face_imgs = Variable(face_imgs.type(Tensor)).to(device)
        mask_imgs = Variable(mask_imgs.type(Tensor)).to(device)
        
        gen_imgs = generator(face_imgs)
        loss = l1_loss(gen_imgs, mask_imgs)
        val_loss_ += loss.item()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
            % (epoch, args.EPOCHS, i, len(test_dataloader), loss.item())
        )
        
    val_loss_ /= len(test_dataloader)
    
    train_loss.append(train_loss_)
    val_loss.append(val_loss_)

In [None]:
generator.load_state_dict(torch.load('../input/vae-unet/unet.pth'))

In [None]:
generator.eval()
sample = test_dataset[0]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
generator.eval()
sample = test_dataset[1]
ax = plt.subplot(1, 2, 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
mask_img = Variable(sample[1].type(Tensor)).to(device)
gen_img = generator(mask_img.unsqueeze(0))
sample_img = np.transpose(gen_img.cpu().detach().numpy().squeeze(0), (1,2,0))
plt.imshow(sample_img)

ax = plt.subplot(1, 2, 2)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
sample_img = np.transpose(sample[0].cpu().detach().numpy(), (1,2,0))
plt.imshow(sample_img)

plt.show()

In [None]:
plt.figure(figsize=(20,20))
plt.plot(train_loss)
plt.plot(val_loss)
plt.legend(["loss","val_loss"])
plt.show()

In [None]:
generator.eval()
total_mse_loss = 0
total_l1_loss = 0
for i, (face_imgs, mask_imgs) in enumerate(test_dataloader):
    Y_imgs = Variable(face_imgs.type(Tensor)).to(device)
    X_imgs = Variable(mask_imgs.type(Tensor)).to(device)
    gen_imgs = generator(X_imgs)
    total_mse_loss += mse_loss(gen_imgs, Y_imgs).item()
    total_l1_loss += l1_loss(gen_imgs, Y_imgs).item()
    
print("MSE loss: {}".format(total_mse_loss/len(test_dataloader)))
print("L1 loss: {}".format(total_l1_loss/len(test_dataloader)))