# ELEGANT: Face Attribute Editing
This notebook implements the ELEGANT model for face attribute editing using GANs.

## 1. Setup and Imports

In [None]:
import os
import torch
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import time
from tensorboardX import SummaryWriter
import argparse

print(f"PyTorch CUDA Version: {torch.version.cuda}")
print(f"torchvision CUDA Version: {torchvision.__version__}")

## 2. Configuration and Dataset

In [None]:
class Config:
    @property
    def data_dir(self):
        data_dir = './datasets/celebA'
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        return data_dir

    @property
    def exp_dir(self):
        exp_dir = os.path.join('train_log')
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
        return exp_dir

    @property
    def model_dir(self):
        model_dir = os.path.join(self.exp_dir, 'model')
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        return model_dir

    @property
    def log_dir(self):
        log_dir = os.path.join(self.exp_dir, 'log')
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        return log_dir

    @property
    def img_dir(self):
        img_dir = os.path.join(self.exp_dir, 'img')
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        return img_dir

    # Model parameters
    nchw = [16,3,256,256]
    G_lr = 2e-4
    D_lr = 2e-4
    betas = [0.5, 0.999]
    weight_decay = 1e-5
    step_size = 3000
    gamma = 0.97
    shuffle = True
    num_workers = 4
    max_iter = 200000
    num_samples = None

config = Config()

In [None]:
class SingleCelebADataset(Dataset):
    def __init__(self, im_names, labels, config):
        self.im_names = im_names
        self.labels = labels
        self.config = config

    def __len__(self):
        return len(self.im_names)

    def __getitem__(self, idx):
        image = Image.open(self.im_names[idx])
        image = self.transform(image) * 2 - 1
        label = (self.labels[idx] + 1) / 2
        return image, label

    @property
    def transform(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.nchw[-2:]),
            transforms.ToTensor(),
        ])
        return transform

    def gen(self):
        dataloader = DataLoader(self,
                                batch_size=self.config.nchw[0],
                                shuffle=self.config.shuffle,
                                num_workers=self.config.num_workers,
                                drop_last=True)
        while True:
            for data in dataloader:
                yield data

    def file_exist(self):
        for im_name in self.im_names:
            if not os.path.exists(im_name):
                print("File not found:", im_name)
                return False
        return True

In [None]:
class MultiCelebADataset(object):
    def __init__(self, attributes, config=config):
        self.attributes = attributes
        self.config = config

        with open(os.path.join(self.config.data_dir, 'list_attr_celeba.txt'), 'r') as f:
            lines = f.read().strip().split('\n')
            col_ids = [lines[1].split().index(attribute) + 1 for attribute in self.attributes]
            self.all_labels = np.array([[int(x.split()[col_id]) for col_id in col_ids] for x in lines[2:]], dtype=np.float32)
            self.im_names = np.array([os.path.join(self.config.data_dir,
                                                   'align_5p/{:06d}.jpg'.format(idx+1)) for idx in range(len(self.all_labels))])

        if self.config.num_samples is not None:
            self.all_labels = self.all_labels[:self.config.num_samples]
            self.im_names = self.im_names[:self.config.num_samples]
        print("Total images:",len(self.im_names))

        self.dict = {i: {True: None, False: None} for i in range(len(self.attributes))}
        for attribute_id in range(len(self.attributes)):
            for is_positive in [True, False]:
                idxs = np.where(self.all_labels[:,attribute_id] == (int(is_positive)*2 - 1))[0]
                im_names = self.im_names[idxs]
                labels = self.all_labels[idxs]
                data_gen = SingleCelebADataset(im_names, labels, self.config).gen()
                self.dict[attribute_id][is_positive] = data_gen

    def gen(self, attribute_id, is_positive):
        return self.dict[attribute_id][is_positive]

    def file_exist(self):
        cnt = 0
        for im_name in self.im_names:
            if not os.path.exists(im_name):
                print("File not found:", im_name)
            else:
                cnt = cnt + 1
        return cnt

## 3. Model Architecture

