**генеративно-состязательные модели (GAN)** 

Евгений Борисов borisov.e@solarl.ru

In [None]:
#  Gulli Antonio, Pal Sujit. Deep Learning with Keras -- Packt Publishing, 2017

In [1]:
import numpy as np

import IPython

# from PIL import Image
# import math


In [2]:
from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D 
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Reshape

from tensorflow.keras.optimizers import SGD

from tensorflow.keras.datasets import mnist

from tensorflow.keras import backend as K

from tensorflow.keras.utils import plot_model

In [3]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

---

In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train.shape, y_train.shape,X_test.shape, y_test.shape

((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))

In [5]:
img_rows, img_cols = X_train.shape[1], X_train.shape[2]

img_rows, img_cols

(28, 28)

In [6]:
X_train = (X_train.astype(np.float32) - 127.5) / 127.5

In [7]:
new_shape = ( X_train.shape[0], 1 ) + X_train.shape[1:] 
new_shape

(60000, 1, 28, 28)

In [8]:
X_train = X_train.reshape( new_shape )
X_train.shape    

(60000, 1, 28, 28)

---

In [9]:
def generator_model():
    model = Sequential()
    model.add(Dense(1024, input_shape=(100, ), activation='tanh'))
    model.add(Dense(128 * 7 * 7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(7 * 7 * 128,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), 
                     padding='same',
                     activation='tanh',
                     data_format='channels_last'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), 
                     padding='same', 
                     activation='tanh',
                     data_format='channels_last'))
    return model

----

In [10]:
def discriminator_model():
    model = Sequential()
    model.add(Conv2D(64, (5, 5),
                    padding='same',
                    input_shape=(28, 28, 1),
                    activation='tanh',
                    data_format='channels_last'
    ))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5),
                    activation='tanh',
                    data_format='channels_last'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))
    return model

---

In [11]:
generator = generator_model()
plot_model(generator,to_file='result/model-generator.png',show_shapes=True, show_layer_names=True)
# IPython.display.Image('result/model-generator.png')

In [12]:
discriminator = discriminator_model()
plot_model( discriminator, to_file='result/model-discriminator.png', show_shapes=True, show_layer_names=True )
# IPython.display.Image('result/model-discriminator.png')

In [13]:
discriminator_on_generator = Sequential()
discriminator_on_generator.add(generator)
discriminator.trainable = False
discriminator_on_generator.add(discriminator)

plot_model( discriminator_on_generator, 
            to_file='result/model-discriminator_on_generator.png', 
            show_shapes=True, 
            show_layer_names=True )
# IPython.display.Image('result/model-discriminator_on_generator.png')

In [14]:
generator.compile(loss='binary_crossentropy', optimizer="SGD")

discriminator_on_generator.compile(
    loss='binary_crossentropy', 
    optimizer=SGD(lr=0.0005, momentum=0.9, nesterov=True)
  )

discriminator.trainable = True
discriminator.compile(
    loss='binary_crossentropy', 
    optimizer=SGD(lr=0.0005, momentum=0.9, nesterov=True)
  )


---

In [15]:
import math 
from PIL import Image

def combine_images(generated_images,e,b):
    generated_images = generated_images.reshape(generated_images.shape[0], 
                                                generated_images.shape[3], 
                                                generated_images.shape[1],
                                                generated_images.shape[2])
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[2:]
    image = np.zeros((height * shape[0], width * shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i + 1) * shape[0], j * shape[1]:(j + 1) * shape[1]] = \
            img[0, :, :]
        
    image = image * 127.5 + 127.5
    Image.fromarray( image.astype(np.uint8)).save('result/'+str(e)+"_"+str(b)+".png")
                    
#     return image

---

In [16]:
BATCH_SIZE=100
N_EPOCH = 50
N_EX = X_train.shape[0]
N_BATCH = int( np.floor(N_EX/BATCH_SIZE) )

N_EX, N_BATCH

(60000, 600)

In [None]:
# пара нейросетей генератор (создаёт "подделку") и дискриминатор (распознаёт подделку)

# одна эпоха обучения состоит из следующих шагов

# 1. генератор создаёт пакет "подделок" из случайного шума

# 2. дискриминатор имеет пакет "настоящих" образцов 
#    и обучается распознавать фальшивки на наборе фальшивые/настоящие

# 3. собираем конвейер генератор-дискриминатор
#    фиксируем веса дискриминатора (т.е. дискриминатор в конвейере не обучаем)
#    и обучаем генератор "обманывать" дискриминатор

