In [3]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD, Adam
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

In [4]:
def generator_model():
    model = Sequential()
    model.add(Dense(units=1024, input_dim=100))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

In [5]:
def discriminator_model():
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [6]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :, 0]
    return image

In [7]:
def train(BATCH_SIZE):
    # データ読み込み(MNIST)
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = np.reshape(X_train, (60000, 28, 28, 1))

    # モデル - Adam & CrossEntropy
    d = discriminator_model()
    d_opt = Adam(lr=2e-4, beta_1=0.5) # 論文通り?_param
    #d_opt = Adam(lr=1e-5, beta_1=0.1) # はじめてのGAN_param
    d.compile(loss='binary_crossentropy', optimizer=d_opt)
    d.trainable = False
    g = generator_model()
    dcgan = Sequential([g, d])
    g_opt = Adam(lr=2e-4, beta_1=0.5)
    dcgan.compile(loss='binary_crossentropy', optimizer=g_opt)

    # 訓練を進めていく
    for epoch in range(100):
        # Prints...
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        # 訓練
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            # noise:(BATCH_SIZE, 100) | Gの入力、-1~1のノイズ
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            # image_batch:(BATCH_SIZE, 28, 28, 1)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            # generated_images:(BATCH_SIZE, 28, 28, 1)
            generated_images = g.predict(noise, verbose=0)
            # 画像を生成して保存
            if index % 100 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save("./img/"+str(epoch)+"_"+str(index)+".png")
            # Dを更新(元画像と生成画像を渡す)
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            # Gを更新
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            g_loss = dcgan.train_on_batch(noise, [1] * BATCH_SIZE)
            print("batch %d g_loss : %f" % (index, g_loss))
            # 重みを定期的に保存
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [8]:
def generate(BATCH_SIZE, nice=False):
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save("generated_image.png")

In [9]:
train(BATCH_SIZE=100)

Epoch is 0
Number of batches 600


  'Discrepancy between trainable weights and collected trainable'


batch 0 d_loss : 0.703093
batch 0 g_loss : 0.404721
batch 1 d_loss : 0.462444
batch 1 g_loss : 0.442370
batch 2 d_loss : 0.435681
batch 2 g_loss : 0.799561
batch 3 d_loss : 0.362454
batch 3 g_loss : 1.609449
batch 4 d_loss : 0.287261
batch 4 g_loss : 2.684033
batch 5 d_loss : 0.235744
batch 5 g_loss : 3.574362
batch 6 d_loss : 0.205323
batch 6 g_loss : 4.239281
batch 7 d_loss : 0.206319
batch 7 g_loss : 4.551203
batch 8 d_loss : 0.218894
batch 8 g_loss : 5.005606
batch 9 d_loss : 0.236090
batch 9 g_loss : 5.646854
batch 10 d_loss : 0.278685
batch 10 g_loss : 5.220964
batch 11 d_loss : 0.279124
batch 11 g_loss : 4.333030
batch 12 d_loss : 0.317399
batch 12 g_loss : 3.749338
batch 13 d_loss : 0.216050
batch 13 g_loss : 3.509508
batch 14 d_loss : 0.203335
batch 14 g_loss : 3.140157
batch 15 d_loss : 0.171105
batch 15 g_loss : 3.497346
batch 16 d_loss : 0.189665
batch 16 g_loss : 3.635487
batch 17 d_loss : 0.155715
batch 17 g_loss : 3.963567
batch 18 d_loss : 0.178260
batch 18 g_loss : 4.1

KeyboardInterrupt: 

In [26]:
generate(BATCH_SIZE=100, nice=True)

