# Non-stationary texture synthesis using adversarial expansions

In [0]:
import os
from google.colab import drive

DIR = "/content/gdrive/My Drive/sandbox/Chainer_NonStationaryTextureSynthesis/"
IMG_NAMES = ["spiral.jpg", "003.jpg"]

drive.mount('/content/gdrive')
os.chdir(DIR)

Mounted at /content/gdrive


In [0]:
#!pip install chainer-6.4.0.tar.gz
#!pip install 'cupy-cuda100>=6.4.0,<7.0.0'

In [0]:
from PIL import Image

import numpy as np
from numpy.lib.stride_tricks import as_strided
import matplotlib
import cv2
matplotlib.use('Agg')

import chainer
from chainer import cuda
import chainer.functions as F
import chainer.links as L
from chainer import Variable
from chainer.training import extensions

chainer.print_runtime_info()

Platform: Linux-4.14.137+-x86_64-with-Ubuntu-18.04-bionic
Chainer: 5.4.0
NumPy: 1.17.3
CuPy:
  CuPy Version          : 5.4.0
  CUDA Root             : /usr/local/cuda
  CUDA Build Version    : 10000
  CUDA Driver Version   : 10010
  CUDA Runtime Version  : 10000
  cuDNN Build Version   : 7301
  cuDNN Version         : 7301
  NCCL Build Version    : 2402
  NCCL Runtime Version  : 2402
iDeep: 2.0.0.post3


## Residual Block

In [0]:
class ResBlock(chainer.Chain):
    def __init__(self, ch, bn=True, activation=F.relu):
        self.use_bn = bn
        self.activation = activation
        super(ResBlock, self).__init__()
        with self.init_scope():
            self.c1 = L.Convolution2D(ch, ch, 3, 1, 1)
            self.c2 = L.Convolution2D(ch, ch, 3, 1, 1)
            self.bn1 = L.BatchNormalization(ch)
            self.bn2 = L.BatchNormalization(ch)

    def forward(self, x):
        h = self.c1(x)
        h = self.bn1(h) if self.use_bn else h
        h = self.activation(h)
        h = self.c2(h)
        h = self.bn2(h) if self.use_bn else h
        return h + x

## CBR (Conv + Normalization + Activation)

In [0]:
# discriminator の訓練時に乗せるノイズ
def add_noise(h, sigma=0.2):
    if not chainer.config.train:
        return h
    xp = cuda.get_array_module(h.data)
    return h + sigma * xp.random.randn(*h.data.shape)

# (ksize, strides, padding)
ksp = {
    'down': (4, 2, 1),
    'none-9': (9, 1, 4),
    'none-7': (7, 1, 3),
    'none-5': (5, 1, 2),
    'none': (3, 1, 1),
    'up': (3, 1, 1),
    'none-4': (4, 1, 1),
    'none-1': (1, 1, 0)
}

class CBR(chainer.Chain):
    def __init__(self, ch0, ch1, bn=True, sample='down', activation=F.relu, dropout=False, noise=False):
        self.use_bn = bn
        self.activation = activation
        self.dropout = dropout
        self.sample = sample
        self.noise = noise
        w = chainer.initializers.Normal(0.02)
        super(CBR, self).__init__()
        with self.init_scope():
            self.c = L.Convolution2D(ch0, ch1, *(ksp[sample]), initialW=w)
            if self.use_bn:
                self.bn = L.BatchNormalization(ch1)

    def forward(self, x):
        #h = L.Deconvolution2D(3, 1, 1)(x) if self.sample == "up" else self.c(x)
        h = F.unpooling_2d(x, 2, 2, 0, cover_all=False) if self.sample == "up" else x
        h = self.c(h)

        h = self.bn(h) if self.use_bn else h
        h = add_noise(h) if self.noise else h
        h = F.dropout(h) if self.dropout else h
        h = self.activation(h) if self.activation else h
        return h

## Generator

