# 📚 5.4修成正果篇-基于变分自编码器图像生成

Variational autoencoders（VAE)由Kingma et al.和Rezende et al.在2013年提出，它在图像生成、强化学习和自然语言处理等多个领域都有很广泛的应用.我们可以学习到给定 x，隐藏变量的条件概率分布P(z|x)，在学习到这个分布后，通过 对P(z|x)进行采样可以生成不同的样本。

<img src="https://tianchi-public.oss-cn-hangzhou.aliyuncs.com/public/files/forum/161829959015020561618299587169.png"/>

输入𝒙通过编码器网络𝑞 ( | )计算得到隐变量 z 的均 值与方差，通过Reparameterizationtrick方式采样后送入解码器网络，获得分布𝑝𝜃( | )， 并通过公式计算误差并优化参数。

# 一、本节目标

了解VAE框架，编写VAE代码，能在自己的数据集上训练自己的图片。

# 二、 实战

In [4]:
import tensorflow as tf

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display

## 2.1 数据集

我们使用了Fashion MNIST数据集，包含了 10 类不同类型的衣服、鞋子、包等灰度图片，图片大小为 28x28，共 70000 张图片，其中 60000 张用于训练集，10000 张用于测试集，每行为一种类别图片。可以看到，Fashion MNIST 除了图片内容与 MNIST 不一样，其 它设定都相同，大部分情况可以直接替换掉原来基于 MNIST 训练的算法代码，而不需要 额外修改。由于 Fashion MNIST 图片识别相对于 MNIST 图片更难，因此可以用于测试稍复杂的算法性能。

In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from tensorflow.keras import Sequential, layers
import sys

In [6]:
#   超参数
z_dim = 10
h_dim = 20
batchsz = 512
learn_rate = 1e-3

(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
print('x_train.shape:', x_train.shape)
train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz)

## 2.2 模型搭建

<img src="https://tianchi-public.oss-cn-hangzhou.aliyuncs.com/public/files/forum/161830827523983751618308264494.png"/>

### 2.2.1 编码器的搭建

编码器的 模块就是下面的图片

<img src="https://tianchi-public.oss-cn-hangzhou.aliyuncs.com/public/files/forum/161830910638287101618309105937.png"/>

### 2.2.2 解码器的搭建

可对照模型自己找到对应位置

### 2.2.3 模型拼接

### 2.2.3整体框架

In [5]:
#   超参数
z_dim = 10
h_dim = 20
batchsz = 512
learn_rate = 1e-3
class VAE(keras.Model):
    def __init__(self):
        super(VAE,self).__init__()

        #   Encoder
        self.fc1 = layers.Dense(128)
        self.fc2 = layers.Dense(z_dim)      #       获得均值
        self.fc3 = layers.Dense(z_dim)      #       获得均值

        #   Decoder
        self.fc4 = layers.Dense(128)
        self.fc5 = layers.Dense(784)

    def encoder(self,x):
        h = tf.nn.relu(self.fc1(x))
        #   get mean    获取均值
        mu = self.fc2(h)
        #   get variance    获取方差
        log_var = self.fc3(h)

        return mu,log_var

    def decoder(self,z):

        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)

        return out


    def call(self, inputs, training=None, mask=None):
        #   [b,784] =>[b,z_dim],[b,z_dim]
        mu,log_var = self.encoder(inputs)

        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var) ** 0.5
        z = mu + std * eps

        x_hat = self.decoder(z)
        return x_hat,mu,log_var


## 2.3网络的训练

看下模型

In [8]:
my_model = VAE()
my_model.build(input_shape=(4,784))
my_model.summary()

Model: "vae_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_11 (Dense)             multiple                  100480    
_________________________________________________________________
dense_12 (Dense)             multiple                  1290      
_________________________________________________________________
dense_13 (Dense)             multiple                  1290      
_________________________________________________________________
dense_14 (Dense)             multiple                  1408      
_________________________________________________________________
dense_15 (Dense)             multiple                  101136    
Total params: 205,604
Trainable params: 205,604
Non-trainable params: 0
_________________________________________________________________


## 2.4 图片生成

图片生成只利用到解码器网络，首先从先验分布𝒩(0, 𝐼)中采样获得隐向量，再通过解
码器获得图片向量，最后 Reshape 为图片矩阵:

保存的图片在下面的路径中找到，其中  1_label.jpg 是真实的图片，1_pre.jpg 是预测的图片， 1_random.jpg 是产生随机产生正态分布的图片

<img src="https://tianchi-public.oss-cn-hangzhou.aliyuncs.com/public/files/forum/161831058872214681618310584152.png"/>

## 2.5 整体代码

In [13]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from tensorflow.keras import Sequential, layers
import sys