In [None]:
class NTimesTanh(nn.Module):
    def __init__(self, N):
        super(NTimesTanh, self).__init__()
        self.N = N
        self.tanh = nn.Tanh()

    def forward(self, x):
        return self.tanh(x) * self.N

class Normalization(nn.Module):
    def __init__(self):
        super(Normalization, self).__init__()
        self.alpha = Parameter(torch.ones(1))
        self.beta  = Parameter(torch.zeros(1))

    def forward(self, x):
        x = torch.nn.functional.normalize(x, dim=1)
        return x * self.alpha + self.beta

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.main = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, 64, 3, 2, 1, bias=True),
                Normalization(),
                nn.LeakyReLU(negative_slope=0.2),
            ),
            nn.Sequential(
                nn.Conv2d(64, 128, 3, 2, 1, bias=True),
                Normalization(),
                nn.LeakyReLU(negative_slope=0.2),
            ),
            nn.Sequential(
                nn.Conv2d(128, 256, 3, 2, 1, bias=True),
                Normalization(),
                nn.LeakyReLU(negative_slope=0.2),
            ),
            nn.Sequential(
                nn.Conv2d(256, 512, 3, 2, 1, bias=True),
                Normalization(),
                nn.LeakyReLU(negative_slope=0.2),
            ),
            nn.Sequential(
                nn.Conv2d(512, 512, 3, 2, 1, bias=True),
                Normalization(),
                nn.LeakyReLU(negative_slope=0.2),
            ),
        ])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, mean=1, std=0.02)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, return_skip=True):
        skip = []
        for i in range(len(self.main)):
            x = self.main[i](x)
            if i < len(self.main) - 1:
                skip.append(x)
        if return_skip:
            return x, skip
        else:
            return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.main = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(1024,512,3,2,1,1,bias=True),
                Normalization(),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(512,256,3,2,1,1,bias=True),
                Normalization(),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(256,128,3,2,1,1,bias=True),
                Normalization(),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(128,64,3,2,1,1,bias=True),
                Normalization(),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(64,3,3,2,1,1,bias=True),
            ),
        ])
        self.activation = NTimesTanh(2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, mean=1, std=0.02)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, enc1, enc2, skip=None):
        x = torch.cat([enc1, enc2], 1)
        for i in range(len(self.main)):
            x = self.main[i](x)
            if skip is not None and i < len(skip):
                x = x + skip[-i-1]
        return self.activation(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, n_attributes, img_size):
        super(Discriminator, self).__init__()
        self.n_attributes = n_attributes
        self.img_size = img_size
        self.conv = nn.Sequential(
            nn.Conv2d(3+n_attributes,64,3,2,1,bias=True),
            Normalization(),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(64,128,3,2,1,bias=True),
            Normalization(),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(128,256,3,2,1,bias=True),
            Normalization(),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(256,512,3,2,1,bias=True),
            Normalization(),
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.linear = nn.Sequential(
            nn.Linear(512*(self.img_size//16)*(self.img_size//16), 1),
            nn.Sigmoid(),
        )
        self.downsample = torch.nn.AvgPool2d(2, stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, mean=0, std=0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, mean=1, std=0.02)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, image, label):
        while image.shape[-1] != self.img_size or image.shape[-2] != self.img_size:
            image = self.downsample(image)
        new_label = label.view((image.shape[0], self.n_attributes, 1, 1)).expand((image.shape[0], self.n_attributes, image.shape[2], image.shape[3]))
        x = torch.cat([image, new_label], 1)
        output = self.conv(x)
        output = output.view(output.shape[0], -1)
        output = self.linear(output)
        return output

## 4. Training Setup

In [None]:
# Example of simulating command-line arguments
args = argparse.Namespace(
    attributes=['Arched_Eyebrows' ,'Black_Hair'],
    gpu=[],
    mode='train',
    restore=None,
    swap=False,
    linear=False,
    matrix=False,
)

In [None]:
print("init model")
model = ELEGANT(args)

## 5. Training Execution

In [None]:
%%time
main(args,model)