In [0]:
class Generator(chainer.ChainList):
    def __init__(self):
        super(Generator, self).__init__()
        with self.init_scope():
            layers = [
                CBR(3, 3, bn=True, sample='none-7'),
                CBR(3, 64, bn=True, sample='none'),
                CBR(64, 128, bn=True, sample='down'),
                CBR(128, 256, bn=True, sample='down'),
                ResBlock(256, bn=True),
                ResBlock(256, bn=True),
                ResBlock(256, bn=True),
                ResBlock(256, bn=True),
                ResBlock(256, bn=True),
                ResBlock(256, bn=True),
                CBR(256, 512, bn=True, sample='none'),
                CBR(512, 256, bn=True, sample='up'),
                CBR(256, 128, bn=True, sample='up'),
                CBR(128, 64, bn=True, sample='up'),
                CBR(64, 3, bn=True, sample='none-7', activation=F.tanh)
            ]
            for layer in layers:
                self.add_link(layer)

    def __call__(self, x):
        for f in self.children():
            x = f(x)
        return x

## Discriminator

In [0]:
class Discriminator(chainer.ChainList):
    def __init__(self, in_ch=3, n_down_layers=4):
        super(Discriminator, self).__init__()
        with self.init_scope():
            layers = [
                CBR(3, 64, bn=False, sample='down', activation=F.leaky_relu, noise=True),
                CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, noise=True),
                CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, noise=True),
                CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, noise=True),
                CBR(512, 512, bn=True, sample='none-4', activation=F.leaky_relu, noise=True),
                CBR(512, 1, bn=False, sample='none-1', activation=None, noise=True)
            ]
            
            for layer in layers:
                self.add_link(layer)


    def __call__(self, x):
        x = add_noise(x)
        for f in self.children():
            x = f(x)
        return x

## Updater

In [0]:
def make_optimizer(model, alpha=0.0002, beta1=0.5):
    optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(0.0001), 'hook_dec')
    return optimizer

In [0]:
gen = Generator().to_gpu()
dis = Discriminator().to_gpu()

opt_gen = make_optimizer(gen)
opt_dis = make_optimizer(dis)

In [0]:
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
# 1000/(64 x 64), 1000/(128 x 128), 1000/(256 x 256), 1000/(512 x 512), 1000/(512 x 512)
style_weights = [0.244, 0.061, 0.015, 0.004, 0.004]