In [17]:
%%time

for e in range(N_EPOCH):
    print('epoch: ', e+1,'/',N_EPOCH)
    
    for b in range(N_BATCH):
        print('\tbatch %d/%d '%(b+1,N_BATCH))

        noise = np.random.uniform(-1, 1, (BATCH_SIZE,100))
                
        X_batch = X_train[ b*BATCH_SIZE: (b+1)*BATCH_SIZE ]
        
        X_batch = X_batch.reshape(X_batch.shape[0], 
                                  X_batch.shape[2], 
                                  X_batch.shape[3],
                                  X_batch.shape[1])
        
        generated_images = generator.predict(noise, verbose=0)
        
        if b % 200 == 0: combine_images(generated_images,e,b)
            
        X = np.concatenate((X_batch, generated_images))
        y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
        
        d_loss = discriminator.train_on_batch(X, y)
        print('\t\td_loss : %f'%(d_loss))
        
        noise = np.random.uniform(-1, 1, (BATCH_SIZE,100))
            
        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch( noise, [1]*BATCH_SIZE)
        
        discriminator.trainable = True
        print('\t\tg_loss : %f'%(g_loss))
        
    

epoch:  1 / 2
	batch 0/600 :
		d_loss : 0.636213
		g_loss : 0.611590
	batch 1/600 :
		d_loss : 0.630389
		g_loss : 0.606158
	batch 2/600 :
		d_loss : 0.622371
		g_loss : 0.605547
	batch 3/600 :
		d_loss : 0.608237
		g_loss : 0.595105
	batch 4/600 :
		d_loss : 0.602331
		g_loss : 0.589922
	batch 5/600 :
		d_loss : 0.587552
		g_loss : 0.585451
	batch 6/600 :
		d_loss : 0.575067
		g_loss : 0.576894
	batch 7/600 :
		d_loss : 0.559080
		g_loss : 0.569569
	batch 8/600 :
		d_loss : 0.548740
		g_loss : 0.566607
	batch 9/600 :
		d_loss : 0.539093
		g_loss : 0.564356
	batch 10/600 :
		d_loss : 0.529142
		g_loss : 0.557946
	batch 11/600 :
		d_loss : 0.519090
		g_loss : 0.549618
	batch 12/600 :
		d_loss : 0.505874
		g_loss : 0.543344
	batch 13/600 :
		d_loss : 0.501682
		g_loss : 0.535263
	batch 14/600 :
		d_loss : 0.490986
		g_loss : 0.535388
	batch 15/600 :
		d_loss : 0.483176
		g_loss : 0.533013
	batch 16/600 :
		d_loss : 0.474605
		g_loss : 0.532395
	batch 17/600 :
		d_loss : 0.467965
		g_loss

		g_loss : 0.876663
	batch 146/600 :
		d_loss : 0.269768
		g_loss : 0.859038
	batch 147/600 :
		d_loss : 0.287984
		g_loss : 0.849095
	batch 148/600 :
		d_loss : 0.283117
		g_loss : 0.843749
	batch 149/600 :
		d_loss : 0.250156
		g_loss : 0.874851
	batch 150/600 :
		d_loss : 0.258476
		g_loss : 0.871568
	batch 151/600 :
		d_loss : 0.287803
		g_loss : 0.848797
	batch 152/600 :
		d_loss : 0.263549
		g_loss : 0.866828
	batch 153/600 :
		d_loss : 0.264666
		g_loss : 0.902513
	batch 154/600 :
		d_loss : 0.294396
		g_loss : 0.889120
	batch 155/600 :
		d_loss : 0.291606
		g_loss : 0.881883
	batch 156/600 :
		d_loss : 0.259963
		g_loss : 0.886060
	batch 157/600 :
		d_loss : 0.269073
		g_loss : 0.951225
	batch 158/600 :
		d_loss : 0.281228
		g_loss : 0.940088
	batch 159/600 :
		d_loss : 0.259789
		g_loss : 0.931023
	batch 160/600 :
		d_loss : 0.237209
		g_loss : 0.967233
	batch 161/600 :
		d_loss : 0.244771
		g_loss : 1.012649
	batch 162/600 :
		d_loss : 0.250420
		g_loss : 0.983948
	batch 163/

		g_loss : 1.082245
	batch 290/600 :
		d_loss : 0.615820
		g_loss : 1.027936
	batch 291/600 :
		d_loss : 0.601648
		g_loss : 0.995055
	batch 292/600 :
		d_loss : 0.604752
		g_loss : 0.935809
	batch 293/600 :
		d_loss : 0.490806
		g_loss : 0.972350
	batch 294/600 :
		d_loss : 0.460302
		g_loss : 0.903096
	batch 295/600 :
		d_loss : 0.523202
		g_loss : 0.989310
	batch 296/600 :
		d_loss : 0.579042
		g_loss : 0.985777
	batch 297/600 :
		d_loss : 0.564293
		g_loss : 1.007486
	batch 298/600 :
		d_loss : 0.614795
		g_loss : 1.075486
	batch 299/600 :
		d_loss : 0.645301
		g_loss : 0.961097
	batch 300/600 :
		d_loss : 0.588656
		g_loss : 0.934539
	batch 301/600 :
		d_loss : 0.591972
		g_loss : 0.891299
	batch 302/600 :
		d_loss : 0.559203
		g_loss : 0.927131
	batch 303/600 :
		d_loss : 0.545016
		g_loss : 0.949591
	batch 304/600 :
		d_loss : 0.505462
		g_loss : 1.050755
	batch 305/600 :
		d_loss : 0.535378
		g_loss : 1.135633
	batch 306/600 :
		d_loss : 0.626749
		g_loss : 1.056668
	batch 307/

		g_loss : 1.065430
	batch 434/600 :
		d_loss : 0.486592
		g_loss : 1.080622
	batch 435/600 :
		d_loss : 0.479098
		g_loss : 1.082066
	batch 436/600 :
		d_loss : 0.531412
		g_loss : 1.096329
	batch 437/600 :
		d_loss : 0.497933
		g_loss : 1.082321
	batch 438/600 :
		d_loss : 0.505504
		g_loss : 1.090189
	batch 439/600 :
		d_loss : 0.492974
		g_loss : 1.083596
	batch 440/600 :
		d_loss : 0.517305
		g_loss : 1.091376
	batch 441/600 :
		d_loss : 0.503565
		g_loss : 1.097980
	batch 442/600 :
		d_loss : 0.513594
		g_loss : 1.032978
	batch 443/600 :
		d_loss : 0.521153
		g_loss : 1.006919
	batch 444/600 :
		d_loss : 0.519764
		g_loss : 0.980504
	batch 445/600 :
		d_loss : 0.525912
		g_loss : 1.030356
	batch 446/600 :
		d_loss : 0.496765
		g_loss : 1.076075
	batch 447/600 :
		d_loss : 0.513397
		g_loss : 1.048709
	batch 448/600 :
		d_loss : 0.539676
		g_loss : 1.062925
	batch 449/600 :
		d_loss : 0.524139
		g_loss : 1.021136
	batch 450/600 :
		d_loss : 0.534990
		g_loss : 1.036230
	batch 451/

		g_loss : 1.144554
	batch 578/600 :
		d_loss : 0.465078
		g_loss : 1.074068
	batch 579/600 :
		d_loss : 0.526373
		g_loss : 1.061939
	batch 580/600 :
		d_loss : 0.469136
		g_loss : 1.084676
	batch 581/600 :
		d_loss : 0.538561
		g_loss : 1.087216
	batch 582/600 :
		d_loss : 0.488354
		g_loss : 1.144050
	batch 583/600 :
		d_loss : 0.539775
		g_loss : 1.165471
	batch 584/600 :
		d_loss : 0.489643
		g_loss : 1.160810
	batch 585/600 :
		d_loss : 0.493427
		g_loss : 1.107819
	batch 586/600 :
		d_loss : 0.383450
		g_loss : 1.181093
	batch 587/600 :
		d_loss : 0.483376
		g_loss : 1.145735
	batch 588/600 :
		d_loss : 0.482141
		g_loss : 1.087341
	batch 589/600 :
		d_loss : 0.433959
		g_loss : 1.188792
	batch 590/600 :
		d_loss : 0.433694
		g_loss : 1.132528
	batch 591/600 :
		d_loss : 0.432462
		g_loss : 1.192999
	batch 592/600 :
		d_loss : 0.453009
		g_loss : 1.236513
	batch 593/600 :
		d_loss : 0.466961
		g_loss : 1.365376
	batch 594/600 :
		d_loss : 0.481783
		g_loss : 1.370348
	batch 595/

		d_loss : 0.526856
		g_loss : 1.053181
	batch 124/600 :
		d_loss : 0.452376
		g_loss : 1.069252
	batch 125/600 :
		d_loss : 0.442334
		g_loss : 1.131252
	batch 126/600 :
		d_loss : 0.443236
		g_loss : 1.181431
	batch 127/600 :
		d_loss : 0.486197
		g_loss : 1.085723
	batch 128/600 :
		d_loss : 0.457543
		g_loss : 1.110626
	batch 129/600 :
		d_loss : 0.519760
		g_loss : 1.098306
	batch 130/600 :
		d_loss : 0.463135
		g_loss : 1.168026
	batch 131/600 :
		d_loss : 0.490911
		g_loss : 1.218385
	batch 132/600 :
		d_loss : 0.527339
		g_loss : 1.132286
	batch 133/600 :
		d_loss : 0.474795
		g_loss : 1.133690
	batch 134/600 :
		d_loss : 0.422113
		g_loss : 1.151781
	batch 135/600 :
		d_loss : 0.420762
		g_loss : 1.198435
	batch 136/600 :
		d_loss : 0.544212
		g_loss : 1.081521
	batch 137/600 :
		d_loss : 0.495414
		g_loss : 0.984389
	batch 138/600 :
		d_loss : 0.477504
		g_loss : 1.078502
	batch 139/600 :
		d_loss : 0.567524
		g_loss : 0.921495
	batch 140/600 :
		d_loss : 0.538425
		g_loss : 

		d_loss : 0.445513
		g_loss : 1.294396
	batch 268/600 :
		d_loss : 0.386118
		g_loss : 1.322204
	batch 269/600 :
		d_loss : 0.354003
		g_loss : 1.373997
	batch 270/600 :
		d_loss : 0.388254
		g_loss : 1.257684
	batch 271/600 :
		d_loss : 0.442104
		g_loss : 1.212788
	batch 272/600 :
		d_loss : 0.467484
		g_loss : 1.120745
	batch 273/600 :
		d_loss : 0.411409
		g_loss : 1.135812
	batch 274/600 :
		d_loss : 0.428525
		g_loss : 1.186591
	batch 275/600 :
		d_loss : 0.410652
		g_loss : 1.170102
	batch 276/600 :
		d_loss : 0.363837
		g_loss : 1.347790
	batch 277/600 :
		d_loss : 0.430638
		g_loss : 1.221706
	batch 278/600 :
		d_loss : 0.400316
		g_loss : 1.227871
	batch 279/600 :
		d_loss : 0.401672
		g_loss : 1.332464
	batch 280/600 :
		d_loss : 0.388152
		g_loss : 1.351259
	batch 281/600 :
		d_loss : 0.444038
		g_loss : 1.302296
	batch 282/600 :
		d_loss : 0.462713
		g_loss : 1.180591
	batch 283/600 :
		d_loss : 0.458939
		g_loss : 1.137012
	batch 284/600 :
		d_loss : 0.459324
		g_loss : 

		d_loss : 0.363347
		g_loss : 1.301967
	batch 412/600 :
		d_loss : 0.399096
		g_loss : 1.258415
	batch 413/600 :
		d_loss : 0.439028
		g_loss : 1.061147
	batch 414/600 :
		d_loss : 0.440559
		g_loss : 1.122073
	batch 415/600 :
		d_loss : 0.447579
		g_loss : 1.284515
	batch 416/600 :
		d_loss : 0.452113
		g_loss : 1.204184
	batch 417/600 :
		d_loss : 0.460048
		g_loss : 1.048798
	batch 418/600 :
		d_loss : 0.431599
		g_loss : 1.148374
	batch 419/600 :
		d_loss : 0.420311
		g_loss : 1.391090
	batch 420/600 :
		d_loss : 0.378346
		g_loss : 1.441105
	batch 421/600 :
		d_loss : 0.387172
		g_loss : 1.314939
	batch 422/600 :
		d_loss : 0.397500
		g_loss : 1.342092
	batch 423/600 :
		d_loss : 0.448075
		g_loss : 1.055198
	batch 424/600 :
		d_loss : 0.447621
		g_loss : 1.057822
	batch 425/600 :
		d_loss : 0.431164
		g_loss : 1.281144
	batch 426/600 :
		d_loss : 0.421348
		g_loss : 1.437783
	batch 427/600 :
		d_loss : 0.402552
		g_loss : 1.374926
	batch 428/600 :
		d_loss : 0.326520
		g_loss : 

		d_loss : 0.431627
		g_loss : 1.364025
	batch 556/600 :
		d_loss : 0.401744
		g_loss : 1.561125
	batch 557/600 :
		d_loss : 0.348475
		g_loss : 1.463173
	batch 558/600 :
		d_loss : 0.317716
		g_loss : 1.687151
	batch 559/600 :
		d_loss : 0.355458
		g_loss : 1.594731
	batch 560/600 :
		d_loss : 0.370540
		g_loss : 1.293244
	batch 561/600 :
		d_loss : 0.359726
		g_loss : 1.677788
	batch 562/600 :
		d_loss : 0.406396
		g_loss : 1.386723
	batch 563/600 :
		d_loss : 0.363872
		g_loss : 1.373227
	batch 564/600 :
		d_loss : 0.394011
		g_loss : 1.384315
	batch 565/600 :
		d_loss : 0.423221
		g_loss : 1.215156
	batch 566/600 :
		d_loss : 0.409411
		g_loss : 1.447092
	batch 567/600 :
		d_loss : 0.404160
		g_loss : 1.689867
	batch 568/600 :
		d_loss : 0.383920
		g_loss : 1.482435
	batch 569/600 :
		d_loss : 0.349533
		g_loss : 1.409469
	batch 570/600 :
		d_loss : 0.381724
		g_loss : 1.177940
	batch 571/600 :
		d_loss : 0.389124
		g_loss : 1.405726
	batch 572/600 :
		d_loss : 0.331354
		g_loss : 

