---
title: Tensorflow Gan 实现 Estimator
tags: 小书匠,Tensorflow,Tensorflow1,gan,estimator
grammar_cjkRuby: true
# renderNumberedHeading: true
---

[toc!]

# Tensorflow Gan 实现 Estimator

这个实现和大多数实现不同。在原始的 Gan 论文中，在更新完 Discriminator 之后，generator 需要重新抽样并计算梯度，好多实现中并没有重新抽样这一步。

In [9]:
import tensorflow as tf
print(tf.__version__)
import io
from tensorflow.keras.layers import Dense
import functools
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
tf.logging.set_verbosity(tf.logging.INFO)

1.15.0


In [10]:
#该函数用于输出生成图片
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [11]:
def sample_Z(shape):
    return tf.random.normal(shape=shape)

In [12]:
# 注意，这里要将图片范围缩小到 [0, 1] 上
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.reshape(-1, 28 * 28)
x_test = x_test.reshape(-1, 28*28)

In [13]:
def model_fn(features, labels, mode):
    x_true = features
    Generator = tf.keras.Sequential([
        Dense(128, activation='relu'),
        Dense(784, activation='sigmoid'), # 最后接一个 sigmoid，将输出范围也缩小到 [0, 1] 上，和 x 的范围相同
    ])

    Discriminator = tf.keras.Sequential([
        Dense(128, activation='relu'),
        Dense(1)
    ])

    input_shape = tf.shape(x_true)
    x_fake = Generator(sample_Z(input_shape))
    D_logits_fake = Discriminator(x_fake)
    D_logits_true = Discriminator(x_true)
    
    tf.summary.image("fake", tf.reshape(x_fake, (-1, 28, 28, 1)))
    tf.summary.image("true", tf.reshape(x_true, (-1, 28, 28, 1)))

    D_loss_positive = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(D_logits_true), logits=D_logits_true))
    D_loss_negative = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(D_logits_fake), logits=D_logits_fake))
    D_loss = D_loss_positive + D_loss_negative

    D_global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
    G_global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
    
    global_step = tf.train.get_global_step()
    
    D_train_op = tf.train.AdamOptimizer().minimize(D_loss,
                                                   var_list=Discriminator.trainable_variables,
                                                   global_step=D_global_step)

    # 先优化 D，然后再优化 G
    with tf.control_dependencies([D_train_op]):
        new_D_logits_fake = Discriminator(Generator(sample_Z(input_shape)))
        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(new_D_logits_fake),
                                                                        logits=new_D_logits_fake))
        G_train_op = tf.train.AdamOptimizer().minimize(G_loss,
                                                           var_list=Generator.trainable_variables,
                                                           global_step=G_global_step)
        # 手动更新 global_step
        # 因为我们要 minimize 两个 loss，如果要自动更新 global_step，会导致 global_step 是实际的两倍
        update_global_step = tf.assign(global_step, G_global_step)
        
        train_op = tf.group(G_train_op, update_global_step)
        
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {"x_fake": x_fake, "D_logits_fake": D_logits_fake}
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions)

    elif mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=G_loss,
                                          train_op=train_op)

    elif mode == tf.estimator.ModeKeys.EVAL:
        metrics = {
            "acc_fake": tf.metrics.accuracy(labels=tf.zeros_like(D_logits_fake),
                                            predictions=tf.cast(tf.nn.sigmoid(D_logits_fake) > 0.5, tf.int32)),
            "acc_true": tf.metrics.accuracy(labels=tf.ones_like(D_logits_true),
                                            predictions=tf.cast(tf.nn.sigmoid(D_logits_true) > 0.5, tf.int32)),
        }
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=D_loss,
                                          eval_metric_ops=metrics)

In [14]:
model_dir = './model'
BATCH_SIZE = 128
EPOCHS = 400
LR = 0.0001

def input_fn(x, epochs=1, batch_size=32, istrain=False):
    dataset = tf.data.Dataset.from_tensor_slices(x)
    if istrain:
        dataset = dataset.shuffle(10000)
    dataset = dataset.repeat(epochs).batch(batch_size)
    return dataset 

config = tf.estimator.RunConfig(save_checkpoints_steps=400)
estimator = tf.estimator.Estimator(
    model_fn=model_fn, # 这里需要一个函数
    model_dir=model_dir, 
    config=config
)


train_spec = tf.estimator.TrainSpec(
    input_fn=functools.partial(input_fn,
                               x_train,
                               batch_size=BATCH_SIZE,
                               istrain=True,
                               epochs=EPOCHS), 
    max_steps=20000,
)

eval_spec = tf.estimator.EvalSpec(
    input_fn=functools.partial(input_fn,
                               x_test,
                               batch_size=BATCH_SIZE,
                               istrain=False,
                               epochs=EPOCHS), 
    throttle_secs=5,
)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec,
    eval_spec
)

INFO:tensorflow:Using config: {'_model_dir': './model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 400, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x142d1e208>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally 

(None, None)

# References
- https://github.com/jiqizhixin/ML-Tutorial-Experiment/blob/ce316d55439859e8aaf10903a55b52066e20146c/Experiments/tf_GAN.ipynb
- http://localhost:8888/lab/tree/DL-Project/Gan/gan/Tensorflow1%20%E4%BD%BF%E7%94%A8%20Estimator%20%E5%AE%9E%E7%8E%B0%20Gan.ipynb