**генеративно-состязательные модели (GAN)** 

Евгений Борисов borisov.e@solarl.ru

In [1]:
#  Gulli Antonio, Pal Sujit. Deep Learning with Keras -- Packt Publishing, 2017

In [2]:
import numpy as np

import IPython

# from PIL import Image
# import math


In [3]:
from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D 
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Reshape

from tensorflow.keras.optimizers import SGD

from tensorflow.keras.datasets import mnist

from tensorflow.keras import backend as K

from tensorflow.keras.utils import plot_model

In [4]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

---

In [5]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train.shape, y_train.shape,X_test.shape, y_test.shape

((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))

In [6]:
img_rows, img_cols = X_train.shape[1], X_train.shape[2]

img_rows, img_cols

(28, 28)

In [7]:
X_train = (X_train.astype(np.float32) - 127.5) / 127.5

In [8]:
new_shape = ( X_train.shape[0], 1 ) + X_train.shape[1:] 
new_shape

(60000, 1, 28, 28)

In [9]:
X_train = X_train.reshape( new_shape )
X_train.shape    

(60000, 1, 28, 28)

---

In [10]:
def generator_model():
    model = Sequential()
    model.add(Dense(1024, input_shape=(100, ), activation='tanh'))
    model.add(Dense(128 * 7 * 7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(7 * 7 * 128,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), 
                     padding='same',
                     activation='tanh',
                     data_format='channels_last'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), 
                     padding='same', 
                     activation='tanh',
                     data_format='channels_last'))
    return model

----

In [11]:
def discriminator_model():
    model = Sequential()
    model.add(Conv2D(64, (5, 5),
                    padding='same',
                    input_shape=(28, 28, 1),
                    activation='tanh',
                    data_format='channels_last'
    ))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5),
                    activation='tanh',
                    data_format='channels_last'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))
    return model

---

In [12]:
generator = generator_model()
plot_model(generator,to_file='result/model-generator.png',show_shapes=True, show_layer_names=True)
# IPython.display.Image('result/model-generator.png')

In [13]:
discriminator = discriminator_model()
plot_model( discriminator, to_file='result/model-discriminator.png', show_shapes=True, show_layer_names=True )
# IPython.display.Image('result/model-discriminator.png')

In [14]:
discriminator_on_generator = Sequential()
discriminator_on_generator.add(generator)
discriminator.trainable = False
discriminator_on_generator.add(discriminator)

plot_model( discriminator_on_generator, 
            to_file='result/model-discriminator_on_generator.png', 
            show_shapes=True, 
            show_layer_names=True )
# IPython.display.Image('result/model-discriminator_on_generator.png')

In [15]:
generator.compile(loss='binary_crossentropy', optimizer="SGD")

discriminator_on_generator.compile(
    loss='binary_crossentropy', 
    optimizer=SGD(lr=0.0005, momentum=0.9, nesterov=True)
  )

discriminator.trainable = True
discriminator.compile(
    loss='binary_crossentropy', 
    optimizer=SGD(lr=0.0005, momentum=0.9, nesterov=True)
  )


---

In [16]:
import math 
from PIL import Image

def combine_images(generated_images,e,b):
    generated_images = generated_images.reshape(generated_images.shape[0], 
                                                generated_images.shape[3], 
                                                generated_images.shape[1],
                                                generated_images.shape[2])
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[2:]
    image = np.zeros((height * shape[0], width * shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i + 1) * shape[0], j * shape[1]:(j + 1) * shape[1]] = \
            img[0, :, :]
        
    image = image * 127.5 + 127.5
    Image.fromarray( image.astype(np.uint8)).save('result/'+str(e)+"_"+str(b)+".png")
                    
#     return image

---

In [17]:
BATCH_SIZE=100
N_EPOCH = 50
N_EX = X_train.shape[0]
N_BATCH = int( np.floor(N_EX/BATCH_SIZE) )

N_EX, N_BATCH

(60000, 600)

In [18]:
# пара нейросетей генератор (создаёт "подделку") и дискриминатор (распознаёт подделку)

# одна эпоха обучения состоит из следующих шагов

# 1. генератор создаёт пакет "подделок" из случайного шума

# 2. дискриминатор имеет пакет "настоящих" образцов 
#    и обучается распознавать фальшивки на наборе фальшивые/настоящие

# 3. собираем конвейер генератор-дискриминатор
#    фиксируем веса дискриминатора (т.е. дискриминатор в конвейере не обучаем)
#    и обучаем генератор "обманывать" дискриминатор

