In [0]:
# target
target = 'horse'

# hyper params
iteration = 100000
batchsize = 64
initial_scale = 1

# extension params
snapshot_interval = iteration//10
display_interval = iteration//100
update_interval = display_interval
log_interval = iteration//1000
tweet_interval = display_interval

# model params
latent_size = 128
sa_gamma = 1.
gen_noise = 5e-2
sa_endpoint = 50000

# learning controller
learning_rate = 1e-4
grad_clip = None
grad_decay = 1e-5
load_weight = True
load_opt = False
load_dis = True
save_opt = False
save_dis = True
IMG_SIZE = 16
IMG_SHAPE = (IMG_SIZE, IMG_SIZE)
GPU = 0

# file names
mount = './'
OUT = '{}/Drive_sync/result/'.format(mount)
dataset = '{}/Drive_sync/picture/{}_pic/**/*'.format(mount, target)

gen_name = '{}_gen'.format(target)
dis_name = '{}_dis'.format(target)

opt_gen_name = 'opt_{}_gen'.format(target)
opt_horse_dis_name = 'opt_{}_dis'.format(target)

In [0]:
import numpy
import math
import glob
import random
import io
import uuid

from PIL import Image, ImageOps, ImageChops, ImageFilter
import matplotlib.pyplot as plt
import cupy

import chainer
from chainer import training, backend, Variable
from chainer.training import extensions
import chainer.functions as F
import chainer.links as L
import chainer.backends.cuda
import chainer.link_hooks as LH

from IPython.display import clear_output
import tweepy

import twitter_api_key

7.4.0


--------------------------------------------------------------------------------
CuPy (cupy) version 6.0.0 may not be compatible with this version of Chainer.
Please consider installing the supported version by running:
  $ pip install 'cupy>=6.3.0,<7.0.0'

See the following page for more details:
  https://docs-cupy.chainer.org/en/latest/install.html
--------------------------------------------------------------------------------

  requirement=requirement, help=help))


In [0]:
def make_optimizer_Adam(model, alpha=1e-4, beta1=0.5, clip=None, decay=None):
    optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
    optimizer.setup(model)
    if clip:
        optimizer.add_hook(chainer.optimizer_hooks.GradientClipping(clip))
    if decay:
        optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(decay))
    return optimizer

In [0]:
def gaussian(size):
    return F.gaussian(cupy.zeros(size, dtype=cupy.float32),
                     cupy.ones(size, dtype=cupy.float32))
    
def zeropad(x, ch):
    return F.pad(x, ((0, 0), (0, ch-x.shape[1]), (0, 0), (0, 0)), 'constant', constant_values=0)

def gap(x):
    return F.average_pooling_2d(x, x.shape[-2:])
    
def noise_injection(x, k):
    return F.gaussian(x, instance_var(x)*k)
    
def instance_var(x):
    _shape = x.shape
    _x = F.reshape(x, _shape[:2]+(-1,))
    _ = F.mean(_x, axis=2, keepdims=True)
    _ = (_x - _)**2
    _ = F.mean(_, axis=2, keepdims=True)
    _ = F.broadcast_to(_[:,:,:,None], _shape)
    return _

def upsample(x):
    return F.depth2space(x, 2)

def instance_normalization(x):
    _shape = x.shape
    x = F.reshape(x, _shape[:2]+(-1,))
    x = F.normalize(x, axis=2)
    return F.reshape(x, _shape)

In [0]:
class Conv(chainer.Chain):
    def __init__(self, ch, kernel=1, stride=1, padding=0, wscale=1.):
        super(Conv, self).__init__()

        with self.init_scope():
            w = chainer.initializers.HeNormal(wscale)
            self.c = L.Convolution2D(None, ch, kernel, stride, padding, initialW=w).add_hook(LH.SpectralNormalization())
    
    def __call__(self, x):
        h = F.leaky_relu(x)
        h = self.c(h)
        return h

