In [1]:
import os

import tensorflow as tf
import numpy as np
from PIL import Image
from chainer.functions import caffe

In [2]:
# 準備用関数
def load_image(path, size=None):
    """sizeがNoneのときは画像のそのままのサイズで読み込む"""
    img = Image.open(os.path.expanduser(path)).convert("RGB")
    if size is not None:
        img = img.resize(size, Image.BILINEAR)
    return tf.constant(transform_for_train(np.array([np.array(img)[:, :, :3]], dtype=np.float32)))


def transform_for_train(img):
    """
    読み込む画像がRGBなのに対し、VGGなどのパラメータがBGRの順なので、順番を入れ替える。
    ImageNetの色の平均値を引く。
    """
    return img[..., ::-1] - 120


def transform_from_train(img):
    """
    transform_for_trainの逆操作。
    """
    data = img[:, :, ::-1] + 120  # これ、np.uinit8にしてよくね？
    return data.clip(0, 255)

In [3]:
class Conv:
    def __init__(self, chainer_conv):
        W = chainer_conv.W.data
        b = chainer_conv.b.data
        self.W = tf.constant(np.transpose(W, [2, 3, 1, 0]))
        self.b = tf.constant(b)
    
    def __call__(self, x, stride=1, activation_fn=tf.nn.relu, padding="SAME"):
        y = tf.nn.conv2d(x, self.W, strides=[1, stride, stride, 1], padding=padding) + self.b
        return activation_fn(y) if activation_fn else y

def avg_pool(x, ksize, stride, padding="SAME"):
    return tf.nn.avg_pool(x, ksize=[1, ksize, ksize, 1],
                           strides=[1, stride, stride, 1],
                           padding=padding)
    
def load_caffemodel(caffemodel):
    print("load model... %s" % caffemodel)
    model = caffe.CaffeFunction(caffemodel)
    return lambda layer_name: Conv(getattr(model, layer_name))

In [18]:
class BaseModel:
    """
    特徴量を得るためのモデルのAbstract class
    """
    default_caffemodel = None
    default_alpha = None
    default_beta = None

    def __init__(self, caffemodel=None, alpha=None, beta=None):
        self.conv = load_caffemodel(caffemodel or self.default_caffemodel)
        self.alpha = alpha or self.default_alpha
        self.beta = beta or self.default_beta
    
        
class NIN(BaseModel):
    """
    NINを用いた特徴量
    """
    default_caffemodel = "nin_imagenet.caffemodel"
    default_alpha = [0., 0., 1., 1.]
    default_beta = [1., 1., 1., 1.]
    
    def __call__(self, x):
        """NINの特徴量"""
        x0 = self.conv("conv1")(x, stride=4, padding="VALID")
        
        y1 = self.conv("cccp2")(self.conv("cccp1")(x0), activation_fn=None)
        pool1 = avg_pool(tf.nn.relu(y1), ksize=3, stride=2)
        x1 = self.conv("conv2")(pool1, stride=1)
        
        y2 = self.conv("cccp4")(self.conv("cccp3")(x1), activation_fn=None)
        pool2 = avg_pool(tf.nn.relu(y2), ksize=3, stride=2)
        x2 = self.conv("conv3")(pool2, stride=1)

        y3 = self.conv("cccp6")(self.conv("cccp5")(x2), activation_fn=None)
        pool3 = avg_pool(tf.nn.relu(y3), ksize=3, stride=2)
        
        drop = tf.nn.dropout(pool3, 0.5)
        x3 = self.conv("conv4-1024")(drop)
        
        return [x0, x1, x2, x3]


class VGG(BaseModel):
    """
    VGGを用いた特徴量
    """
    default_caffemodel = "VGG_ILSVRC_16_layers.caffemodel"
    default_alpha = [0., 0., 1., 1.]
    default_beta = [1., 1., 1., 1.]
    
    def __call__(self, x):
        """VGGの特徴量"""
        y1 = self.conv("conv1_2")(self.conv("conv1_1")(x), activation_fn=None)
        x1 = avg_pool(tf.nn.relu(y1), ksize=2, stride=2)  # max?
        
        y2 = self.conv("conv2_2")(self.conv("conv2_1")(x1), activation_fn=None)
        x2 = avg_pool(tf.nn.relu(y2), ksize=2, stride=2)  # max?
        
        y3 = self.conv("conv3_3")(self.conv("conv3_2")(self.conv("conv3_1")(x2)), activation_fn=None)
        x3 = avg_pool(tf.nn.relu(y3), ksize=2, stride=2)  # max?
        
        y4 = self.conv("conv4_3")(self.conv("conv4_2")(self.conv("conv4_1")(x3)), activation_fn=None)
        
        return [y1, y2, y3, y4]

In [19]:
# TODO: よさげにする
def inner_product_matrix(y):
    _, height, width, ch_num = y.get_shape().as_list()
    y_reshaped = tf.reshape(y, [-1, height * width, ch_num])
    return tf.matmul(y_reshaped, y_reshaped, adjoint_a=True) / (height * width * ch_num)