In [None]:
%%time

for e in range(N_EPOCH):
    print('epoch: ', e+1,'/',N_EPOCH)
    
    for b in range(N_BATCH):
        print('\tbatch %d/%d, loss '%(b+1,N_BATCH))

        noise = np.random.uniform(-1, 1, (BATCH_SIZE,100))
                
        X_batch = X_train[ b*BATCH_SIZE: (b+1)*BATCH_SIZE ]
        
        X_batch = X_batch.reshape(X_batch.shape[0], 
                                  X_batch.shape[2], 
                                  X_batch.shape[3],
                                  X_batch.shape[1])
        
        generated_images = generator.predict(noise, verbose=0)
        
        if b % 200 == 0: combine_images(generated_images,e,b)
            
        X = np.concatenate((X_batch, generated_images))
        y = [1] * BATCH_SIZE + [0] * BATCH_SIZE

        discriminator.trainable = True
        d_loss = discriminator.train_on_batch(X, y)
        print('\t\t%f : discriminator'%(d_loss))
        
        noise = np.random.uniform(-1, 1, (BATCH_SIZE,100))
            
        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch( noise, [1]*BATCH_SIZE)
        
        print('\t\t%f : generator'%(g_loss))


epoch:  1 / 50
	batch 1/600, loss 
		0.674214 : discriminator
		0.695865 : generator
	batch 2/600, loss 
		0.665910 : discriminator
		0.692729 : generator
	batch 3/600, loss 
		0.649831 : discriminator
		0.691227 : generator
	batch 4/600, loss 
		0.628737 : discriminator
		0.684904 : generator
	batch 5/600, loss 
		0.618773 : discriminator
		0.675910 : generator
	batch 6/600, loss 
		0.595239 : discriminator
		0.674965 : generator
	batch 7/600, loss 
		0.578777 : discriminator
		0.663497 : generator
	batch 8/600, loss 
		0.567072 : discriminator
		0.657537 : generator
	batch 9/600, loss 
		0.545138 : discriminator
		0.648260 : generator
	batch 10/600, loss 
		0.532253 : discriminator
		0.642359 : generator
	batch 11/600, loss 
		0.511220 : discriminator
		0.635515 : generator
	batch 12/600, loss 
		0.499900 : discriminator
		0.635243 : generator
	batch 13/600, loss 
		0.503922 : discriminator
		0.618490 : generator
	batch 14/600, loss 
		0.495158 : discriminator
		0.624012 : generator


		0.278136 : discriminator
		0.725919 : generator
	batch 117/600, loss 
		0.274592 : discriminator
		0.722634 : generator
	batch 118/600, loss 
		0.265097 : discriminator
		0.750862 : generator
	batch 119/600, loss 
		0.276635 : discriminator
		0.704134 : generator
	batch 120/600, loss 
		0.267461 : discriminator
		0.709990 : generator
	batch 121/600, loss 
		0.278855 : discriminator
		0.684701 : generator
	batch 122/600, loss 
		0.295659 : discriminator
		0.694319 : generator
	batch 123/600, loss 
		0.270229 : discriminator
		0.705061 : generator
	batch 124/600, loss 
		0.304911 : discriminator
		0.650763 : generator
	batch 125/600, loss 
		0.297139 : discriminator
		0.672349 : generator
	batch 126/600, loss 
		0.311641 : discriminator
		0.657628 : generator
	batch 127/600, loss 
		0.311665 : discriminator
		0.655520 : generator
	batch 128/600, loss 
		0.352207 : discriminator
		0.640378 : generator
	batch 129/600, loss 
		0.351566 : discriminator
		0.653970 : generator
	batch 130/600

		0.263280 : discriminator
		1.273918 : generator
	batch 231/600, loss 
		0.222802 : discriminator
		1.260997 : generator
	batch 232/600, loss 
		0.256325 : discriminator
		1.220755 : generator
	batch 233/600, loss 
		0.259325 : discriminator
		1.230774 : generator
	batch 234/600, loss 
		0.270970 : discriminator
		1.245402 : generator
	batch 235/600, loss 
		0.283643 : discriminator
		1.182996 : generator
	batch 236/600, loss 
		0.311991 : discriminator
		1.198752 : generator
	batch 237/600, loss 
		0.335903 : discriminator
		1.202134 : generator
	batch 238/600, loss 
		0.355296 : discriminator
		1.042035 : generator
	batch 239/600, loss 
		0.343438 : discriminator
		1.084668 : generator
	batch 240/600, loss 
		0.318235 : discriminator
		1.017372 : generator
	batch 241/600, loss 
		0.269323 : discriminator
		0.962296 : generator
	batch 242/600, loss 
		0.317159 : discriminator
		0.949825 : generator
	batch 243/600, loss 
		0.356125 : discriminator
		0.981008 : generator
	batch 244/600

		0.460011 : discriminator
		1.053665 : generator
	batch 345/600, loss 
		0.493492 : discriminator
		1.072288 : generator
	batch 346/600, loss 
		0.452722 : discriminator
		1.049549 : generator
	batch 347/600, loss 
		0.478477 : discriminator
		1.065630 : generator
	batch 348/600, loss 
		0.506373 : discriminator
		1.037872 : generator
	batch 349/600, loss 
		0.487683 : discriminator
		1.042506 : generator
	batch 350/600, loss 
		0.451793 : discriminator
		1.075109 : generator
	batch 351/600, loss 
		0.433889 : discriminator
		1.090687 : generator
	batch 352/600, loss 
		0.472976 : discriminator
		1.042720 : generator
	batch 353/600, loss 
		0.489242 : discriminator
		1.071852 : generator
	batch 354/600, loss 
		0.455571 : discriminator
		1.046857 : generator
	batch 355/600, loss 
		0.497278 : discriminator
		1.006548 : generator
	batch 356/600, loss 
		0.467477 : discriminator
		1.057740 : generator
	batch 357/600, loss 
		0.490329 : discriminator
		1.060494 : generator
	batch 358/600

		0.466227 : discriminator
		1.102820 : generator
	batch 459/600, loss 
		0.440815 : discriminator
		1.151174 : generator
	batch 460/600, loss 
		0.475936 : discriminator
		1.074120 : generator
	batch 461/600, loss 
		0.488920 : discriminator
		1.127325 : generator
	batch 462/600, loss 
		0.480668 : discriminator
		1.119215 : generator
	batch 463/600, loss 
		0.460242 : discriminator
		1.125335 : generator
	batch 464/600, loss 
		0.451077 : discriminator
		1.082648 : generator
	batch 465/600, loss 
		0.499257 : discriminator
		1.107407 : generator
	batch 466/600, loss 
		0.490692 : discriminator
		1.088282 : generator
	batch 467/600, loss 
		0.490822 : discriminator
		1.158533 : generator
	batch 468/600, loss 
		0.472460 : discriminator
		1.083891 : generator
	batch 469/600, loss 
		0.458921 : discriminator
		1.173739 : generator
	batch 470/600, loss 
		0.490963 : discriminator
		1.195328 : generator
	batch 471/600, loss 
		0.484264 : discriminator
		1.152205 : generator
	batch 472/600

		0.474717 : discriminator
		1.094177 : generator
	batch 573/600, loss 
		0.461266 : discriminator
		1.152609 : generator
	batch 574/600, loss 
		0.478659 : discriminator
		1.116560 : generator
	batch 575/600, loss 
		0.496445 : discriminator
		1.112268 : generator
	batch 576/600, loss 
		0.481540 : discriminator
		1.172390 : generator
	batch 577/600, loss 
		0.479420 : discriminator
		1.185177 : generator
	batch 578/600, loss 
		0.465198 : discriminator
		1.169878 : generator
	batch 579/600, loss 
		0.450747 : discriminator
		1.150996 : generator
	batch 580/600, loss 
		0.518279 : discriminator
		1.122290 : generator
	batch 581/600, loss 
		0.451242 : discriminator
		1.086304 : generator
	batch 582/600, loss 
		0.495545 : discriminator
		1.072810 : generator
	batch 583/600, loss 
		0.458818 : discriminator
		1.112829 : generator
	batch 584/600, loss 
		0.511489 : discriminator
		1.087535 : generator
	batch 585/600, loss 
		0.465012 : discriminator
		1.169517 : generator
	batch 586/600

		0.471122 : discriminator
		1.110843 : generator
	batch 88/600, loss 
		0.460108 : discriminator
		1.132337 : generator
	batch 89/600, loss 
		0.509589 : discriminator
		1.123423 : generator
	batch 90/600, loss 
		0.459965 : discriminator
		1.122247 : generator
	batch 91/600, loss 
		0.440090 : discriminator
		1.179094 : generator
	batch 92/600, loss 
		0.440246 : discriminator
		1.156657 : generator
	batch 93/600, loss 
		0.464248 : discriminator
		1.180976 : generator
	batch 94/600, loss 
		0.528588 : discriminator
		1.172051 : generator
	batch 95/600, loss 
		0.498023 : discriminator
		1.125641 : generator
	batch 96/600, loss 
		0.493950 : discriminator
		1.139174 : generator
	batch 97/600, loss 
		0.488705 : discriminator
		1.069798 : generator
	batch 98/600, loss 
		0.424578 : discriminator
		1.043646 : generator
	batch 99/600, loss 
		0.446473 : discriminator
		1.102633 : generator
	batch 100/600, loss 
		0.429099 : discriminator
		1.179060 : generator
	batch 101/600, loss 
		0.

		0.548478 : discriminator
		1.092301 : generator
	batch 202/600, loss 
		0.499616 : discriminator
		1.133327 : generator
	batch 203/600, loss 
		0.485386 : discriminator
		1.057681 : generator
	batch 204/600, loss 
		0.441201 : discriminator
		1.074000 : generator
	batch 205/600, loss 
		0.419731 : discriminator
		1.164963 : generator
	batch 206/600, loss 
		0.414122 : discriminator
		1.285647 : generator
	batch 207/600, loss 
		0.451805 : discriminator
		1.229335 : generator
	batch 208/600, loss 
		0.429918 : discriminator
		1.346868 : generator
	batch 209/600, loss 
		0.497571 : discriminator
		1.289963 : generator
	batch 210/600, loss 
		0.509944 : discriminator
		1.196969 : generator
	batch 211/600, loss 
		0.442430 : discriminator
		1.165591 : generator
	batch 212/600, loss 
		0.432973 : discriminator
		1.171485 : generator
	batch 213/600, loss 
		0.418939 : discriminator
		1.266271 : generator
	batch 214/600, loss 
		0.431546 : discriminator
		1.246609 : generator
	batch 215/600

		0.402087 : discriminator
		1.289355 : generator
	batch 316/600, loss 
		0.426862 : discriminator
		1.246315 : generator
	batch 317/600, loss 
		0.408972 : discriminator
		1.238870 : generator
	batch 318/600, loss 
		0.441500 : discriminator
		1.171164 : generator
	batch 319/600, loss 
		0.403576 : discriminator
		1.267251 : generator
	batch 320/600, loss 
		0.410275 : discriminator
		1.234967 : generator
	batch 321/600, loss 
		0.411007 : discriminator
		1.284242 : generator
	batch 322/600, loss 
		0.459607 : discriminator
		1.256859 : generator
	batch 323/600, loss 
		0.421579 : discriminator
		1.241879 : generator
	batch 324/600, loss 
		0.560712 : discriminator
		1.131594 : generator
	batch 325/600, loss 
		0.552038 : discriminator
		1.113792 : generator
	batch 326/600, loss 
		0.461839 : discriminator
		1.042209 : generator
	batch 327/600, loss 
		0.408751 : discriminator
		1.203072 : generator
	batch 328/600, loss 
		0.411496 : discriminator
		1.312048 : generator
	batch 329/600

		0.339780 : discriminator
		1.474287 : generator
	batch 430/600, loss 
		0.349028 : discriminator
		1.423392 : generator
	batch 431/600, loss 
		0.361508 : discriminator
		1.362440 : generator
	batch 432/600, loss 
		0.355429 : discriminator
		1.367233 : generator
	batch 433/600, loss 
		0.344296 : discriminator
		1.423564 : generator
	batch 434/600, loss 
		0.348525 : discriminator
		1.372782 : generator
	batch 435/600, loss 
		0.355861 : discriminator
		1.348232 : generator
	batch 436/600, loss 
		0.326806 : discriminator
		1.355095 : generator
	batch 437/600, loss 
		0.401593 : discriminator
		1.362191 : generator
	batch 438/600, loss 
		0.360216 : discriminator
		1.275792 : generator
	batch 439/600, loss 
		0.373178 : discriminator
		1.225763 : generator
	batch 440/600, loss 
		0.360803 : discriminator
		1.313017 : generator
	batch 441/600, loss 
		0.396836 : discriminator
		1.353163 : generator
	batch 442/600, loss 
		0.394460 : discriminator
		1.292282 : generator
	batch 443/600

		0.305662 : discriminator
		1.474466 : generator
	batch 544/600, loss 
		0.335370 : discriminator
		1.445462 : generator
	batch 545/600, loss 
		0.363043 : discriminator
		1.399876 : generator
	batch 546/600, loss 
		0.352064 : discriminator
		1.527990 : generator
	batch 547/600, loss 
		0.326426 : discriminator
		1.552000 : generator
	batch 548/600, loss 
		0.349837 : discriminator
		1.581346 : generator
	batch 549/600, loss 
		0.383738 : discriminator
		1.330416 : generator
	batch 550/600, loss 
		0.384327 : discriminator
		1.337694 : generator
	batch 551/600, loss 
		0.333826 : discriminator
		1.451649 : generator
	batch 552/600, loss 
		0.335650 : discriminator
		1.437503 : generator
	batch 553/600, loss 
		0.318320 : discriminator
		1.464434 : generator
	batch 554/600, loss 
		0.329916 : discriminator
		1.549680 : generator
	batch 555/600, loss 
		0.308172 : discriminator
		1.605498 : generator
	batch 556/600, loss 
		0.394294 : discriminator
		1.293538 : generator
	batch 557/600

		0.389242 : discriminator
		1.499697 : generator
	batch 59/600, loss 
		0.370767 : discriminator
		1.521825 : generator
	batch 60/600, loss 
		0.368733 : discriminator
		1.692722 : generator
	batch 61/600, loss 
		0.442017 : discriminator
		1.304070 : generator
	batch 62/600, loss 
		0.396508 : discriminator
		1.372012 : generator
	batch 63/600, loss 
		0.343885 : discriminator
		1.470808 : generator
	batch 64/600, loss 
		0.362819 : discriminator
		1.490175 : generator
	batch 65/600, loss 
		0.419668 : discriminator
		1.378344 : generator
	batch 66/600, loss 
		0.345401 : discriminator


