# PFA-GAN

In [1]:
from torch import nn, Tensor, zeros, zeros_like, cat, eye, mean, sum as tsum, arange, cuda, load
from torch.optim import Adam
from torch.utils.data import Dataset, DistributedSampler, DataLoader
from torchvision import transforms
from torchvision.datasets.folder import pil_loader
from collections import OrderedDict
from matplotlib import pyplot as plt

import os
import math
import h5py
import torch
import random
import itertools
import numpy as np
import torch.nn.functional as F
import torchvision

## Data

In [2]:
def age2group(age, age_group):
    if isinstance(age, np.ndarray):
        groups = np.zeros_like(age)
    else:
        groups = zeros_like(age).to(age.device)

    if age_group == 4:
        section = [30, 40, 50]
    elif age_group == 5:
        section = [20, 30, 40, 50]
    elif age_group == 7:
        section = [10, 20, 30, 40, 50, 60]
    else:
        raise NotImplementedError
    
    for i, thresh in enumerate(section, 1):
        groups[age > thresh] = i
        
    return groups[0]

In [3]:
class CACD_Dataset(Dataset):
    def __init__(
            self,
            root_dir='materials',
            dataset_name='cacd',
            age_group=4,
            train=False,
            source=0,
            max_iter=200000,
            batch_size=64,
            transforms=None
    ):
        self.root_dir = root_dir
        self.dataset_name = dataset_name
        self.age_group = age_group
        self.train = train
        self.batch_size = batch_size
        self.max_iter = max_iter
        self.total_pairs = batch_size * max_iter
        self.transforms = transforms

        self._load_meta_data()

        self.mean_ages = np.array(
            [np.mean(self.ages[self.age_groups == i])
            for i in range(self.age_group)]
        ).astype(np.float32)

        self.label_group_images = []
        self.label_group_ages = []

        for i in range(self.age_group):
            self.label_group_images.append(
                self.image_names[self.age_groups == i].tolist())
            self.label_group_ages.append(
                self.ages[self.age_groups == i].astype(np.float32).tolist())

        self.target_labels = np.random.randint(source + 1, self.age_group, self.total_pairs)

        pairs = np.array(list(itertools.combinations(range(age_group), 2)))
        p = [1, 1, 1, 0.5, 0.5, 0.5]
        p = np.array(p) / np.sum(p)
        pairs = pairs[np.random.choice(range(len(pairs)), self.total_pairs, p=p), :]
        source_labels, target_labels = pairs[:, 0], pairs[:, 1]
        self.source_labels = source_labels
        self.target_labels = target_labels

        self.true_labels = np.random.randint(0, self.age_group, self.total_pairs)

    def _load_meta_data(self):
        meta = h5py.File(os.path.join(self.root_dir, f"{self.dataset_name}.mat"), 'r')

        self.ages = meta['celebrityImageData']['age'][0,:]
        self.age_groups = np.asanyarray([age2group(np.asanyarray([age]), self.age_group) for age in self.ages])
        self.image_names = np.asanyarray(
            [''.join(chr(i[0])
                     for i in hdf5_object)
                     for hdf5_object in [meta[hdf5_object_reference][:] 
                     for hdf5_object_references in meta['celebrityImageData']['name']
                     for hdf5_object_reference in hdf5_object_references]
            ]
        )

    def __len__(self):
        if len(self.ages) == len(self.image_names) and len(self.image_names) == len(self.age_groups):
            return len(self.ages)
        else:
            return -1
    
    def __getitem__(self, idx):
        source_label = self.source_labels[idx]
        target_label = self.target_labels[idx]
        true_label = self.true_labels[idx]

        source_img = transforms.ToTensor()(pil_loader(os.path.join(self.root_dir, self.dataset_name, random.choice(self.label_group_images[source_label]))))

        index = random.randint(0, len(self.label_group_images[true_label]) - 1)
        true_img = transforms.ToTensor()(pil_loader(os.path.join(self.root_dir, self.dataset_name, self.label_group_images[true_label][index])))
        true_age = self.label_group_ages[true_label][index]
        mean_age = self.mean_ages[target_label]

        if self.transforms is not None:
            source_img = self.transforms(source_img)
            true_img = self.transforms(true_img)

        return source_img, true_img, source_label, target_label, true_label, true_age, mean_age


