In [1]:
import tensorflow as tf
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers, Input, Model
import time

from IPython import display

In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(*args, **kwargs)
        self.__dict__ = self

    def print_(self):
        print(f"{'=' * 80}\n{'Opts'.center(80)}\n{'-' * 80}")
        for key in self:
            if self[key]:
                print('{:>30}: {:<30}'.format(key, self[key]).center(80))
        print('=' * 80)

In [3]:
args = AttrDict()
args_dict = {'image_size_1':28,
             'image_size_2':28,
             'image_channel':1,
             'nfc':32,
             "min_nfc":5,
             "kernel_size": 3,
             'num_layer':5,
             'batch_size':128,
             'g_conv_dim':32,
             'g_noise_size':100,
             'd_conv_dim':64,
             'train_iters':10000,
             'opt_lr':0.0003,
             'opt_beta1':0.5,
             'opt_beta2':0.999,
             'batch_size':32, 
             'sample_dir': 'samples_gan',
             'log_step':200,
             'sample_every':200,}

args.update(args_dict)
args.print_()

                                      Opts                                      
--------------------------------------------------------------------------------
                           image_size_1: 28                                     
                           image_size_2: 28                                     
                          image_channel: 1                                      
                                    nfc: 32                                     
                                min_nfc: 5                                      
                            kernel_size: 3                                      
                              num_layer: 5                                      
                             batch_size: 32                                     
                             g_conv_dim: 32                                     
                           g_noise_size: 100                                    
                            

In [18]:
class generator(tf.keras.Model):

    def __init__(self,args):
        super(generator, self).__init__()
        N = args.nfc
        self.conv_first = layers.Conv2D(
                                filters = max(N,args.min_nfc),#output
                                padding = "same",
                                kernel_size = args.kernel_size,
                                strides = 1,
                                kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None)
                               )
        self.batch_first = layers.BatchNormalization()
        self.relu_first = layers.LeakyReLU(alpha = 0.2)
        self.body_layers = []
        for i in range(args.num_layer - 2):
            N = int(args.nfc/pow(2,(i+1)))
            self.body_layers.append(layers.Conv2D(filters = max(N,args.min_nfc), 
                                                  padding = "same",
                                                  kernel_size = args.kernel_size,
                                                  strides = 1,
                                                  kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None)))
            self.body_layers.append(layers.BatchNormalization())
            self.body_layers.append(layers.LeakyReLU(alpha = 0.2))
        self.end_layer = layers.Conv2D(filters = args.image_channel, 
                                  padding = "same",
                                  kernel_size = args.kernel_size,
                                  strides = 1,
                                  activation='tanh',
                                  kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None))
    def call(self, x, y):
        x = self.conv_first(x)
        x = self.batch_first(x)
        x = self.relu_first(x)
        for i in self.body_layers:
            x = i(x)
        x = self.end_layer(x)
        ind = int((y.shape[2]-x.shape[2])/2)
        y = y[:,:,ind:(y.shape[2]-ind),ind:(y.shape[3]-ind)]
        return x+y
    
    # to deal with https://github.com/tensorflow/tensorflow/issues/25036, thanks for 
    def model(self):
        x = Input(shape=(args.image_size_1, args.image_size_2, args.image_channel))
        y = Input(shape=(args.image_size_1, args.image_size_2, args.image_channel))
        return Model(inputs=[x,y], outputs=self.call(x,y))
    
    def summary(self):
        self.model().summary()

In [21]:
G = generator(args)

In [23]:
G.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 28, 28, 32)   320         input_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 28, 28, 32)   128         conv2d_20[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, 28, 28, 32)   0           batch_normalization_16[0][0]     
____________________________________________________________________________________________

In [59]:
class discriminator(tf.keras.Model):
    def __init__(self,args):
        super(discriminator, self).__init__()
        N = args.nfc
        self.conv_first = layers.Conv2D(input_shape = (args.image_size_1, args.image_size_2, args.image_channel),
                                filters = max(N,args.min_nfc),#output
                                padding = "same",
                                kernel_size = args.kernel_size,
                                strides = 1,
                                bias_initializer='zeros',
                                kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None)
                                )
        self.batch_first = layers.BatchNormalization()
        self.relu_first = layers.LeakyReLU(alpha = 0.2)
        self.body_layers = []
        for i in range(args.num_layer - 2):
            N = int(args.nfc/pow(2,(i+1)))
            self.body_layers.append(layers.Conv2D(filters = max(N,args.min_nfc), 
                                    padding = "same",
                                    kernel_size = args.kernel_size,
                                    strides = 1,
                                    kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None)))
            self.body_layers.append(layers.BatchNormalization())
            self.body_layers.append(layers.LeakyReLU(alpha = 0.2))

        self.conv_last = layers.Conv2D(filters = 1, 
                          padding = "same",
                          kernel_size = args.kernel_size,
                          strides = 1,
                          activation='tanh',
                          kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.02, seed=None))
        self.output_1 = layers.Flatten()
        self.output_2 = layers.Dense(1)
    def call(self, x):
        x = self.conv_first(x)
        x = self.batch_first(x)
        x = self.relu_first(x)
        for i in self.body_layers:
            x = i(x)
        x = self.conv_last(x)
        x = self.output_1(x)
        x = self.output_2(x)
        return x
    
    # to deal with https://github.com/tensorflow/tensorflow/issues/25036, thanks for 
    def model(self):
        x = Input(shape=(args.image_size_1, args.image_size_2, args.image_channel))
        return Model(inputs=[x], outputs=self.call(x))
    
    def summary(self):
        self.model().summary()

In [60]:
D = discriminator(args)

In [46]:
D.summary()

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_7 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 28, 28, 32)        320       
_________________________________________________________________
batch_normalization_28 (Batc (None, 28, 28, 32)        128       
_________________________________________________________________
leaky_re_lu_28 (LeakyReLU)   (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 28, 28, 16)        4624      
_________________________________________________________________
batch_normalization_29 (Batc (None, 28, 28, 16)        64        
_________________________________________________________________
leaky_re_lu_29 (LeakyReLU)   (None, 28, 28, 16)        0   

In [47]:
def init_models(args):
    #generator initialization:
    netG = generator(args)
    
    #discriminator initialization:
    netD = discriminator(args)

    #both model has fixed initializer, things wont get too different
    return netD, netD