In [None]:
generator.save_weights('result/generator', True)
discriminator.save_weights('result/discriminator', True)

---

In [None]:
# def generate(BATCH_SIZE, nice=False):
#     generator = generator_model()
#     generator.compile(loss='binary_crossentropy', optimizer='SGD')
#     generator.load_weights('generator')
#     if nice:
#         discriminator = discriminator_model()
#         discriminator.compile(loss='binary_crossentropy', optimizer='SGD')
#         discriminator.load_weights('discriminator')
#         noise = np.zeros((BATCH_SIZE * 20, 100))
                
#         for i in range(BATCH_SIZE * 20): noise[i, :] = np.random.uniform(-1, 1, 100)
        
#         generated_images = generator.predict(noise, verbose=1)
#         d_pret = discriminator.predict(generated_images, verbose=1)
        
#         index = np.arrange(0, BATCH_SIZE * 20)
#         index.resize((BATCH_SIZE * 20, 1))
#         pre_with_index = list(np.append(d_pret, index, axis=1))
#         pre_with_index.sort(key=lambda x: x[0], reverse=True)
#         nice_images = np.zeros((BATCH_SIZE, 1) + (generated_images.shape[2:]), dtype=np.float32)
        
#         for i in range(int(BATCH_SIZE)):
#             idx = int(pre_with_index[i][1])
#             nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
            