In [4]:
class DataPrefetcher():
    def __init__(self, loader, *norm_index):
        self.loader = iter(loader)
        self.normlize = lambda x: x.sub_(0.5).div_(0.5)
        self.norm_index = norm_index
        #self.stream = cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_input = next(self.loader)
        except StopIteration:
            self.next_input = None
            return
        #with cuda.stream(self.stream):
        
        self.next_input = [
            self.normlize(x) if i in self.norm_index and type(x) is torch.tensor else x
            for i, x in enumerate(self.next_input)
        ]

    def next(self):
        #cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        self.preload()
        return input

## PFA-GAN

### Discriminator

In [5]:
class Discriminator(nn.Module):
    def __init__(self, age_group, conv_dim=64, repeat_num=3):
        super(Discriminator, self).__init__()
        self.age_group = age_group
        self.conv_dim = conv_dim
        self.repeat_num = repeat_num
        self._init_model()

    def _init_model(self):
        self.conv1 = nn.Conv2d(
            3,
            self.conv_dim,
            kernel_size=4,
            stride=2,
            padding=1
        )

        layers = []
        nf_mult = 1

        # gradually increase the number of filters
        for n in range(1, self.repeat_num):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            
            layers += [
                nn.utils.spectral_norm(
                    nn.Conv2d(
                        in_channels=self.conv_dim * nf_mult_prev + (self.age_group if n == 1 else 0),
                        out_channels=self.conv_dim * nf_mult,
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=True
                    )
                ),
                nn.LeakyReLU(0.2, True)
            ]
        
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** self.repeat_num, 8)
        
        layers += [
            nn.Conv2d(
                in_channels=self.conv_dim * nf_mult_prev,
                out_channels=self.conv_dim * nf_mult,
                kernel_size=4,
                stride=1,
                padding=1
            ),
            nn.BatchNorm2d(self.conv_dim * nf_mult),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(
                in_channels=self.conv_dim * nf_mult,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1

            )
        ]

        self.main = nn.Sequential(*layers)
    
    def forward(self, inputs, condition):
        x = F.leaky_relu(self.conv1(inputs), 0.2, inplace=True)
        condition = self._group2feature(
            condition,
            feature_size=x.size(2),
            age_group=self.age_group
        ).to(x)
        return self.main(cat([x, condition], dim=1))

    def _group2feature(self, group, age_group, feature_size):
        onehot = self._group2onehot(
            group, 
            age_group
        )
        return onehot.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, feature_size, feature_size)

    def _group2onehot(self, groups, age_group):
        code = eye(age_group)[groups.squeeze()]
        if len(code.size()) > 1:
            return code
        return code.unsqueeze(0)


### ResidualBlock

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.n_channels = channels
        self._init_model()
    
    def _init_model(self):
        layers = [
            nn.Conv2d(
                self.n_channels,
                self.n_channels,
                3,
                1,
                1
            ),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(
                self.n_channels,
                self.n_channels,
                3,
                1,
                1
            ),
            nn.BatchNorm2d(self.n_channels),
        ]

        self.main = nn.Sequential(*layers)

    def forward(self, x):
        residual = x
        x = self.main(x)
        return F.leaky_relu(residual + x, 0.2, inplace=True)

### Generator