class DCGANUpdater(chainer.training.updaters.StandardUpdater):

    def __init__(self, *args, **kwargs):
        self.gen, self.dis = kwargs.pop('models')
        self.vgg19 = L.VGG19Layers()
        self.vgg19.to_gpu()
        super(DCGANUpdater, self).__init__(*args, **kwargs)

    def loss_dis(self, dis, y_fake, y_real):
        xp = dis.xp
        batchsize = len(y_fake)
        #L1 = F.mean_squared_error(y_real, xp.ones(y_real.array.shape, dtype='float32'))
        #L2 = F.mean_squared_error(y_fake, xp.zeros(y_fake.array.shape, dtype='float32'))
        L1 = F.sum(F.softplus(-y_real)) / batchsize
        L2 = F.sum(F.softplus(y_fake)) / batchsize
        loss = L1 / 2. + L2 / 2.
        chainer.report({'loss': loss}, dis)
        chainer.report({'loss/false_fake': L1}, dis)
        chainer.report({'loss/false_real': L2}, dis)
        return loss
    
    def gram_matrix(self, y):
        xp = self.gen.xp
        b, ch, h, w = y.data.shape
        features = F.reshape(y, (b, ch, w*h))
        gram = F.matmul(features, features, transb=True) / xp.float32(2*ch*w*h)**2
        return gram

    def loss_gen(self, gen, y_fake, x_real, x_fake):
        xp = gen.xp
        
        with chainer.using_config('train', False):
            with chainer.using_config('enable_backprop', False):
                x_real_feat = self.vgg19.extract([(chainer.backends.cuda.to_cpu(x_real.array)[0][:, :, ::-1] + 1) *127.5], layers=style_layers)
                x_fake_feat = self.vgg19.extract([(chainer.backends.cuda.to_cpu(x_fake.array)[0][:, :, ::-1] + 1) *127.5], layers=style_layers)
        
        loss_style = 0
        for layer, w in zip(style_layers, style_weights):
            loss_style += w * F.mean_squared_error(
                self.gram_matrix(F.relu(x_real_feat[layer])),
                self.gram_matrix(F.relu(x_fake_feat[layer])),
            )
        
        batchsize = len(y_fake)
        loss_L1 = F.mean_absolute_error(x_real, x_fake)
        #loss_adv = F.mean_squared_error(y_fake, xp.ones(y_fake.array.shape, dtype='float32')) / 2.
        loss_adv = F.sum(F.softplus(-y_fake)) / batchsize
        
        alpha = 100
        beta = 1


        loss = loss_adv + alpha * loss_L1 + beta * loss_style
        
        chainer.report({'loss': loss}, gen)
        chainer.report({'loss/adv': loss_adv}, gen)
        chainer.report({'loss/L1': alpha * loss_L1}, gen)
        chainer.report({'loss/style': beta * loss_style}, gen)
        return loss

    def update_core(self):
        gen_optimizer = self.get_optimizer('gen')
        dis_optimizer = self.get_optimizer('dis')
        batch = self.get_iterator('main').__next__()
        batchsize = len(batch)

        xp = self.gen.xp
        ch, rk, ck = batch[0][0].shape
        ch, r2k, c2k = batch[0][1].shape

        batch_k = xp.zeros((batchsize, ch, rk, ck)).astype("f")
        batch_2k = xp.zeros((batchsize, ch, r2k, c2k)).astype("f")

        for i in range(batchsize):
            batch_k[i, :] = xp.asarray(batch[i][0])
            batch_2k[i, :] = xp.asarray(batch[i][1])

        gen, dis = self.gen, self.dis

        # x_real: 入力画像からの2kパッチ
        x_real = Variable(batch_2k)
        y_real = dis(x_real)

        xp = chainer.backends.cuda.get_array_module(x_real.data)

        # x_fake: kパッチから生成した2kパッチ
        x_fake = gen(batch_k)
        y_fake = dis(x_fake)

        dis_optimizer.update(self.loss_dis, dis, y_fake, y_real)
        gen_optimizer.update(self.loss_gen, gen, y_fake, x_real, x_fake)

## Dataset

In [0]:
class PatchDataset(chainer.dataset.DatasetMixin):
    def __init__(self, image_path, crop_size=128):
        self.image = cv2.imread(image_path).astype('float32')
        #self.image = cv2.pyrDown(cv2.pyrDown(self.image))
        #cv2.imwrite(image_path+'small.jpg', self.image.astype('uint8'))
        self.k = crop_size
        assert isinstance(self.k, int) and 1<=self.k and 2*self.k<=min(*self.image.shape[:2]) 
        r, c, ch = self.image.shape
        self.n_patch_of_k = (r-self.k) * (c-self.k)
        self.n_patch_of_2k = (r-2*self.k) * (c-2*self.k)
    
    def __len__(self):
        return self.n_patch_of_k
    
    def get_example(self, i):
        r, c, ch = self.image.shape
        
        # 2k*2k patch の左上のindex
        r0_2k = np.random.randint(0, r - 2*self.k)
        c0_2k = np.random.randint(0, c - 2*self.k)
        # k*k patch をその中から取る
        r0_k = np.random.randint(r0_2k, r0_2k + self.k - 1)
        c0_k = np.random.randint(c0_2k, c0_2k + self.k - 1)
        
        img_k = self.image[r0_k:(r0_k + self.k), c0_k:(c0_k + self.k), :]
        img_k = self.preprocess(img_k)
        img_2k = self.image[r0_2k:(r0_2k + 2*self.k), c0_2k:(c0_2k + 2*self.k), :]
        img_2k = self.preprocess(img_2k)
        
        return img_k, img_2k
    
    def preprocess(self, img):
        img = img.astype("f")
        img = img / 127.5 - 1
        img = img.transpose((2, 0, 1))
        return img

    def postprocess(self, img):
        img = (img + 1) *127.5
        img = np.clip(img, 0, 255)
        img = img.astype(np.uint8)
        img = img.transpose((1, 2, 0))
        return img
        
        

