In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image, ImageFile
from pickle import load, dump
import cv2
import time
import argparse
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
class PixelwiseNormalization(nn.Module):
    def pixel_norm(self, x):
        eps = 1e-8
        return x * torch.rsqrt(torch.mean(x * x, 1, keepdim=True) + eps)
    
    def forward(self, x):
        return self.pixel_norm(x)

# Adaptive Instance Normalization
class AdaIn(nn.Module):
    def __init__(self, n_channel, dim_latent):
        super().__init__()
        self.norm = nn.InstanceNorm2d(n_channel)
        self.transform = nn.Linear(dim_latent, n_channel * 2)
        self.transform.bias.data[n_channel:] = 0
        self.transform.bias.data[:n_channel] = 1
        
    def forward(self, image, style):
        factor, bias = self.transform(style).unsqueeze(2).unsqueeze(3).chunk(2, 1)
        result = self.norm(image)
        result = result * factor + bias
        return result

class Noise(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))
    
    def forward(self, image):
        noise = image.new_empty(image.size(0), 1, image.size(2), image.size(3)).normal_()
        result = image + self.weight * noise
        return result

class GeneratorBlock(nn.Module):
    def __init__(self, input_nc, output_nc, num_channels, dim_latent, first=False):
        super().__init__()
        
        if first:
            self.conv1 = nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=1, padding=3)
        else:
            self.conv1 = nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=1, padding=1)
        self.noise1 = Noise()
        self.adain1 = AdaIn(output_nc, dim_latent)
        self.activate1 = nn.LeakyReLU(0.2, inplace=True)
        
        self.conv2 = nn.ConvTranspose2d(output_nc, output_nc, kernel_size=4, stride=2, padding=1)  # upsample
        self.noise2 = Noise()
        self.adain2 = AdaIn(output_nc, dim_latent)
        self.activate2 = nn.LeakyReLU(0.2, inplace=True)
        
        self.toRGB = nn.Conv2d(output_nc, num_channels, kernel_size=1, stride=1, padding=0)
        
    def forward(self, image, style, last=False):
        image = self.conv1(image)
        image = self.noise1(image)
        image = self.adain1(image, style)
        image = self.activate1(image)
        
        image = self.conv2(image)
        image = self.noise2(image)
        image = self.adain2(image, style)
        image = self.activate2(image)
        
        if last:
            image = self.toRGB(image)
        return image