In [7]:
class GeneratorSubNetwork(nn.Module):
    def __init__(self, in_channels=3, n_residual_blocks=4):
        super(GeneratorSubNetwork, self).__init__()

        self.in_channels = in_channels
        self.n_residual_blocks = n_residual_blocks

        self._init_model()
    
    def _init_model(self):
        layers = [
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=32,
                kernel_size=9,
                stride=1,
                padding=4
            ),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        for _ in range(self.n_residual_blocks):
            layers.append(ResidualBlock(128))

        layers.extend([
            nn.ConvTranspose2d(
                in_channels=128,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(
                in_channels=64,
                out_channels=32,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(
                in_channels=32,
                out_channels=3,
                kernel_size=9,
                stride=1,
                padding=4
            )
        ])
    
        self.main = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.main(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, age_group, n_residual_blocks=4):
        super(Generator, self).__init__()
        self.age_group = age_group
        self._init_model(n_residual_blocks)
    
    def _init_model(self, n_residual_blocks):
        self.sub_networks = nn.ModuleList()

        for _ in range(self.age_group - 1):
            self.sub_networks.append(GeneratorSubNetwork(
                in_channels=3,
                n_residual_blocks=n_residual_blocks
            ))
    
    def forward(self, x, source_label: Tensor, target_label: Tensor):
        condition = self._pfa_encoding(source_label, target_label, self.age_group)
        for i in range(self.age_group - 1):
            aging_effects = self.sub_networks[i](x)
            x = x + aging_effects * condition[:, i]
        return x

    def _pfa_encoding(self, source, target, age_group):
        source, target = source.long(), target.long()
        code = zeros((source.size(0), age_group - 1, 1, 1, 1)).to(source)
        for i in range(source.size(0)):
            code[i, source[i]: target[i], ...] = 1
        return code


### Age Classifier

In [9]:
class AuxiliaryAgeClassifier(nn.Module):
    def __init__(self, conv_dim=64, channels=3, classes=101):
        super(AuxiliaryAgeClassifier, self).__init__()

        self.conv_dim = conv_dim
        self.channels = channels
        self.classes = classes

        self._init_model()

    def _add_vgg_block(self, in_channels, out_channels, more=False):
        layers = [
            (
                'conv1',
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1
                )
            ),
            (
                'relu1',
                nn.ReLU(inplace=True)
            ),
            (
                'conv2',
                nn.Conv2d(
                    out_channels,
                    out_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1
                )
            ),
            (
                'relu2',
                nn.ReLU(inplace=True)
            ),
        ]

        if more:
            layers.extend([
                (
                    'conv3',
                    nn.Conv2d(
                        out_channels,
                        out_channels,
                        kernel_size=3,
                        stride=1,
                        padding=1
                    )
                ),
                (
                    'relu3',
                    nn.ReLU(inplace=True)
                ),
            ])

        layers.append(('maxpool', nn.MaxPool2d(kernel_size=2, stride=2)))

        return nn.Sequential(OrderedDict(layers))

    def _init_model(self):
        self.conv = nn.Sequential(
            self._add_vgg_block(self.channels, self.conv_dim),
            self._add_vgg_block(self.conv_dim, self.conv_dim*2),
            self._add_vgg_block(self.conv_dim*2, self.conv_dim*4, True),
            self._add_vgg_block(self.conv_dim*4, self.conv_dim*8, True),
            self._add_vgg_block(self.conv_dim*8, self.conv_dim*8, True),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(self.conv_dim*7*7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5, inplace=True),
        )

        self.fc2 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5, inplace=True),
        )

        self.cls = nn.Linear(4096, self.classes)

    def forward(self, x):
        """
        Forward operation of the net.
        """
        in_size = x.shape[0]
        x = self.conv(x)
        x = x.view(in_size, -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.cls(x)
        return F.softmax(x, dim=1)

### PFA-GAN

In [10]:
class PFA_GAN(object):
    def __init__(
            self,
            alpha,
            pix_loss_weight,
            gan_loss_weight,
            id_loss_weight,
            age_loss_weight,
            age_group=4,
            image_size=256,
            pretrained_image_size=256,
            init_lr=1e-4,
            restore_iter=0,
            max_iter=200000,
            save_iter=2000,
            decay_pix_factor=0,
            decay_pix_n=2000,
            num_workers=0,
            batch_size=256):
        self.age_group = age_group
        self.image_size = image_size
        self.pretrained_image_size = pretrained_image_size
        self.init_lr = init_lr
        self.restore_iter = restore_iter
        self.max_iter = max_iter
        self.save_iter = save_iter
        self.pix_loss_weight = pix_loss_weight
        self.decay_pix_factor = decay_pix_factor
        self.decay_pix_n = decay_pix_n
        self.gan_loss_weight = gan_loss_weight
        self.alpha = alpha
        self.id_loss_weight = id_loss_weight
        self.age_loss_weight = age_loss_weight
        self.num_workers = num_workers
        self.batch_size = batch_size

        self.prefetcher = self._get_train_loader()

        self._init_model()
    
    def fit(self):
        for n_iter in range(self.restore_iter + 1, self.max_iter + 1):
            inputs = self.prefetcher.next()
            self.train(inputs, n_iter=n_iter)
            if n_iter % self.save_iter == 0 or n_iter == self.max_iter:
                pass

    def train(self, inputs, n_iter):
        source_img, true_img, source_label, target_label, true_label, true_age, mean_age = inputs
        
        self.generator.train()
        self.discriminator.train()

        if self.image_size < self.pretrained_image_size:
            source_img_small = F.interpolate(source_img, self.image_size)
            true_img_small = F.interpolate(true_img, self.image_size)
        else:
            source_img_small = source_img
            true_img_small = true_img
        
        g_source = self.generator(source_img_small, source_label, target_label)
        
        if self.image_size < self.pretrained_image_size:
            g_source_pretrained = F.interpolate(g_source, self.pretrained_image_size)
        else:
            g_source_pretrained = g_source

        ###########################
        # Train Discriminator
        ###########################
        self.optimizer_discriminator.zero_grad()
        d1_logit = self.discriminator(true_img_small, true_label)
        # d2_logit = self.discriminator(true_img, source_label)
        d3_logit = self.discriminator(g_source.detach(), target_label)

        # d_loss = 0.5 * (ls_gan(d1_logit, 1.) + ls_gan(d2_logit, 0.) + ls_gan(d3_logit, 0.))
        d_loss = 0.5 * (self._ls_gan(d1_logit, 1.) + self._ls_gan(d3_logit, 0.))

        # COMMENT IN 
        #with amp.scale_loss(d_loss, self.optimizer_discriminator) as scaled_loss:
        #    scaled_loss.backward()
        self.optimizer_discriminator.step()

        ###########################
        # Train Generator
        ###########################
        self.optimizer_generator.zero_grad()

        ###########################
        # GAN Loss
        ###########################
        gan_logit = self.discriminator(g_source, target_label)
        g_loss = self._ls_gan(gan_logit, 1.)

        ###########################
        # Age Loss
        ###########################
        age_loss = self._age_criterion(g_source_pretrained, mean_age)

        ###########################
        # L1 Loss
        ###########################
        l1_loss = F.l1_loss(g_source_pretrained, source_img)

        ###########################
        # SSIM loss
        ###########################
        ssim_loss = self._compute_ssim_loss(g_source_pretrained, source_img, window_size=10)

        ###########################
        # ID Loss
        ###########################
        id_loss = F.mse_loss(
            self._extract_vgg_face(g_source_pretrained),
            self._extract_vgg_face(source_img)
        )

        pix_loss_weight = max(
            self.pix_loss_weight,
            self.pix_loss_weight * (self.decay_pix_factor ** (n_iter // self.decay_pix_n))
        )

        total_loss = \
            g_loss * self.gan_loss_weight + \
            (l1_loss * (1 - self.alpha) + ssim_loss * self.alpha) * pix_loss_weight + \
            id_loss * self.id_loss_weight + \
            age_loss * self.age_loss_weight
        
        # COMMENT IN 
        # with amp.scale_loss(total_loss, self.optimizer_generator) as scaled_loss:
        #     scaled_loss.backward()
        self.optimizer_generator.step()

    def _ls_gan(self, inputs, targets):
        return mean((inputs - targets) ** 2)
    
    def _age_criterion(self, input, gt_age):
        age_logit, group_logit = self.age_classifier(input)
        return F.mse_loss(self._get_dex_age(age_logit), gt_age)# + \
               #F.cross_entropy(group_logit, age2group(gt_age, self.age_group).long())
    
    def _get_dex_age(self, pred):
        pred = F.softmax(pred, dim=1)
        value = tsum(pred * arange(pred.size(1)).to(pred.device), dim=1)
        return value
    
    def _compute_ssim_loss(self, img1, img2, window_size=11):
        channel = img1.size(1)
        window = self._create_window(window_size, channel).to(img1.device)

        mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
        mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

        C1 = 0.01 ** 2
        C2 = 0.03 ** 2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

        return 1.0 - ssim_map.mean()

    def _create_window(self, window_size, channel):
        _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window
    
    def _gaussian(self, window_size, sigma):
        gauss = Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
        return gauss / gauss.sum()
    
    def _extract_vgg_face(self, inputs):
        inputs = self._normalize((F.hardtanh(inputs) * 0.5 + 0.5) * 255,
                           [129.1863, 104.7624, 93.5940],
                           [1.0, 1.0, 1.0])
        return self.vgg_face(inputs)
    
    def _normalize(self, input, mean, std):
        mean = Tensor(mean).to(input.device)
        std = Tensor(std).to(input.device)
        return input.sub(mean[None, :, None, None]).div(std[None, :, None, None])

    def _init_model(self):
        self.generator = Generator(self.age_group)
        self.generator.apply(self._weights_init)

        self.discriminator = Discriminator(
            age_group=self.age_group,
            repeat_num=int(np.log2(self.image_size) - 4),

        )

        self.vgg_face = torchvision.models.vgg16(num_classes=2622)
        self.vgg_face.eval()

        self.age_classifier = AuxiliaryAgeClassifier(
            age_group=self.age_group
        )
        self.age_classifier.load_state_dict(
            self._load_network(
                os.path.join(
                    'materials',
                    'models',
                    'age_sd.pth.obj'
                )
            )
        )
        self.age_classifier.eval()

        self.optimizer_discriminator = Adam(
            self.discriminator.parameters(),
            self.init_lr,
            betas=(0.5, 0.99)
        )

        self.optimizer_generator = Adam(
            self.generator.parameters(),
            self.init_lr,
            betas=(0.5, 0.99)
        )

    def _load_network(self, state_dict):
        if isinstance(state_dict, str):
            state_dict = load(state_dict, map_location='cpu')
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            namekey = k.replace('module.', '')  # remove `module.`
            new_state_dict[namekey] = v
        return new_state_dict

    def _get_train_loader(self):
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(self.pretrained_image_size),
            torchvision.transforms.ToTensor(),
        ])

        train_dataset = CACD_Dataset(
            'materials',
            'cacd'
        )

        #train_sampler = DistributedSampler(
        #    train_dataset, shuffle=False
        #)

        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            drop_last=True,
            num_workers=self.num_workers,
            pin_memory=True,
            #sampler=train_sampler
        )

        return DataPrefetcher(train_loader, [0, 1])

    def _weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            # torch.nn.init.kaiming_normal(m.weight.data, mode='fan_in')
            # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            # m.weight.data.normal_(0, math.sqrt(2. / n))
            m.weight.data.normal_(0, 0.02)
            if hasattr(m.bias, 'data'):
                m.bias.data.fill_(0)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.zero_()


In [11]:
pfa_gan = PFA_GAN(
    alpha=1e-4,
    pix_loss_weight=1e-4,
    gan_loss_weight=1e-4,
    id_loss_weight=1e-4,
    age_loss_weight=1e-4,
)

TypeError: AuxiliaryAgeClassifier.__init__() got an unexpected keyword argument 'age_group'

In [82]:
pfa_gan.fit()