In [None]:
generator.save_weights('result/generator', True)
discriminator.save_weights('result/discriminator', True)

In [None]:
# noise = np.zeros( (BATCH_SIZE, 100) )
# for i in range(BATCH_SIZE): noise[i, :] = np.random.uniform(-1, 1, 100)

---

In [None]:
# def generate(BATCH_SIZE, nice=False):
#     generator = generator_model()
#     generator.compile(loss='binary_crossentropy', optimizer='SGD')
#     generator.load_weights('generator')
#     if nice:
#         discriminator = discriminator_model()
#         discriminator.compile(loss='binary_crossentropy', optimizer='SGD')
#         discriminator.load_weights('discriminator')
#         noise = np.zeros((BATCH_SIZE * 20, 100))
                
#         for i in range(BATCH_SIZE * 20): noise[i, :] = np.random.uniform(-1, 1, 100)
        
#         generated_images = generator.predict(noise, verbose=1)
#         d_pret = discriminator.predict(generated_images, verbose=1)
        
#         index = np.arrange(0, BATCH_SIZE * 20)
#         index.resize((BATCH_SIZE * 20, 1))
#         pre_with_index = list(np.append(d_pret, index, axis=1))
#         pre_with_index.sort(key=lambda x: x[0], reverse=True)
#         nice_images = np.zeros((BATCH_SIZE, 1) + (generated_images.shape[2:]), dtype=np.float32)
        
