In [8]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_dim=100, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(28 * 28 * 1, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))
    return model

def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(28,28,1)))
    model.add(layers.Dense(512,  activation='relu'))
    model.add(layers.Dense(256,  activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

def build_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    # Nurodo, kad discrimanator modelis nesimokins. Tai padeda isvengti perteklinio mokymo
    discriminator.trainable = False
    gan_input = tf.keras.Input(shape=(100,))
    img = generator(gan_input)
    gan_output = discriminator(img)
    gan = tf.keras.Model(gan_input, gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan

generator = build_generator()
discriminator =  build_discriminator()
gan = build_gan(generator, discriminator)

(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = (x_train.astype('float32') - 127.5)  / 127.5
x_train = np.expand_dims(x_train, axis=3)

def train(epochs, batch_size=128, save_interval=50):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        # atsitiktinai pasirenkam indeksus is mokymu duomenu rinkinio
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        imgs = x_train[idx]

        noise = np.random.normal(0,1, (batch_size, 100))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        g_loss  = gan.train_on_batch(noise, valid)
        if epoch % save_interval == 0:
            print(f'epocha {epoch} | D loss: {d_loss[0]} | G loss: {g_loss} | accuracy ? ')
            save_img(epoch)

def save_img(epoch):
    #  rows and columns
    r,c = 5,5
    noise = np.random.normal(0,1, (r * c, 100))
    gen_imgs = generator.predict(noise)
    # atstatome  intensyvumo intervala i [0,1]
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(r,c)
    counter = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[counter, :, :, 0], cmap='gray')
            axs[i,j].axis('off')
            counter +=1
    fig.savefig(f'gan_images/epocha_{epoch}.png')
    plt.close()


train(epochs=30000, batch_size=64, save_interval=1000)

[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
epocha 0 | D loss: 0.6926432847976685 | G loss: [array(0.6938501, dtype=float32), array(0.6938501, dtype=float32), array(0.4375, dtype=float32)] | accuracy ? 
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 90ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0

KeyboardInterrupt: 