#         image = combine_images(nice_images)
        
#     else:
#         noise = np.zeros((BATCH_SIZE, 100))
#         for i in range(BATCH_SIZE): noise[i, :] = np.random.uniform(-1, 1, 100)
#         generated_images = generator.predict(noise, verbose=1)
#         image = combine_images(generated_images)
        
        
#     image = image * 127.5 + 127.5
#     Image.fromarray(image.astype(np.uint8)).save('geneted_image.png')

In [None]:
# Epoch time calculate 296.2202084370001

---

In [None]:
# plot_model(generator,to_file='model-generator.png',show_shapes=True, show_layer_names=True)
# IPython.display.Image('model-generator.png')

In [None]:
# plot_model( discriminator, to_file='model-discriminator.png', show_shapes=True, show_layer_names=True )
# IPython.display.Image('model-discriminator.png')

In [None]:
# plot_model( discriminator_on_generator, to_file='model-discriminator_on_generator.png', 
#                             show_shapes=True, show_layer_names=True )
# IPython.display.Image('model-discriminator_on_generator.png')

In [None]:
# discriminator = discriminator_model()
# generator = generator_model()

# discriminator_on_generator = generator_containing_discriminator(generator, discriminator)

# d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
# g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
# generator.compile(loss='binary_crossentropy', optimizer="SGD")

# discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
# discriminator.trainable = True
# discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)

In [None]:
# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot
# SVG(model_to_dot(generator).create(prog='dot', format='svg'))