In [3]:
from operator import itemgetter
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

AUTOTUNE = tf.data.AUTOTUNE
mnist_train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
mnist_train_ds = mnist_train_ds.cache().shuffle(
    5000).batch(32).prefetch(buffer_size=AUTOTUNE)

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(
                  from_logits=True),
              metrics=['accuracy'])

history = model.fit(mnist_train_ds, epochs=10)


def plot_loss(history):
  plt.plot(history.history['loss'], label='loss')
  plt.ylim([0, 1])
  plt.xlabel('Epoch')
  plt.ylabel('loss')
  plt.legend()
  plt.grid(True)


plot_loss(history)

predict_ds = tf.data.Dataset.from_tensor_slices(x_test).batch(32)
result = model.predict(predict_ds, steps=10)
print(result.shape)

q = [1, 5, 6, 75, 100]
fig, axs = plt.subplots(1, 5, figsize=(5, 25))
for fig_idx, (idx, prediction) in enumerate(zip(q, itemgetter(1, 5, 6, 75, 100)(result))):
  axs[fig_idx].imshow(x_test[idx].reshape(28, 28))
  pred_y = np.argmax(prediction)
  axs[fig_idx].set_title(pred_y)


Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_2 (Dense)             (None, 12544)             1254400   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 12544)             0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      819200    
 nspose)                                                         
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 14, 14, 128)       0         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 64)       204800    
 ranspose)                                            

In [4]:
from tensorflow.keras import layers

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    #model.add(layers.BatchNormalization()) # mode collapse
    model.add(layers.LeakyReLU())
    #model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'))
    #model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    #model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))

    return model

D=make_discriminator_model()
D.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_3 (Conv2D)           (None, 14, 14, 64)        1088      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 conv2d_4 (Conv2D)           (None, 7, 7, 128)         131200    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                                 
 flatten_1 (Flatten)         (None, 6272)              0         
                                                                 
 dense_3 (Dense)             (None, 1)                 6273      
                                                                 
Total params: 138,561
Trainable params: 138,561
Non-tr