In [0]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import *
import numpy as np
import tensorflow as tf
from keras.optimizers import *
import math
from PIL import Image

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [0]:
def generator_model():
  model = Sequential()
  model.add(Dense(input_dim=100,output_dim=1024))
  model.add(Activation('tanh'))
  model.add(Dense(128*7*7))
  
  model.add(BatchNormalization())
  model.add(Activation('tanh'))
  
  model.add(Reshape((7,7,128),input_shape=(128*7*7,)))
  
  model.add(UpSampling2D(size=(2,2)))
  
  model.add(Conv2D(64,(5,5),padding='same'))
  model.add(Activation('tanh'))
  model.add(UpSampling2D(size=(2,2)))
  
  model.add(Conv2D(1,(5,5),padding='same'))
  model.add(Activation('tanh'))
  return model

In [0]:
def discriminator_model():
  model = Sequential()
  model.add(Conv2D(64,(5,5),input_shape=(28,28,1)))
  model.add(Activation('tanh'))
  
  model.add(MaxPooling2D(pool_size=(2,2)))
  model.add(Conv2D(128,(5,5)))
  model.add(Activation('tanh'))
  model.add(MaxPooling2D(pool_size=(2,2)))
  
  model.add(Flatten())
  model.add(Dense(1024))
  model.add(Activation('tanh'))
  
  model.add(Dense(1))
  model.add(Activation('sigmoid'))
  return model

In [0]:
def generator_containing_discriminator(g,d):
  model = Sequential()
  model.add(g)
  d.trainable = False
  model.add(d)
  return model

In [0]:
y_train = y_train.reshape(y_train.shape[0],1)
y_test = y_test.reshape(y_test.shape[0],1)

In [0]:
_train = (x_train,y_train)
_test = (x_test,y_test)

In [0]:
def combine_images(generated_images):
    #生成图片拼接
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

In [0]:
def train(train, test, batch_size):
  (x_train,y_train) = train
  (x_test,y_test) = test
  x_train = (x_train.astype(np.float32) - 127.5) / 127.5
  
  d = discriminator_model()
  g = generator_model()
  d_on_g = generator_containing_discriminator(g,d)
  
  d_optim = SGD(lr=0.001,momentum=0.9,nesterov=True)
  g_optim = SGD(lr=0.001,momentum=0.9,nesterov=True)
  
  g.compile(loss='binary_crossentropy',optimizer='SGD')
  d_on_g.compile(loss='binary_crossentropy',optimizer=g_optim)
  
  # 前一个架构训练了生成器，所以在训练判别器之前先要设定其为可训练。
  d.trainable = True
  d.compile(loss='binary_crossentropy',optimizer=d_optim)
  
  for epoch in range(30):
    print("Epoch is",epoch)
    
    for index in range(int(x_train.shape[0] / batch_size)):
      noise = np.random.uniform(-1,1,size=(batch_size,100))
      image_batch = x_train[index*batch_size:(index+1)*batch_size]
      
      image_batch = image_batch.reshape(image_batch.shape[0],28,28,1)
      generated_image = g.predict(noise,verbose=0)
      
      #print('g shape:',generated_image.shape)
      #print('i shape:',image_batch.shape)
      
      if index % 100 == 0:
        image = combine_images(generated_image)
        image = image*127.5 + 127.5
        Image.fromarray(image.astype(np.uint8)).save('./GAN/'+str(epoch)+'_'+str(index)+'.png')
      
      x = np.concatenate((image_batch,generated_image))
      y = [1]*batch_size + [0]*batch_size
      
      d_loss = d.train_on_batch(x,y)
      print('batch:',index,', d_loss:',d_loss)
      
      noise = np.random.uniform(-1,1,(batch_size,100))
      d.trainable = False
      
      g_loss = d_on_g.train_on_batch(noise,[1]*batch_size)
      d.trainable = True
      print('batch:',index,', g_loss:',g_loss)
      
      

In [0]:
train(_train,_test,batch_size=128)