# TODO: version依存


class Generator:
    def __init__(self, base_model, img_orig, img_style, config):
        # 特徴抽出を行う
        mids_orig = base_model(img_orig)
        mids_style = base_model(img_style)
        
        # 損失関数に使うものを作る
        prods_style = [inner_product_matrix(y) for y in mids_style]
        
        # img_genを初期化する
        img_gen = tf.Variable(tf.random_uniform(config.output_shape, -20, 20))  # rank 4で無くてもいい説
        
        self.img_gen = img_gen
        mids = base_model(img_gen)
        
        self.loss = []
        self.loss1 = []
        self.loss2 = []
        
        for i, (mid, mid_orig, mid_style, prod_style, alpha, beta) in enumerate(
            zip(mids, mids_orig, mids_style, prods_style, base_model.alpha, base_model.beta)):
            # 損失関数の定義
            shape1 = mid.get_shape().as_list()
            loss1 = config.lam * tf.nn.l2_loss(mid - mid_orig) / np.prod(shape1)
            shape2 = prod_style.get_shape().as_list()
            loss2 = beta * tf.nn.l2_loss(inner_product_matrix(mid) - prod_style) / np.prod(shape2)
            if alpha != 0.0:
                loss = loss1 * alpha + loss2 / len(mids)
            else:
                loss = loss2 / len(mids)
            self.loss.append(loss)
            self.loss1.append(loss1 * alpha)
            self.loss2.append(loss2 / len(mids))
        self.total_loss = sum(self.loss)  # tfのを使うべき？
        self.total_train = config.optimizer.minimize(self.total_loss)
        clipped = tf.clip_by_value(self.img_gen, -120., 136.)  # 要らない？
        self.clip = tf.assign(self.img_gen, clipped)
        
    def generate(self, config):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print("start")
            # 学習開始
            for i in range(config.iteration):
                sess.run([self.total_train, self.clip])
                if (i + 1) % 50 == 0:
                    # l, l1, l2 = sess.run([self.loss, self.loss1, self.loss2])
                    # print("%d| loss: %f, loss1: %f, loss2: %f" % (i + 1, sum(l), sum(l1), sum(l2)))
                    # for l_, l1_, l2_ in zip(l, l1, l2):
                    #   print("\tloss: %f, loss1: %f, loss2: %f" % (l_, l1_, l2_))
                    self.save_image(sess, config.save_path % (i + 1))
                    
    def save_image(self, sess, path):
        data = sess.run(self.img_gen)[0]
        data = transform_from_train(data)
        img = Image.fromarray(data.astype(np.uint8))
        print("save %s" % path)
        img.save(path)

In [20]:
def generate_model(model_name, **args):
    if model_name == 'nin':
        return NIN(**args)
    if model_name == 'vgg':
        return VGG(**args)

In [21]:
# 設定
# modelを読み込む


class Config:
    batch_size = 1
    iteration = 5000
    lr = 4.0
    lam = 0.005
    width = 100
    height = 100
    output_shape = [batch_size, height, width, 3]
    output_dir = "_output"
    model = "nin"
    original_image = "./tmp/cat.png"
    style_image = "./tmp/gogh.png"
    # model = "vgg"
    save_path = os.path.expanduser(os.path.join(output_dir, "%05d.png"))
    optimizer = tf.train.AdamOptimizer(lr)
    no_resize_style = False  # Trueにすると画風画像をリサイズせずに利用する（開始が遅くなる）


In [22]:
# 画像生成
config = Config()
os.makedirs(config.output_dir, exist_ok=True)
model = generate_model(config.model)

# nin = generate_model('nin')
# vgg = generate_model('vgg')

load model... nin_imagenet.caffemodel


In [23]:
img_orig = load_image(config.original_image, [config.width, config.height])
img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None)


In [24]:
# config = Config()
generator = Generator(model, img_orig, img_style, config)

generator.generate(config)


start
save _output/00050.png
save _output/00100.png
save _output/00150.png
save _output/00200.png
save _output/00250.png
save _output/00300.png
save _output/00350.png
save _output/00400.png
save _output/00450.png
save _output/00500.png
save _output/00550.png
save _output/00600.png
save _output/00650.png
save _output/00700.png
save _output/00750.png
save _output/00800.png
save _output/00850.png
save _output/00900.png
save _output/00950.png
save _output/01000.png
save _output/01050.png
save _output/01100.png
save _output/01150.png
save _output/01200.png
save _output/01250.png
save _output/01300.png
save _output/01350.png
save _output/01400.png
save _output/01450.png
save _output/01500.png
save _output/01550.png
save _output/01600.png
save _output/01650.png
save _output/01700.png
save _output/01750.png
save _output/01800.png
save _output/01850.png
save _output/01900.png
save _output/01950.png
save _output/02000.png
save _output/02050.png
save _output/02100.png
save _output/02150.png
save 