In [None]:
import os, math
from PIL import Image, ImageDraw, ImageFont
from tensorflow.keras.datasets import cifar10, mnist
import matplotlib.pyplot as plt
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, BatchNormalization 
from tensorflow.keras.layers import Conv2D, Flatten, Activation, Concatenate
from tensorflow.keras.layers import Reshape, Conv2DTranspose, LeakyReLU
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import mse, SparseCategoricalCrossentropy

In [None]:
tf.random.set_seed(42)

In [None]:
(x_train, y_train),(x_test, _) = mnist.load_data()

In [None]:
x_train = x_train.reshape(-1, 28, 28, 1)
x_train = x_train / 255

In [None]:
x_train.shape

(60000, 28, 28, 1)

In [None]:
image_size = x_train[0].shape[1]
image_size

28

In [None]:
class ConvTransBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides):
        super().__init__()
        self.bn = BatchNormalization()  # check training flag
        self.act = Activation(activation='relu')
        self.conv2D_trans = Conv2DTranspose(filters=filters,
                                kernel_size=kernel_size,
                                strides=strides,
                                padding='same')
        
    def call(self, inputs, training=False):
        x = self.bn(inputs)
        x = self.act(x)
        return self.conv2D_trans(x)


In [None]:
class Generator(tf.keras.models.Model):
    def __init__(self, filters, kernel_size, resize_img):
        super().__init__()
        self.dense1 = Dense(resize_img * resize_img * filters[0])
        self.reshape = Reshape([resize_img, resize_img, filters[0]])
        self.conv2dtrans = []
        for i, _filter in enumerate(filters):
            if i <= 1:
                strides = 2
            else:
                strides = 1
            self.conv2dtrans.append(ConvTransBlock(_filter, kernel_size, strides))
        
        self.act = Activation("sigmoid")

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.reshape(x)
        for conv in self.conv2dtrans:
            x = conv(x)
        return self.act(x)
    
    def model(self):
        x = Input(shape=[1, 6272])
        return Model(inputs=x, outputs=self.call(x))

In [None]:
def leaky_conv(filters, kernel_size, strides):
    return Sequential([LeakyReLU(alpha=0.2),
                       Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       strides=strides,
                       padding='same')])

In [None]:
class Discriminator(tf.keras.models.Model):
    def __init__(self, filters, kernel_size):
        super().__init__()
        self.leaky_convs = []
        for i, _filter in enumerate(filters):
            if i < len(filters) - 1:
                strides = 2
            else:
                strides = 1
            self.leaky_convs.append(leaky_conv(_filter, kernel_size, strides))
        self.flat = Flatten()
        self.dense = Dense(1, activation='sigmoid')
        self.input_dense = Dense(28*28)
        self.reshaped = Reshape((28,28,1))
        self.contact = Concatenate()

    def call(self, inputs, training=False):
        img, label = inputs        
        label = self.reshaped(self.input_dense(label))      
        x = self.contact(axis=2)([img, label])# tensor 3 dim (28, 28 ,2)   
        for conv in self.leaky_convs:
            x = conv(x)
        x = self.flat(x)
        return self.dense(x)
          

In [None]:
Reshape?

In [None]:
#hiper params
latent_size = 2
batch_size = 64
train_steps = 40000
disc_lr = 2e-4
disc_decay = 6e-8
gen_lr = disc_lr/4
gen_decay = disc_decay/4

In [None]:
disc_kernel_size = 5
disc_layers_filters = [32, 64, 128]

In [None]:
gen_resize_img = image_size // 4
gen_kernel_size = 3
gen_layers_filter = [128, 64, 32, 1]

In [None]:
gen = Generator(gen_layers_filter, gen_kernel_size, gen_resize_img)

In [None]:
disc = Discriminator(disc_layers_filters, disc_kernel_size)
disc.compile(loss='binary_crossentropy',
             optimizer=RMSprop(learning_rate=disc_lr, decay=disc_decay),
             metrics=['accuracy'])

In [None]:
class Gan(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.gen = Generator(gen_layers_filter, gen_kernel_size, gen_resize_img)

    def call(self, inputs):
        noise, real_img, label = inputs
        fake_img = self.gen(noise)
        return disc((fake_img, label))
        

gan = Gan()
gan.compile(loss='binary_crossentropy',
                optimizer=RMSprop(learning_rate=gen_lr, decay=gen_decay),
                metrics=['accuracy'])     



        

# def build_gan():
#     gen_input = Input(shape=(latent_size + 10,))   
#     # fake_img = Generator(gen_layers_filter, gen_kernel_size, gen_resize_img)(gen_input)
#     # disc_input = Input((10,))
#     # output = disc((fake_img, disc_input))
#     gan = Model(inputs=gen_input, outputs=output)  
#     gan.summary()
#     gan.compile(loss='binary_crossentropy',
#                 optimizer=RMSprop(learning_rate=gen_lr, decay=gen_decay),
#                 metrics=['accuracy'])            
#     return gan   


# gan = build_gan()

In [None]:
noise = np.random.uniform(-1, 1, size=(x_train.shape[0], latent_size))
noise.shape

(60000, 2)

In [None]:
one_hot = tf.one_hot(y_train, y_train.max()+1)

In [None]:
fake_labels = np.zeros(y_train.shape[0])

In [None]:
noise.shape, one_hot.shape, fake_labels.shape, x_train.shape

((60000, 2), TensorShape([60000, 10]), (60000,), (60000, 28, 28, 1))

In [None]:
gan.fit((x_train, noise, one_hot), fake_labels, batch_size=1)

In [None]:
X = tf.random.uniform((1,6272))
for layer in gan.layers:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)

ValueError: ignored