def my_save_img(data,name):
    save_img_path = 'VAEimage/{}.jpg'.format(name)
    new_img = np.zeros((280,280))
    for index,each_img in enumerate(data[:100]):
        row_start = int(index/10) * 28
        col_start = (index%10)*28
        # print(index,row_start,col_start)
        new_img[row_start:row_start+28,col_start:col_start+28] = each_img

    plt.imsave(save_img_path,new_img)

#   超参数
z_dim = 10
h_dim = 20
batchsz = 512
learn_rate = 1e-3

(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
print('x_train.shape:', x_train.shape)
train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz)

class VAE(keras.Model):
    def __init__(self):
        super(VAE,self).__init__()

        #   Encoder
        self.fc1 = layers.Dense(128)
        self.fc2 = layers.Dense(z_dim)      #       获得均值
        self.fc3 = layers.Dense(z_dim)      #       获得均值

        #   Decoder
        self.fc4 = layers.Dense(128)
        self.fc5 = layers.Dense(784)

    def encoder(self,x):
        h = tf.nn.relu(self.fc1(x))
        #   get mean    获取均值
        mu = self.fc2(h)
        #   get variance    获取方差
        log_var = self.fc3(h)

        return mu,log_var

    def decoder(self,z):

        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)

        return out


    def call(self, inputs, training=None, mask=None):
        #   [b,784] =>[b,z_dim],[b,z_dim]
        mu,log_var = self.encoder(inputs)

        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var) ** 0.5
        z = mu + std * eps

        x_hat = self.decoder(z)
        return x_hat,mu,log_var

my_model = VAE()
# my_model.build(input_shape=(4,784))
opt = tf.optimizers.Adam(learn_rate)

for epoch in range(100):
    for step,x in enumerate(train_db):
        x = tf.reshape(x, [-1, 784])
        with tf.GradientTape() as  tape:
            x_hat,mu,log_var = my_model(x)

            # rec_loss = tf.losses.binary_crossentropy(x, x_hat, from_logits=True)
            rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_hat)
            rec_loss = tf.reduce_sum(rec_loss)/x.shape[0]

            #   分布loss  (mu,var) - N(0,1)
            kl_div = -0.5 * (log_var + 1 - mu ** 2 - tf.exp(log_var))
            kl_div = tf.reduce_sum(kl_div) / x.shape[0]

            #   两个误差综合
            my_loss = rec_loss + 1. * kl_div

        grads = tape.gradient(my_loss, my_model.trainable_variables)
        opt.apply_gradients(zip(grads, my_model.trainable_variables))

        if step % 100 == 0:
            print(epoch,step,float(my_loss),'kl div:',float(kl_div),'rec loss:',float(rec_loss))

        #   evaluation
        #   随机Z只用decode生成
        z = tf.random.normal((batchsz,z_dim))
        logits = my_model.decoder(z)
        x_hat = tf.sigmoid(logits)
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
        my_save_img(x_hat,'{}_random'.format(epoch))

        x = next(iter(test_db))
        my_save_img(x, '{}_label'.format(epoch))
        x = tf.reshape(x, [-1, 784])
        x_hat_logits, _, _ = my_model(x)
        x_hat = tf.sigmoid(x_hat_logits)
        x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
        my_save_img(x_hat, '{}_pre'.format(epoch))

x_train.shape: (60000, 28, 28)
0 0 548.4443359375 kl div: 2.5494801998138428 rec loss: 545.8948364257812
0 100 307.4501037597656 kl div: 15.489289283752441 rec loss: 291.9608154296875
1 0 299.75299072265625 kl div: 15.903350830078125 rec loss: 283.8496398925781
1 100 275.8286437988281 kl div: 16.12905502319336 rec loss: 259.6995849609375
2 0 271.4725036621094 kl div: 16.0245361328125 rec loss: 255.44796752929688
2 100 268.268798828125 kl div: 15.50701904296875 rec loss: 252.7617950439453
3 0 264.3460388183594 kl div: 15.03890323638916 rec loss: 249.30712890625
3 100 262.34039306640625 kl div: 15.260851860046387 rec loss: 247.07952880859375
4 0 253.7904510498047 kl div: 15.42512321472168 rec loss: 238.36532592773438
4 100 257.94842529296875 kl div: 15.146361351013184 rec loss: 242.80206298828125
5 0 252.0976104736328 kl div: 14.786120414733887 rec loss: 237.31149291992188
5 100 258.8625793457031 kl div: 14.585002899169922 rec loss: 244.277587890625
6 0 253.73342895507812 kl div: 15.0369

# 三、课后作业

1 自己设置一个深度的VAE，完成自己数据集的重建

2 掌握网络的输入输出维度变换