## Import Modules

In [102]:
import os
import math
from io import BytesIO
import lmdb
import tqdm
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
# from torchvision.utils import make_grid
# from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt

## Dataset Tool

In [3]:
class AnimeFace(Dataset):
    def __init__(
        self, 
        root_folder='/amax/data/LHao/dataset/', 
        transform=None, 
        sizes=(4, 8, 16, 32, 64), 
        db_path='./lmdb_data', 
        resolution=64
    ):
        super().__init__()
        self.sizes = sizes
        self.resolution = resolution
        imgs = ImageFolder(root_folder)
        self._save_db(imgs, db_path)
        self._db = lmdb.open(db_path)
        self.transform = transform
        
    def _save_db(self, files, db_path):
        img_count = 0
        db = lmdb.open(db_path, map_size=1024 ** 4)
        with db.begin(write=True) as file:
            for i, (original_img, _) in enumerate(files.imgs):
                resize_imgs = self._resize_resolution(original_img)
                for size, img in zip(self.sizes, resize_imgs):
                    key = f'{size}-{str(i).zfill(5)}'
                    file.put(key.encode('utf-8'), img)
                img_count += 1
            file.put('length'.encode('utf-8'), str(img_count).encode('utf-8'))
            self._length = img_count
        db.close()
        
    def _resize_resolution(self, img):
        imgs = []
        img = Image.open(img)
        for size in self.sizes:
            buffer = BytesIO()
            img_resize = transforms.F.resize(img, size)
            img_resize = transforms.F.center_crop(img_resize, size)
            img_resize.save(buffer, format='jpeg')
            imgs.append(buffer.getvalue())
        return imgs
                
    def __len__(self):
        return self._length
    
    def __getitem__(self, index):
        with self._db.begin(write=False) as file:
            img_bytes = file.get(f'{self.resolution}-{str(index).zfill(5)}'.encode())
        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        if self.transform:
            img = self.transform(img)
        return img
    
    def __del__(self):
        self._db.close()
        
    def show(self, index):
        plt.imshow(self[index].permute(1, 2, 0))

## Generator

In [80]:
class StyledGenerator(nn.Module):
    def __init__(self, num_resolution, w_channel=512, device='cpu'):
        super().__init__()
        self.mapping = MappingNetwork()
        self.synthesis = SynthesisNetwork(num_resolution, w_channel=w_channel, device=device)

    def forward(self, latent_code, level):
        return False

### Mapping Network

In [5]:
class MappingNetworkBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
    
    def forward(self, x):
        return self.activation(self.linear(x))

class MappingNetwork(nn.Module):
    def __init__(self, latent_in_dim=512, latent_out_dim=512, num_mlp=8, broadcast=18):
        super().__init__()
        self.latent_in_dim = latent_in_dim
        self.broadcast = broadcast
        self.norm = PixelNorm()
        self._mapping = [MappingNetworkBlock(latent_in_dim, latent_out_dim) for i in range(num_mlp)]
        self.mapping = nn.Sequential(*self._mapping)
    
    def forward(self, x):
        x = self.norm(x)
        x = x.view(-1, self.latent_in_dim)
        out = self.mapping(x)
        out = out.unsqueeze(1)
        return out.repeat(1, self.broadcast, 1)

### tools

#### FadeIn

In [6]:
class FadeIn():
    def __init__(self):
        self.alpha = 1e-5
        
    def __call__(self, x, y):
        return x * (1 - self.alpha) + y * self.alpha
    
    def _update_alpha(self, delta):
        self.alpha += abs(delta)
        self.alpha = float(min(self.alpha, 1.))
        
    def _reset(self):
        self.alpha = 1e-5

#### PixelNorm

![pixelnorm](./source/20201215105113807.png)

In [7]:
class PixelNorm(nn.Module):
    def __init__(self, sigma=1e-8):
        super().__init__()
        self.sigma = sigma
    
    def forward(self, x):
        # x = x/√(x2_avg+ε)
        return x.div(x.pow(2.).mean(dim=1, keepdim=True).add(self.sigma).sqrt()) 

#### InstanceNorm

In [8]:
class InstanceNorm(nn.Module):
    def __init__(self, sigma=1e-8):
        super().__init__()
        self.sigma = sigma
    
    def forward(self, x):
        # x = (x-x_mean)/√(x2_avg+ε)
        return (x - x.mean(dim=(2, 3), keepdim=True)).div(x.pow(2.).mean(dim=(2, 3), keepdims=True).add(self.sigma).sqrt())