## Let's train

### params

In [0]:
n_iter = 100000  # number of epochs
batchsize = 1  # minibatch size
snapshot_interval = 1000  # number of iterations per snapshots
imsave_interval = 1000
display_interval = 100  # number of iterations per display the status
plot_interval = 100
gpu_id = 0
seed = 0  # random seed

in_dir = 'dataset/'
tmp_train_snapshot = 'last_snapshot.npz'

def out_generated_image(gen, dataset, seed, dst_path):
    @chainer.training.make_extension()
    def make_image(trainer):
        np.random.seed(seed)
        xp = gen.xp
        img = xp.asarray(dataset.preprocess(dataset.image))
        with chainer.using_config('train', False):
            x = gen(img.reshape((1,*img.shape)))
        x = chainer.backends.cuda.to_cpu(x.array)
        x = dataset.postprocess(x[0])

        preview_path = dst_path + 'output_itr{:0>8}.jpg'.format(trainer.updater.iteration)
        if not os.path.exists(dst_path):
            os.makedirs(dst_path)
        cv2.imwrite(preview_path, x)
    return make_image

### extensions & run

In [0]:
%%capture
for imname in IMG_NAMES:
    in_img = in_dir + imname
    out_dir = 'out/' + imname + '_result/'
    out_img_dir = out_dir + 'img/'

    patchDataset = PatchDataset(in_img)

    train_iter = chainer.iterators.SerialIterator(patchDataset, batchsize)

    updater = DCGANUpdater(models=(gen, dis), iterator=train_iter,
                        optimizer={ 'gen': opt_gen, 'dis': opt_dis }, device=gpu_id)
    trainer = chainer.training.Trainer(updater, (n_iter, 'iteration'), out=out_dir)

    trainer.extend(
        extensions.snapshot(filename=tmp_train_snapshot),
        trigger=(snapshot_interval, 'iteration')
    )
    trainer.extend(
        extensions.snapshot_object(gen, 'gen.npz'),
        trigger=(snapshot_interval, 'iteration')
    )
    trainer.extend(
        extensions.snapshot_object(dis, 'dis.npz'),
        trigger=(snapshot_interval, 'iteration')
    )
    trainer.extend(
        extensions.LogReport(
            trigger=(display_interval, 'iteration'),
            log_name='log.json'
        )
    )

    # on Colab, stdout has better be disabled for auto refresh
    # trainer.extend(
    #     extensions.PrintReport(['iteration', 'gen/loss', 'dis/loss']),
    #     trigger=(display_interval, 'iteration')
    # )
    # trainer.extend(extensions.ProgressBar(update_interval=plot_interval))

    trainer.extend(
        extensions.PlotReport(
            ['gen/loss/adv', 'dis/loss'],
            x_key='iteration', trigger=(plot_interval, 'iteration'),
            file_name='adv_loss_plot.png'
        )
    )

    trainer.extend(
        extensions.PlotReport(
            ['gen/loss/adv', 'gen/loss/L1', 'gen/loss/style'],
            x_key='iteration', trigger=(plot_interval, 'iteration'),
            file_name='gen_loss_plot.png'
        )
    )

    trainer.extend(
        extensions.PlotReport(
            ['dis/loss/false_fake', 'dis/loss/false_real'],
            x_key='iteration', trigger=(plot_interval, 'iteration'),
            file_name='dis_loss_plot.png'
        )
    )

    trainer.extend(
        out_generated_image(gen, patchDataset, seed, out_img_dir),
        trigger=(imsave_interval, 'iteration')
    )

    if os.path.exists(out_dir+tmp_train_snapshot):
        chainer.serializers.load_npz(out_dir+tmp_train_snapshot, trainer)

    trainer.run()