In [0]:
class SelfAttentionBlock(chainer.Chain):
    def __init__(self, ch, wscale=1., gamma=1.):
        super(SelfAttentionBlock, self).__init__()

        self.gamma = gamma
        with self.init_scope():
            self.cf = Conv(ch//8, wscale=wscale)
            self.cg = Conv(ch//8, wscale=wscale)
            self.ch = Conv(ch, wscale=wscale)
            
    def __call__(self, x):
        f = self.cf(x)
        g = self.cg(x)
        h = self.ch(x)
        f = F.reshape(f, f.shape[:2]+(-1,))
        g = F.reshape(g, g.shape[:2]+(-1,))
        h = F.reshape(h, h.shape[:2]+(-1,))
        
        attention_map = F.batch_matmul(f, g, transa=True)
        attention_map = F.softmax(attention_map, axis=-1)
        feature_map = F.batch_matmul(h, attention_map, transb=True)
        feature_map = F.reshape(feature_map, x.shape)
        return F.add(x, feature_map*self.gamma)

    def set_gamma(self, gamma):
        self.gamma = gamma

In [0]:
class InceptionBlock(chainer.Chain):
    def __init__(self, in_ch, out_ch, wscale=1.):
        super(InceptionBlock, self).__init__()

        with self.init_scope():
            self.layer_11 = Conv(in_ch, wscale=wscale)
            self.layer_33 = Conv(in_ch*2, 3, 1, 1, wscale=wscale)
            self.layer_55 = Conv(in_ch*2, 5, 1, 2, wscale=wscale)
            self.layer_77 = Conv(in_ch*3, 7, 1, 3, wscale=wscale)

            self.c = Conv(out_ch, wscale=wscale)
            
    def __call__(self, x):
        h_11 = self.layer_11(x)
        h_33 = self.layer_33(x)
        h_55 = self.layer_55(x)
        h_77 = self.layer_77(x)
        
        h = F.concat((h_11, h_33, h_55, h_77))
        h = self.c(h)
        return h

In [0]:
class Dense(chainer.Chain):
    def __init__(self, ch, wscale=1.):
        super(Dense, self).__init__()
        
        with self.init_scope():
            w = chainer.initializers.HeNormal(wscale)
            self.l = L.Linear(None, ch, initialW=w).add_hook(LH.SpectralNormalization())

    def __call__(self, x):
        h = F.leaky_relu(x)
        h = self.l(h)
        return h

In [0]:
class Affine(chainer.Chain):
    def __init__(self, mid_ch=64, w_ch=128, wscale=1.):
        super(Affine, self).__init__()
        
        with self.init_scope():
            self.l1 = Dense(mid_ch, wscale=wscale)
            self.l2 = Dense(mid_ch, wscale=wscale)
            self.l3 = Dense(mid_ch, wscale=wscale)
            self.l4 = Dense(mid_ch, wscale=wscale)
            self.l5 = Dense(mid_ch, wscale=wscale)
            self.l6 = Dense(mid_ch, wscale=wscale)
            self.l7 = Dense(mid_ch, wscale=wscale)
            self.l_out = Dense(w_ch, wscale=wscale)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.l3(h)
        h = self.l4(h)
        h = self.l5(h)
        h = self.l6(h)
        h = self.l7(h)
        h = self.l_out(h)
        return h

In [0]:
class AdaIN(chainer.Chain):
    def __init__(self, ch, wscale=1.):
        super(AdaIN, self).__init__()

        with self.init_scope():
            self.average_convert = Dense(ch, wscale=wscale)
            self.bias_convert = Dense(ch, wscale=wscale)

    def __call__(self, x, w):
        h = instance_normalization(x)
        a = self.average_convert(w)
        a = F.broadcast_to(F.reshape(a, a.shape+(1, 1)), h.shape)
        b = self.bias_convert(w)
        b = F.broadcast_to(F.reshape(b, b.shape+(1, 1)), h.shape)
        return h * a + b

In [0]:
class MinibatchDiscrimination(chainer.Chain):
    def __init__(self, kernel, ch, wscale=1.):
        super(MinibatchDiscrimination, self).__init__()
        self.kernel = kernel
        self.ch = ch
        
        with self.init_scope():
            w = chainer.initializers.HeNormal(wscale)
            self.t = L.Linear(None, self.kernel*self.ch, initialW=w).add_hook(LH.SpectralNormalization())

    def __call__(self, x):
        batchsize = x.shape[0]
        m = F.reshape(self.t(x), (batchsize, self.kernel, self.ch))
        m = F.expand_dims(m, 3)
        m_T = F.transpose(m, (3, 1, 2, 0))
        m, m_T = F.broadcast(m, m_T)
        norm = F.sum(F.absolute_error(m, m_T), axis=2)
        eraser = F.broadcast_to(cupy.eye(batchsize, dtype=cupy.float32).reshape((batchsize, 1, batchsize)), norm.shape)
        c_b = F.exp(-(norm + 1e6 * eraser))
        o_b = F.sum(c_b, axis=2)
        h = F.concat((x, o_b))
        return h

In [0]:
class Discriminator(chainer.Chain):

    def __init__(self, out_ch=2, k=8, wscale=1., sa_gamma=1.):
        super(Discriminator, self).__init__()
        
        with self.init_scope():
            w = chainer.initializers.HeNormal(wscale)
            if IMG_SIZE >= 256:
                self.c_0 = L.Convolution2D(3, 32, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
                self.inception0 = InceptionBlock(32//k, 32, wscale=wscale)
                self.resize1 = Conv(64, 4, 2, 1, wscale=wscale)

            if IMG_SIZE >= 128:
                self.c_1 = L.Convolution2D(3, 64, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
                self.inception1 = InceptionBlock(64//k, 64, wscale=wscale)
                self.resize2 = Conv(128, 4, 2, 1, wscale=wscale)
            
            if IMG_SIZE >= 64:
                self.c_2 = L.Convolution2D(3, 128, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
                self.inception2 = InceptionBlock(128//k, 128, wscale=wscale)
                self.resize3 = Conv(256, 4, 2, 1, wscale=wscale)

            if IMG_SIZE >= 32:
                self.c_3 = L.Convolution2D(3, 256, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
                self.inception3 = InceptionBlock(256//k, 256, wscale=wscale)
                self.resize4 = Conv(512, 4, 2, 1, wscale=wscale)

            if IMG_SIZE >= 16:
                self.c_4 = L.Convolution2D(3, 512, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
                self.inception4 = InceptionBlock(512//k, 512, wscale=wscale)
                self.sa4 = SelfAttentionBlock(512, wscale=wscale, gamma=sa_gamma)
                self.resize5 = Conv(1024, 4, 2, 1, wscale=wscale)

            self.c_5 = L.Convolution2D(3, 1024, 1, 1, 0, initialW=w).add_hook(LH.SpectralNormalization())
            self.inception5 = InceptionBlock(1024//k, 1024, wscale=wscale)
            self.sa5 = SelfAttentionBlock(1024, wscale=wscale, gamma=sa_gamma)

            self.minibatch_discrimination = MinibatchDiscrimination(64, 16, wscale=wscale)

            self.l_out = Dense(out_ch, wscale=wscale)
    
    def __call__(self, x):

        if IMG_SIZE >= 256:
            h = self.c_0(x)
            h = self.inception0(h)
            h = self.resize1(h)
        else:
            h = 0
        
        if IMG_SIZE >= 128:
            h = h + self.c_1(F.average_pooling_2d(x, IMG_SIZE//128))
            h = self.inception1(h)
            h = self.resize2(h)
        
        if IMG_SIZE >= 64:
            h = h + self.c_2(F.average_pooling_2d(x, IMG_SIZE//64))
            h = self.inception2(h)
            h = self.resize3(h)
        
        if IMG_SIZE >= 32:
            h = h + self.c_3(F.average_pooling_2d(x, IMG_SIZE//32))
            h = self.inception3(h)
            h = self.resize4(h)
        
        if IMG_SIZE >= 16:
            h = h + self.c_4(F.average_pooling_2d(x, IMG_SIZE//16))
            h = self.inception4(h)
            h = self.sa4(h)
            h = self.resize5(h)

        h = h + self.c_5(F.average_pooling_2d(x, IMG_SIZE//8))
        h = self.inception5(h)
        h = self.sa5(h)
        
        h = gap(h)
        h = self.minibatch_discrimination(F.reshape(h, (h.shape[0], -1)))
        h = self.l_out(h)
        return h

    def set_gamma(self, gamma):
        if IMG_SIZE == 8:
            self.sa5.set_gamma(gamma)
        if IMG_SIZE == 16:
            self.sa4.set_gamma(gamma)

In [0]:
class ResBlock(chainer.Chain):
    def __init__(self, ch, k=8, wscale=1.):
        super(ResBlock, self).__init__()
        self.ch = ch
        in_ch = self.ch//k

        with self.init_scope():
            self.inception_1 = InceptionBlock(in_ch, self.ch, wscale=wscale)
            self.adain_1 = AdaIN(self.ch, wscale=wscale)

            self.inception_2 = InceptionBlock(in_ch*3//2, self.ch*3//2, wscale=wscale)
            self.adain_2 = AdaIN(self.ch*3//2, wscale=wscale)

            self.c = Conv(3, wscale=wscale)
            
    def __call__(self, x, w1, w2, noise=None):
        _h = noise_injection(x, noise) if noise else x
        _h = self.inception_1(_h)
        h = zeropad(x, self.ch)
        h = h + _h
        h = self.adain_1(h, w1)

        _h = noise_injection(h, noise) if noise else h
        _h = self.inception_2(_h)
        h = zeropad(h, self.ch*3//2)
        h = h + _h
        h = self.adain_2(h, w1)

        out = self.c(h)
        return h, out

In [0]:
class Generator(chainer.Chain):

    def __init__(self, wscale=1., noise=None):
        super(Generator, self).__init__()
        self.noise = noise

        with self.init_scope():
            w = chainer.initializers.HeNormal(wscale)
            self.txbtm = L.Linear(None, 1024*8*8, initialW=w).add_hook(LH.SpectralNormalization())
            self.affine = Affine()
            
            self.res5 = ResBlock(1024, wscale=wscale)
            if IMG_SIZE >= 16:
                self.res4 = ResBlock(512, wscale=wscale)
            if IMG_SIZE >= 32:
                self.res3 = ResBlock(256, wscale=wscale)
            if IMG_SIZE >= 64:
                self.res2 = ResBlock(128, wscale=wscale)
            if IMG_SIZE >= 128:
                self.res1 = ResBlock(64, wscale=wscale)
            if IMG_SIZE >= 256:
                self.res0 = ResBlock(32, wscale=wscale)
    
    def __call__(self, x):
        
        h = F.reshape(self.txbtm(x), (-1, 1024, 8, 8))
        w1 = self.affine(gaussian((x.shape[0], latent_size)))
        w2 = self.affine(gaussian((x.shape[0], latent_size)))
        h, out = self.res5(h, w1, w2, self.noise)
        
        if IMG_SIZE >= 16:
            h = upsample(h)
            w1 = self.affine(gaussian((x.shape[0], latent_size)))
            w2 = self.affine(gaussian((x.shape[0], latent_size)))
            h, _out = self.res4(h, w1, w2, self.noise)
            out = F.unpooling_2d(out, 2, cover_all=False) + _out
        
        if IMG_SIZE >= 32:
            h = upsample(h)
            w1 = self.affine(gaussian((x.shape[0], latent_size)))
            w2 = self.affine(gaussian((x.shape[0], latent_size)))
            h, _out = self.res3(h, w1, w2, self.noise)
            out = F.unpooling_2d(out, 2, cover_all=False) + _out
        
        if IMG_SIZE >= 64:
            h = upsample(h)
            w1 = self.affine(gaussian((x.shape[0], latent_size)))
            w2 = self.affine(gaussian((x.shape[0], latent_size)))
            h, _out = self.res2(h, w1, w2, self.noise)
            out = F.unpooling_2d(out, 2, cover_all=False) + _out
        
        if IMG_SIZE >= 128:
            h = upsample(h)
            w1 = self.affine(gaussian((x.shape[0], latent_size)))
            w2 = self.affine(gaussian((x.shape[0], latent_size)))
            h, _out = self.res1(h, w1, w2, self.noise)
            out = F.unpooling_2d(out, 2, cover_all=False) + _out
        
        if IMG_SIZE >= 256:
            h = upsample(h)
            w1 = self.affine(gaussian((x.shape[0], latent_size)))
            w2 = self.affine(gaussian((x.shape[0], latent_size)))
            h, _out = self.res0(h, w1, w2, self.noise)
            out = F.unpooling_2d(out, 2, cover_all=False) + _out

        return F.tanh(out)

In [0]:
class PretrainUpdater(chainer.training.updaters.StandardUpdater):

    def __init__(self, end_point, sa_gamma, *args, **kwargs):
        self.end_point = end_point
        self.sa_gamma = sa_gamma
        self.gen, self.dis = kwargs.pop('models')
        super(PretrainUpdater, self).__init__(*args, **kwargs)

    def set_gamma(self, gamma):
        self.dis.set_gamma(gamma)
    
    def loss_gan(self, model, loss):
        chainer.report({'loss': loss}, model)
        return loss

    def gan_update(self, x_real, gen, dis, gen_optimizer, dis_optimizer):
        y_real = dis(x_real)
        dis_loss_real = F.mean((y_real-1)**2)

        x_fake = gen(gaussian((x_real.shape[0], latent_size)))
        y_fake = dis(x_fake)
        dis_loss_fake = F.mean((y_fake+1)**2)
        dis_optimizer.update(self.loss_gan, dis, (dis_loss_real+dis_loss_fake)*0.5)

        gen_loss = F.mean(y_fake**2)
        gen_optimizer.update(self.loss_gan, gen, gen_loss)

    def update_core(self):
        gen_optimizer = self.get_optimizer(gen_name)
        dis_optimizer = self.get_optimizer(dis_name)
        
        if self.iteration <= self.end_point:
            gamma = self.sa_gamma*self.iteration/self.end_point
            self.set_gamma(gamma)

        gen, dis = self.gen, self.dis
        target_iter = self.get_iterator('main')

        target_batch = target_iter.next()
        real_target = Variable(self.converter(target_batch, device=self.device)) /255. *2. -1.
        self.gan_update(real_target, gen, dis, gen_optimizer, dis_optimizer)

In [0]:
def out_generated_image(gen):
    @chainer.training.make_extension()
    def make_image(trainer):
        clear_output()
        with chainer.using_config('train', False):
            generated = gen(F.vstack((cupy.zeros((1, latent_size), dtype=cupy.float32), gaussian((7, latent_size)))))
            
        generated = F.transpose(F.reshape(generated, (-1, 3)+IMG_SHAPE), (0, 2, 3, 1))
        generated = chainer.backends.cuda.to_cpu(generated.array)

        plt.figure(figsize=(16, 8))
            
        for i, img in enumerate(generated):
            plt.subplot(2, 4, i+1).axis('off')
            plt.subplot(2, 4, i+1).imshow(Image.fromarray(numpy.uint8((img+1.)/2. *255.)))

        plt.show()
    return make_image

In [0]:
def image_upload(image_array, api):
    bin_io = io.BytesIO()
    img = Image.fromarray(numpy.uint8((image_array+1.)/2. *255.))
    img.save(bin_io, format='JPEG')
    result = api.media_upload(filename='{}_generated_{}.jpg'.format(target, uuid.uuid4()), file=bin_io)
    return result.media_id

In [0]:
def post_generated_image(gen, api):
    @chainer.training.make_extension()
    def post_image(trainer):
        with chainer.using_config('train', False):
            generated = F.unpooling_2d(gen(gaussian((4, latent_size))), 256//IMG_SIZE, cover_all=False)

        generated = F.transpose(F.reshape(generated, (-1, 3, 256, 256)), (0, 2, 3, 1))
        generated = chainer.backends.cuda.to_cpu(generated.array)

        try:
            img_ids = [image_upload(img, api) for img in generated]
            hash_tags = ['AIでペガサスを作る',
                        '#makeing{}'.format(target),
                        '#nowlearning...',
                        '#AI',
                        '#人工知能',
                        '#DeepLearning',
                        '#GAN']

            api.update_status(
                status='\n'.join(hash_tags),
                media_ids=img_ids
                )
        except:
            pass
    return post_image

In [0]:
gen = Generator(noise=gen_noise, wscale=initial_scale)
dis = Discriminator(out_ch=1, sa_gamma=sa_gamma, wscale=initial_scale)

In [0]:
chainer.backends.cuda.get_device_from_id(GPU).use()
gen.to_gpu()
dis.to_gpu()

<__main__.Discriminator at 0x7fb34eea54d0>

In [0]:
if load_weight:
    chainer.serializers.load_npz(OUT+gen_name+'.npz', gen, strict=False)
if load_dis:
    chainer.serializers.load_npz(OUT+dis_name+'.npz', dis, strict=False)

In [0]:
opt_gen = make_optimizer_Adam(
    gen,
    alpha=learning_rate,
    clip=grad_clip,
    decay=grad_decay
    )
opt_dis = make_optimizer_Adam(
    dis, 
    alpha=learning_rate, 
    clip=grad_clip, 
    decay=grad_decay
    )

In [0]:
if load_opt:
    chainer.serializers.load_npz(OUT+opt_gen_name+'.npz', opt_gen, strict=False)
if load_opt and load_dis:
    chainer.serializers.load_npz(OUT+opt_dis_name+'.npz', opt_dis, strict=False)

In [0]:
def img_convert(img_array):
    img = Image.fromarray(numpy.uint8(img_array.transpose(1, 2, 0)))
    img = img.convert('RGB').resize(IMG_SHAPE)
        
    if random.random() > 0.5:
        img = ImageOps.mirror(img)

    img_array = numpy.asarray(img, dtype=numpy.float32).transpose(2, 0, 1)
    return img_array

In [0]:
target_files = glob.glob(dataset, recursive=True)
print('{} contains {} image files'
      .format(dataset, len(target_files)))

target_img_dataset = chainer.datasets.ImageDataset(paths=target_files)
target_trans_dataset = chainer.datasets.TransformDataset(target_img_dataset, img_convert)

.//Drive_sync/picture/horse_pic/**/* contains 1431 image files


In [0]:
target_iter = chainer.iterators.SerialIterator(target_trans_dataset, batchsize, shuffle=True)

In [0]:
updater = PretrainUpdater(
    end_point=sa_endpoint,
    sa_gamma=sa_gamma,
    models=(gen, dis),
    iterator={'main': target_iter},
    optimizer={gen_name: opt_gen,
               dis_name: opt_dis,},
    device=GPU)

In [0]:
trainer = training.Trainer(updater, (iteration, 'iteration'), out=OUT)

In [0]:
snapshot_interval = (snapshot_interval, 'iteration')
display_interval = (display_interval, 'iteration')
log_interval = (log_interval, 'iteration')
if tweet_interval:
    tweet_interval = (tweet_interval, 'iteration')

In [0]:
if tweet_interval:
    auth = tweepy.OAuthHandler(twitter_api_key.CONSUMER_KEY, twitter_api_key.CONSUMER_SECRET)
    auth.set_access_token(twitter_api_key.ACCESS_TOKEN_KEY, twitter_api_key.ACCESS_TOKEN_SECRET)
    api = tweepy.API(auth)

In [0]:
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(out_generated_image(gen), trigger=display_interval)
trainer.extend(extensions.ProgressBar(update_interval=update_interval))
if tweet_interval:
    trainer.extend(post_generated_image(gen, api), trigger=tweet_interval)

trainer.extend(extensions.snapshot_object(gen, gen_name+'.npz'), trigger=snapshot_interval)

if save_dis:
    trainer.extend(extensions.snapshot_object(dis, dis_name+'.npz'), trigger=snapshot_interval)

if save_opt:
    trainer.extend(extensions.snapshot_object(opt_gen, opt_gen_name+'.npz'), trigger=snapshot_interval)

if save_dis and save_opt:
    trainer.extend(extensions.snapshot_object(opt_dis, opt_dis_name+'.npz'), trigger=snapshot_interval)

trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', gen_name+'/loss', dis_name+'/loss'
    ]), trigger=display_interval)

In [0]:
trainer.run()