#### ConstantInput

In [9]:
class ConstantInput(nn.Module):
    def __init__(self, channel=512, size=4):
        super().__init__()
        self.const_input = nn.Parameter(torch.randn(channel, size, size))
    
    def forward(self, batch_size):
        out = self.const_input.repeat(batch_size, 1, 1, 1)
        return out

#### AdaIn

![AdaIN](./source/AdaIN.png)

In [10]:
class AdaIN(nn.Module):
    def __init__(self, in_features=512, out_features=1024):
        super().__init__()
        self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
    
    def forward(self, x, w):
        w = self.linear(w)
        style = w.reshape([-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))
        return x * (style[:, 0] + 1) + style[:, 1]

#### Noise

In [11]:
class Noise(nn.Module):
    def __init__(self, channel):
        super().__init__()
        # channel为每层输入的张量的通道数
        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
    
    def forward(self, noise):
        return noise * self.weight

#### LayerEpilogue (每层最后都会有一个添加噪声和AdaIn的操作，这里定义一个LayerEpliogue进行封装)

In [12]:
class LayerEpilogue(nn.Module):
    def __init__(self, layer_channel, w_channel=512, use_noise=True, use_pixel_norm=False, use_instance_norm=True, device='cpu'):
        '''
        :channel: x的通道数
        '''
        super().__init__()
        self.device = device
        self.use_noise = use_noise
        self.B = Noise(layer_channel)
        self.adain = AdaIN(w_channel, layer_channel*2)
        
        self.pixel_norm = PixelNorm()
        self.instance_norm = InstanceNorm()
        
        self.use_pixel_norm = use_pixel_norm
        self.use_instance_norm = use_instance_norm
    
    def forward(self, x, w, noise=None):
        # 加入noise
        if self.use_noise:
            if noise is None:
                noise = torch.randn([1, 1, x.shape[2], x.shape[3]], device=device)
            x = x + self.B(noise)
            
        # 使用激活函数
        
        # 归一化处理    
        if self.use_pixel_norm:
            x = self.pixel_norm(x)
        if self.use_instance_norm:
            x = self.instance_norm(x)
        
        return self.adain(x, w)

#### StyleBlock

In [13]:
class StyleBlock(nn.Module):
    def __init__(self, resolution, w_channel=512, in_channels=512, out_channels=512, device='cpu'):
        super().__init__()
        
        assert resolution >= 4
        self.resolution = resolution
        if resolution == 4:
            self.constant_input = ConstantInput()
        else:
            self.upsample = nn.Upsample(scale_factor=2)
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
            
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.lrelu = nn.LeakyReLU(0.2)
        self.layer_epilogue = LayerEpilogue(layer_channel=out_channels, w_channel=w_channel, device=device)
        
    def forward(self, w, x=None, noises=None):
        if self.resolution == 4:
            x = self.constant_input(w[0].shape[0])
        else:
            assert x is not None
            x = self.lrelu(self.conv1(self.upsample(x)))
        x = self.layer_epilogue(x, w[0], noises[0])
        x = self.lrelu(self.conv2(x))
        x = self.layer_epilogue(x, w[1], noises[1])
        return x

#### MinibatchStddevLayer

In [14]:
class MinibatchStdDev(nn.Module):

    def __init__(self):
        super().__init__()
        
    def forward(self, x, alpha=1e-8):
        batch_size, _, height, width = x.shape
        y = x - x.mean(dim=0, keepdim=True)
        y = torch.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)
        y = y.mean().view(1, 1, 1, 1)
        y = y.repeat(batch_size, 1, height, width)
        y = torch.cat([x, y], 1)

        return y

### SynthesisNetwork