#         for i in range(int(BATCH_SIZE)):
#             idx = int(pre_with_index[i][1])
#             nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
            
#         image = combine_images(nice_images)
        
#     else:
#         noise = np.zeros((BATCH_SIZE, 100))
#         for i in range(BATCH_SIZE): noise[i, :] = np.random.uniform(-1, 1, 100)
#         generated_images = generator.predict(noise, verbose=1)
#         image = combine_images(generated_images)
        
        
#     image = image * 127.5 + 127.5
#     Image.fromarray(image.astype(np.uint8)).save('geneted_image.png')

In [None]:
# Epoch time calculate 296.2202084370001

---

In [None]:
# plot_model(generator,to_file='model-generator.png',show_shapes=True, show_layer_names=True)
# IPython.display.Image('model-generator.png')

In [None]:
# plot_model( discriminator, to_file='model-discriminator.png', show_shapes=True, show_layer_names=True )
# IPython.display.Image('model-discriminator.png')

In [None]:
# plot_model( discriminator_on_generator, to_file='model-discriminator_on_generator.png', 
#                             show_shapes=True, show_layer_names=True )
# IPython.display.Image('model-discriminator_on_generator.png')

In [None]:
# discriminator = discriminator_model()
# generator = generator_model()

# discriminator_on_generator = generator_containing_discriminator(generator, discriminator)

# d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
# g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
# generator.compile(loss='binary_crossentropy', optimizer="SGD")

# discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
# discriminator.trainable = True
# discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)

In [None]:
# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot
# SVG(model_to_dot(generator).create(prog='dot', format='svg'))