class DiscriminatorBlock(nn.Module):
    def __init__(self, input_nc, output_nc, num_channels, last=False):
        super().__init__()
        
        self.fromRGB = nn.Conv2d(num_channels, input_nc, kernel_size=1, stride=1, padding=0)
        
        if not last:
            self.model = nn.Sequential(
                nn.Conv2d(input_nc, input_nc, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1),  # downsample
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.model = nn.Sequential(
                nn.Conv2d(input_nc + 1, input_nc, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1),  # downsample
                nn.LeakyReLU(0.2, inplace=True)
            )
        
        self.last = last
        
    def minibatch_standard_deviation(self, x):
        eps = 1e-8
        return torch.cat([x, torch.sqrt(((x - x.mean())**2).mean() + eps).expand(x.shape[0], 1, *x.shape[2:])], dim=1)

    def forward(self, x, first=False):
        if first:
            x = self.fromRGB(x)
        if self.last:
            x = self.minibatch_standard_deviation(x)
        x = self.model(x)
        return x

class MappingNetwork(nn.Module):
    def __init__(self, dim_latent, num_depth):
        super().__init__()
        
        modules = [PixelwiseNormalization()]
            
        for _ in range(num_depth):
            modules += [nn.Linear(dim_latent, dim_latent)]
            modules += [nn.LeakyReLU(0.2)]
            
        self.module = nn.Sequential(*modules)
        
    def forward(self, x):
        x = self.module(x)
        return x

class Generator(nn.Module):
    def __init__(self, num_depth, num_channels, num_fmap, num_mapping, input_size=None):
        super().__init__()
        
        if input_size is None:
            self.input_size = num_fmap(0)
        else:
            self.input_size = input_size
            
        self.constant_input = nn.Parameter(torch.ones((1, self.input_size), dtype=torch.float32))
        
        self.style = MappingNetwork(self.input_size, num_mapping)
        
        self.blocks = nn.ModuleList(
            [GeneratorBlock(self.input_size, num_fmap(1), num_channels, self.input_size, first=True)]
            + [GeneratorBlock(num_fmap(i), num_fmap(i + 1), num_channels, self.input_size) for i in range(1, num_depth)])
        
        self.activation = nn.Sigmoid()
        
        self.depth = 0
        self.alpha = 1.0

    def forward(self, styles, input_is_style=False):
        if not input_is_style:
            styles = [self.style(z) for z in styles]
        for _ in range(len(self.blocks) - len(styles)):
            styles += [styles[-1]]
        
        x = self.constant_input.expand(styles[0].size(0), self.input_size).unsqueeze(-1).unsqueeze(-1)
        
        rgb = x = self.blocks[0](x, styles[0], self.depth == 0)
        
        if self.depth > 0:
            for i in range(self.depth - 1):
                x = self.blocks[i+1](x, styles[i+1])
            rgb = self.blocks[self.depth](x, styles[self.depth], last=True)
            if self.alpha < 1.0:
                prev_rgb = self.blocks[self.depth - 1].toRGB(x)
                prev_rgb = F.interpolate(prev_rgb, mode='bilinear', scale_factor=2, align_corners=True, recompute_scale_factor=True)
                rgb = (1 - self.alpha) * prev_rgb + self.alpha * rgb
            
        rgb = self.activation(rgb)
        
        return rgb, styles
    
class Discriminator(nn.Module):
    def __init__(self, num_depth, num_channels, num_fmap):
        super().__init__()

        self.blocks = nn.ModuleList(
            [DiscriminatorBlock(num_fmap(i), num_fmap(i-1), num_channels) for i in range(num_depth, 1, -1)]
            + [DiscriminatorBlock(num_fmap(1), num_fmap(0), num_channels, last=True)])

        # PatchGAN
        self.conv_last = nn.Conv2d(num_fmap(0), 1, kernel_size=3, stride=1, padding=1)
        
        self.depth = 0
        self.alpha = 1.0
        
    def forward(self, x):
        out = self.blocks[-(self.depth + 1)](x, first=True)
        
        if self.depth > 0 and self.alpha < 1.0:
            x = F.interpolate(x, mode='bilinear', scale_factor=0.5, align_corners=True, recompute_scale_factor=True)
            prev = self.blocks[-self.depth].fromRGB(x)
            out = self.alpha * out + (1 - self.alpha) * prev
                
        for i in range(self.depth, 0, -1):
            out = self.blocks[-i](out)
            
        out = self.conv_last(out)
        
        return out

In [None]:
def showImage(image):
    %matplotlib inline
    import matplotlib.pyplot as plt

    PIL = transforms.ToPILImage()
    ToTensor = transforms.ToTensor()

    img = PIL(image)
    fig = plt.figure(dpi=200)
    ax = fig.add_subplot(1, 1, 1) # (row, col, num)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.imshow(img)
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
use_cuda = torch.cuda.is_available()
device = 'cpu'#torch.device("cuda" if use_cuda else "cpu")
num_mapping = 8
image_size = 128
max_depth = int(np.log2(image_size)) - 2
def num_fmap(stage):
    base_size = image_size * 2
    fmap_base = base_size * 4
    fmap_max = base_size // 2
    fmap_decay = 1.0
    return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
feed_dim = num_fmap(0)
net = Generator(max_depth, 3, num_fmap, num_mapping).to(device)
net.depth = 4
net.load_state_dict(torch.load('weight_G.pth', map_location=device))

In [None]:
style_feeds = [torch.randn(1, feed_dim).to(device)]
image, style = net(style_feeds)
showImage(image[0])

In [None]:
feeds1 = style_feeds

In [None]:
style_feeds = [torch.randn(1, feed_dim).to(device)]
image, style = net(style_feeds)
showImage(image[0])

In [None]:
feeds2 = style_feeds

In [None]:
%matplotlib widget
fig = plt.figure(dpi=200)
ax = fig.add_subplot(1, 1, 1) # (row, col, num)
ax.set_xticks([])
ax.set_yticks([])
images = []
for i in range(300):
    l = i / 300.0
    x = l * feeds1[0] + (1 - l) * feeds2[0]
    image, _ = net([x])
    PIL = transforms.ToPILImage()
    img = PIL(image[0])
    images += [[plt.imshow(img, animated=True)]]
ani = animation.ArtistAnimation(fig, images, interval=10, repeat_delay=1000)
ani.save('anim.gif', writer="imagemagick")
plt.show()