In [15]:
class SynthesisNetwork(nn.Module):
    def __init__(self, num_resolution, w_channel=512, device='cpu'):
        super().__init__()
        max_channel = 512

        # 生成 noises
        self.noises = []
        self._generate_noise(num_resolution, device)
        
        self.style, self.to_rgb = nn.ModuleDict(), nn.ModuleList()
        for resolution_idx in range(num_resolution):
            resolution = np.power(2, (resolution_idx + 2))
            if 4 <= resolution <= 32:
                self.style[f'res{resolution}'] = StyleBlock(resolution,
                                                            w_channel=w_channel,
                                                            in_channels=max_channel, 
                                                            out_channels=max_channel,
                                                            device=device)
                self.to_rgb.append(self._to_rgb(max_channel))
            else:
                self.style[f'res{resolution}'] = StyleBlock(resolution, 
                                                            w_channel=w_channel,
                                                            in_channels=max_channel//np.power(2, resolution_idx-4), 
                                                            out_channels=max_channel//np.power(2, resolution_idx-3),
                                                            device=device)
                self.to_rgb.append(self._to_rgb(max_channel//np.power(2, resolution_idx-3)))
                
        self.fade_in = FadeIn()
        
    def forward(self, w, level):
        # level是指第level个block 4：0，8：1...
        resolution = np.power(2, level + 2)
        x = None
        for block_idx, block in enumerate(self.style.values()):
            if block_idx < level:
                x = block([w[:, block_idx * 2, :], w[:, block_idx * 2 + 1, :]], 
                          x, 
                          [self.noises[block_idx * 2], self.noises[block_idx * 2 + 1]])
                
        if level == 0:
            x = self.style[f'res{resolution}']((w[:, level * 2, :], w[:, level * 2 + 1, :]), 
                                                x, 
                                               (self.noises[level * 2], self.noises[level * 2 + 1]))
            
            x = self.to_rgb[level](x)
            return x
        
        residual = x.clone()
        
        residual = F.interpolate(x, scale_factor=2)
        residual = self.to_rgb[level-1](residual)
        x = self.style[f'res{resolution}']((w[:, level * 2, :], w[:, level * 2 + 1, :]), 
                                            x, 
                                           (self.noises[level * 2], self.noises[level * 2 + 1]))
        x = self.to_rgb[level](x)
        return self.fade_in(residual, x)

    
    def _generate_noise(self, num_resolution, device='cpu'):
        for layer_idx in range(1, (num_resolution) * 2 + 1):
            self.noises.append(torch.randn([1, 1, 2 ** ((layer_idx + 1) // 2 + 1), 2 ** ((layer_idx + 1) // 2 + 1)], device=device))
    
    def _to_rgb(self, in_channels):
        return nn.Conv2d(in_channels, 3, 1)

## Discriminator

#### DiscriminatorBlock

In [16]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, resolution, in_channels=512, out_channels=512):
        super().__init__()
        assert resolution >= 4
        self.resolution = resolution
        
        if resolution == 4:
            self.model = nn.Sequential(
                MinibatchStdDev(), 
                nn.Conv2d(in_channels+1, out_channels, kernel_size=4),
                nn.LeakyReLU(0.2),
                nn.Flatten(),  # 全连接之前将张量展平
                nn.Linear(out_channels, out_channels),
                nn.LeakyReLU(0.2),
                nn.Linear(out_channels, 1)
            )
        else:
            self.model = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 
                nn.LeakyReLU(0.2),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2),
                nn.AvgPool2d(2)
            )
        
    def forward(self, x):
        return self.model(x)

In [17]:
class Discriminator(nn.Module):
    def __init__(self, num_resolution):
        super().__init__()
        self.num_resolution = num_resolution
        max_channel = 512
        
        self.model, self.from_rgb = nn.ModuleDict(), nn.ModuleList()
        for resolution_idx in range(num_resolution - 1, -1, -1):
            resolution = np.power(2, (resolution_idx + 2))
            if 4 <= resolution <= 32:
                self.model[f'resolution{resolution}'] = DiscriminatorBlock(resolution, max_channel, max_channel)
                self.from_rgb.append(self._from_rgb(max_channel))
            else:
                self.model[f'resolution{resolution}'] = DiscriminatorBlock(resolution, 
                                                                           max_channel//np.power(2, resolution_idx-3), 
                                                                           max_channel//np.power(2, resolution_idx-4))
                self.from_rgb.append(self._from_rgb(max_channel//np.power(2, resolution_idx-3)))
        
        self.fade_in = FadeIn()
        
    def forward(self, x, level):
        resolution = np.power(2, level + 2)
        residual = x.clone()
        x = self.from_rgb[self.num_resolution-level-1](x)
        
        if level == 0:
            x = self.model[f'resolution{resolution}'](x)
            return x.view(-1)
        
        residual = F.avg_pool2d(residual, kernel_size=2, stride=2)
        residual = self.from_rgb[self.num_resolution-level](residual)
        x = self.model[f'resolution{resolution}'](x)
        x = self.fade_in(residual, x)
        for block_idx, block in enumerate(self.model.values()):
            if block_idx >= (self.num_resolution - level):
                x = block(x)
        return x.view(-1)
    
    
    def _from_rgb(self, out_channels):
        return nn.Conv2d(3, out_channels=out_channels, kernel_size=1)
    

## Loss

In [90]:
class Loss:
    def __init__(self, G, D):
        self.G = G
        self.D = D
        
        
class WGAN(Loss):
    def __init__(self, G, D, wgan_epsilon=0.001, wgan_lambda=10.0, wgan_target=1.0, device='cpu'):
        super().__init__(G, D)
        self.device = device
        self.wgan_epsilon = wgan_epsilon
        self.wgan_lambda = wgan_lambda
        self.wgan_target = wgan_target
        if device.type == 'cuda':
            self.G = self.G.module
            self.D = self.D.module

    def G_wgan(self, latent_code, level):
        # Loss_G = -D(G(z))
        w = self.G.mapping(latent_code)
        fake_images = self.G.synthesis(w, level)
        fake_out = self.D(fake_images, level)
        loss = -torch.mean(fake_out)
        return loss, fake_images
    
    def D_wgan(self, latent_code, real_images, level):
        # Loss_D = D(G(z)) - D(x) + ε·D(x)^2
        w = self.G.mapping(latent_code)
        fake_images = self.G.synthesis(w, level)
        real_out = self.D(real_images, level)
        fake_out = self.D(fake_images, level)
        loss = torch.mean(fake_out) - torch.mean(real_out) + self.wgan_epsilon * torch.square(real_out)
        return loss, fake_images
    
    def D_wgan_gp(self, latent_code, real_images, level):
        # Loss_D = D(G(z)) - D(x) + η·(||∇T||-1)^2 + ε·D(x)^2
        # D(G(z)) - D(x) + ε·D(x)^2
        wgan_loss, fake_images = self.D_wgan(latent_code, real_images, level)
        
        # η·(||∇T||-1)^2  梯度惩罚
        alpha = torch.rand(latent_code.shape[0], 1, 1, 1).uniform_(-1, 1).to(self.device)
        interpolates = real_images.data * alpha + (1 - alpha) * fake_images.data  # 真实分布与生成分布之间的插值
        interpolates.requires_grad = True
        interpolates_pred = self.D(interpolates, level)
        gradients = torch.autograd.grad(outputs=interpolates_pred.sum(), inputs=interpolates, create_graph=True)[0] # 求梯度
        slopes = torch.sqrt(torch.sum(torch.square(gradients), dim=(1, 2, 3)))# 取模
        gradient_penalty = (self.wgan_lambda / self.wgan_target ** 2) * torch.mean((slopes - self.wgan_target)**2)
        return (wgan_loss + gradient_penalty).mean()

## Train

In [112]:
class TrainProcessing:
    def __init__(self, 
                 dataset,
                 epoch = 800_000,
                 latent_code_dim = 512,
                 num_resolution = 5,
                 batch_size = 16,
                 num_gpu = 4,
                 beta1 = 0.5,
                 init_size = 4,
                 max_size = 64,
                 wgan_epsilon=0.001, 
                 wgan_lambda=10.0, 
                 wgan_target=1.0
                ):
        self.dataset = dataset
        
        # 图像尺寸训练范围
        self.init_size = init_size
        self.max_size = max_size
        
        # 超参数设置
        self.epoch = epoch
        self.latent_code_dim = latent_code_dim
        self.num_resolution = num_resolution
        self.batch_size = batch_size
        
        # gpu
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.num_gpu = num_gpu
        IS_PARALLEL = True if num_gpu > 1 else False

        self.generator = StyledGenerator(num_resolution=num_resolution, w_channel=latent_code_dim, device=self.device)
        self.discriminator = Discriminator(num_resolution=num_resolution)
        
        if IS_PARALLEL and torch.cuda.device_count() > 1:
            self.generator = nn.DataParallel(self.generator, device_ids=range(num_gpu)).cuda()
            self.discriminator = nn.DataParallel(self.discriminator, device_ids=range(num_gpu)).cuda()
        
#         self.generator = self.generator.to(self.device)
#         self.discriminator = self.discriminator.to(self.device)

        # 优化器
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002, betas=(beta1, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(beta1, 0.999))

        # 初始化损失函数
        self.loss = WGAN(self.generator, self.discriminator, device=self.device)
        
    
    def train(self):
        
        level = int(math.log2(self.init_size)) - 2
        resolution = 4 * 2 ** level
        
        train_loader = self._sample_data(self.dataset, resolution, self.batch_size)
        data_loader = iter(train_loader)
        
        loss_dict = {
            'disc_loss_val': 0,
            'gen_loss_val': 0,
        }
        
        used_sample = 0
        phase = 60_000
        
        pbar = tqdm.tqdm(range(self.epoch))
        
        max_level = int(math.log2(self.max_size)) - 2
        final_progress = False
        
        for iteration in pbar:
            alpha = min(1, 1 / phase * (used_sample + 1))
            
            if used_sample > phase * 2:
                used_sample = 0
                level += 1

                if level > max_level:
                    level = max_level
                    final_progress = True
                else:
                    self._update_alpha(1)
                    
                ckpt_level = level

                resolution = 4 * 2 ** level

                train_loader = self._sample_data(self.dataset, resolution, self.batch_size)
                data_loader = iter(train_loader)

                self._save_model(ckpt_level)

            try:
                data = next(data_loader)
            except StopIteration:
                data_loader = iter(train_loader)
                data = next(data_loader)
                
            real_images = data.to(self.device)
            used_sample += real_images.shape[0]

            # 训练判别器
            self._requires_grad(self.generator, False)
            self._requires_grad(self.discriminator, True)

            latent = torch.randn(self.batch_size, self.latent_code_dim, 1, 1, device=self.device)
            # 计算损失
            d_loss = self.loss.D_wgan_gp(latent, real_images, level)
            self.discriminator.zero_grad()
            d_loss.backward()
            self.discriminator_optimizer.step()
            
            loss_dict['disc_loss_val'] = d_loss.item()

            # 训练生成器
            self._requires_grad(self.generator, True)
            self._requires_grad(self.discriminator, False)

            latent = torch.randn(self.batch_size, self.latent_code_dim, 1, 1, device=self.device)
            g_loss, _ = self.loss.G_wgan(latent, level)
            
            loss_dict['gen_loss_val'] = g_loss.item()

            self.generator.zero_grad()
            g_loss.backward()
            self.generator_optimizer.step()
            
            self._requires_grad(self.generator, False)
            self._requires_grad(self.discriminator, True)
            
            # 设置print内容
            state_msg = (
                f'Size: {4 * 2 ** level}; G: {loss_dict["gen_loss_val"]:.3f}; D: {loss_dict["disc_loss_val"]:.3f};'
                f' Alpha: {alpha:.5f}'
            )
            pbar.set_description(state_msg)
    
    def _requires_grad(self, model, flag=True):
        for p in model.parameters():
            p.requires_grad = flag
            
    def _sample_data(self, dataset, resolution, batch_size):
        dataset.resolution = resolution
        train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
        
        return train_loader
    
    def _update_alpha(self, alpha):
        if self.device.type == 'cuda':
            self.generator.module.synthesis.fade_in.alpha = alpha
            self.discriminator.module.fade_in.alpha = alpha
        else:
            self.generator.synthesis.fade_in.alpha = alpha
            self.discriminator.module.fade_in.alpha = alpha
        
    def _save_model(self, ckpt_level):
        if self.num_gpu > 0:
            g = self.generator.module
            d = self.discriminator.module
        else:
            g = self.generator
            d = self.discriminator
        
        torch.save(
            {
                'generator': g.state_dict(),
                'discriminator': d.state_dict(),
                'g_optimizer': self.generator_optimizer.state_dict(),
                'd_optimizer': self.discriminator_optimizer.state_dict(),
            },
            f'checkpoint/train_step-{ckpt_level}.model',
        )


In [93]:
image_transforms = transforms.Compose([
#         transforms.Resize(64),
#         transforms.CenterCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = AnimeFace(root_folder='/amax/data/LHao/dataset/', transform=image_transforms)

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
trainer = TrainProcessing(dataset)
trainer.train()

Size: 64; G: -35841.098; D: 948020.375; Alpha: 1.00000:   4%|▍         | 34430/800000 [1:09:47<85:11:04,  2.50it/s]        