<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#构建条件变分编码器" data-toc-modified-id="构建条件变分编码器-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>构建条件变分编码器</a></span></li><li><span><a href="#转换one-hot编码" data-toc-modified-id="转换one-hot编码-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>转换one-hot编码</a></span></li><li><span><a href="#训练网络" data-toc-modified-id="训练网络-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>训练网络</a></span></li><li><span><a href="#绘制网络结构" data-toc-modified-id="绘制网络结构-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>绘制网络结构</a></span></li><li><span><a href="#提取解码器部分作为生成模型" data-toc-modified-id="提取解码器部分作为生成模型-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>提取解码器部分作为生成模型</a></span></li></ul></div>

# 构建条件变分编码器

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda, Layer, concatenate
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import mnist

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))

batch_size = 32  #批处理个数
original_dim = 576  #输入维度
latent_dim = 10  #修改隐向量的维度
intermediate_dim = 256  #全连接层神经元个数
epochs = 50
epsilon_std = 1.0

x = Input(shape=(original_dim, ))
label = Input(shape=(5, ))
inputs = concatenate([x, label])
new_layer = Dense(intermediate_dim, activation='relu')(inputs)
h = Dense(128, activation='relu')(new_layer)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)


def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(
        shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var / 2) * epsilon


# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim, ))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
z_input = concatenate([z, label])
decoder_h = Dense(intermediate_dim, activation='relu')
decoded_p = Dense(128, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')

h_decoded = decoder_h(z_input)
p_decoded = decoded_p(h_decoded)
x_decoded_mean = decoder_mean(p_decoded)


# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean):
        xent_loss = original_dim * metrics.binary_crossentropy(
            x, x_decoded_mean)
        kl_loss = -0.5 * K.sum(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean)
        self.add_loss(loss, inputs=inputs)
        # We won't actually use the output.
        return x


y = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model([x, label], y)
vae.compile(optimizer='adam', loss=None)

# 转换one-hot编码

In [None]:
x_label = np_utils.to_categorical(x_label,num_classes=5)

# 训练网络

In [None]:
vae.fit(x=[x_train,x_label],
        shuffle=True,
        epochs=200,
        batch_size=batch_size)

# 绘制网络结构

In [None]:
from keras.utils import plot_model
from keras.utils.vis_utils import model_to_dot
from IPython.display import SVG

plot_model(vae, show_layer_names=False,show_shapes=True,to_file='CVAE-model.tif')
SVG(model_to_dot(vae,show_layer_names=False,show_shapes=True).create(prog='dot', format='svg'))

# 提取解码器部分作为生成模型

In [None]:
decoder_input = Input(shape=(latent_dim,))


label_input=Input(shape=(5,))
merge_label=concatenate([decoder_input,label_input])
_h_decoded = decoder_h(merge_label)
_p_decoded = decoded_p(_h_decoded)
_x_decoded_mean = decoder_mean(_p_decoded)
generator = Model([decoder_input,label_input], _x_decoded_mean)