In [171]:
from keras.layers import Input, Dense, Conv2D, Add, Dot, Lambda, Conv2DTranspose, Dot, Activation, Reshape, BatchNormalization, UpSampling2D, AveragePooling2D, GlobalAveragePooling2D, Multiply, LeakyReLU, Flatten, MaxPool2D 
from keras.models import Model
import keras.backend as K

In [210]:
# Author: Jose Sepulveda
# Description: This is a keras implementation of spectral normalization.
#              This was proposed in this paper: https://arxiv.org/pdf/1802.05957.pdf
#


from keras import backend as K


# Stochastic Gradient Descent with Spectral Normalization:
#   1) Initialize a random vector u, initialized from an isotropic distribution.
#   2) Use the Power iteration method with this vector u on the matrix of wieghts
#      to obtain two approximations of eigenvectors.
#   3) Calculate the spectral norm of the wieghts matrix.
#   4) Update wieghts using vanilla SGD using the spectral norm of the wieghts matrix.

def spectral_norm(w):
    """
        Input: tensor of wieghts
        Output: SN tensor of wieghts
    """
    def l2_norm(v):
        return K.sum(v ** 2) ** 0.5

    w_dim = w.shape.as_list()[-1]
    # Initialize random vector u
    u = K.random_normal(shape=[1, w_dim])

    # We need to flatten the wieghts
    w_flat = K.reshape(w, [-1, w_dim])

    # Power iteration method
    v = K.dot(u, K.transpose(w_flat))
    v = v / l2_norm(v)
    u = K.dot(v, w_flat)
    u = u / l2_norm(u)

    # Calculate the SN of W
    sigma = K.dot(K.dot(v, w_flat), K.transpose(u))
    w_sn = w_flat / sigma

    # Update wieghts
    w_sn = K.reshape(w_sn, w.shape.as_list())
    return w_sn


In [211]:
def ResBlockDown(input_shape, channel_size, channel_multiplier=1, name=None):
    # Resblock architecture
    # 1 BatchNorm 
    # 2 ReLU activation
    # 3 Conv layer
    # 4 BatchNorm
    # 5 ReLU activation
    # 6 Conv layer
    # 7 Sum with input 
    
    #FIRST BLOCK
    #input layer
    input_layer = Input(shape=input_shape)
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(input_layer)
    
    # Relu
    resblock = Activation('relu')(resblock)
    
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2D(channel_size * channel_multiplier, 3, padding='same', strides=2, kernel_regularizer=spectral_norm)(resblock)
    #SECOND BLOCK
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(resblock)
    # Relu
    resblock = Activation('relu')(resblock)
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2D(channel_size * channel_multiplier, 3, padding='same', kernel_regularizer=spectral_norm)(resblock)
    # Downsample
    #resblock = AveragePooling2D()(resblock)
    
    # Time for the shortcut connection!
    
    shortcut_identity = Conv2D(channel_size * channel_multiplier, 3, padding='same', strides=2, kernel_regularizer=spectral_norm)(input_layer)
    #shortcut_identity = AveragePooling2D()(shortcut_identity)
    
    output_layer = Add()([shortcut_identity, resblock])
    
    return Model(input_layer, output_layer, name=name)
    
    
    
    
    

In [245]:
def ResBlockUp(input_shape, channel_size, channel_multiplier=1, name=None):
    # Resblock architecture
    # 1 BatchNorm 
    # 2 ReLU activation
    # 3 Conv layer
    # 4 BatchNorm
    # 5 ReLU activation
    # 6 Conv layer
    # 7 Sum with input 
    
    #FIRST BLOCK
    #input layer
    input_layer = Input(shape=input_shape)
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(input_layer)
    
    # Relu
    resblock = Activation('relu')(resblock)
    
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2DTranspose(channel_size * channel_multiplier, 3, padding='same', strides=2, kernel_regularizer=spectral_norm)(resblock)
    
    #SECOND BLOCK
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(resblock)
    
    # Relu
    resblock = Activation('relu')(resblock)
    
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2DTranspose(channel_size * channel_multiplier, 3, padding='same', kernel_regularizer=spectral_norm)(resblock)
    
    # Downsample
    #resblock = AveragePooling2D()(resblock)
    
    # Time for the shortcut connection!
    
    shortcut_identity = Conv2DTranspose(channel_size * channel_multiplier, 1, padding='same', strides=2, kernel_regularizer=spectral_norm)(input_layer)
    #shortcut_identity = AveragePooling2D()(shortcut_identity)
    output_layer = Add()([shortcut_identity, resblock])
    
    return Model(input_layer, output_layer, name=name)
    
    

In [213]:
def ResBlock(input_shape, channel_size, channel_multiplier=1, name=None):
     # Resblock architecture
    # 1 BatchNorm 
    # 2 ReLU activation
    # 3 Conv layer
    # 4 BatchNorm
    # 5 ReLU activation
    # 6 Conv layer
    # 7 Sum with input 
    
    #FIRST BLOCK
    #input layer
    input_layer = Input(shape=input_shape)
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(input_layer)
    
    # Relu
    resblock = Activation('relu')(resblock)
    
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2D(channel_size * channel_multiplier, 3, padding='same', kernel_regularizer=spectral_norm)(resblock)
    
    #SECOND BLOCK
    
    # BatchNorm - needs to be conditional
    resblock = BatchNormalization()(input_layer)
    
    # Relu
    resblock = Activation('relu')(resblock)
    
    # Convolution size 3 filter as per paper
    # Need to spectrally normalize here somehow
    resblock = Conv2D(channel_size * channel_multiplier, 3, padding='same', kernel_regularizer=spectral_norm)(resblock)
    
    
    # Time for the shortcut connection!
    
    shortcut_identity = Conv2D(channel_size * channel_multiplier, 1, padding='same', kernel_regularizer=spectral_norm)(input_layer)
    
    output_layer = Add()([shortcut_identity, resblock])
    
    return Model(input_layer, output_layer, name=name)
    
    
    
    
    
    

In [214]:


def SelfAttentionBlock(input_shape, name=None):
    # f = conv
    channels = input_shape[-1]
    input_layer = Input(shape=input_shape)
    f = Conv2D(channels // 8, 1, padding='same')(input_layer)
    # f = maxpooling
    f = MaxPool2D(pool_size=2, strides=2, padding='same')(f)
    
    g = Conv2D(channels // 8, 1, padding='same')(input_layer)
    
    h = Conv2D(channels // 2, 1, padding='same')(input_layer)
    h = MaxPool2D(pool_size=2, strides=2, padding='same')(h)
    
    
    g = Lambda(lambda input1: K.reshape(input1, shape=[-1,1]))(g)
    f = Lambda(lambda input1: K.reshape(input1, shape=[-1,1]))(f)
    s = Dot(-1)([g, f])
    beta = Activation('softmax')(s)

    h = Lambda(lambda input1: K.reshape(input1, shape=[-1,1]))(h)
    o = Dot(-1)([beta, h])
    
    gamma = Conv2D(channels, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal')(input_layer)
    #gamma = Reshape((-1, channels // 2))(g)
    #print(gamma.shape)
    a, x, y ,z = input_layer.shape
    #o = K.reshape(o, shape=[x,y,z, channels // 2])
    o = Lambda(lambda input1: K.reshape(input1, shape=[x,y,z, channels // 2]))(o)
    o = Conv2D(channels, kernel_size=1, strides=1)(o)
  
    Wz_yi = Multiply()([gamma, o])
    output_layer = Add()([Wz_yi, input_layer])
    
    return Model(input_layer, output_layer, name=name)
    
    # g = conv
    # h = conv
    # h = maxpooling

In [218]:

def GlobalSumPooling2D(name=None):
    return Lambda(lambda inputs: K.sum(inputs, axis=[1, 2]), name=name)
        
# Discriminator test
def build_discriminator(channel_multiplier=64):
    model_input = Input(shape=(128,128,3), name="D_input")
    resblockdown1 = ResBlockDown(input_shape=(128,128,3),channel_size=1, channel_multiplier=channel_multiplier, name='D_resblock_down_1')
    h = resblockdown1(model_input)
    selfattentionblock = SelfAttentionBlock(input_shape=(64,64,64), name='D_self_attention_block')
    h = selfattentionblock(h)
    # Non local block should be here
    resblockdown2 = ResBlockDown(input_shape=(64,64,64),channel_size=2, channel_multiplier=channel_multiplier, name='D_resblock_down_2')
    h = resblockdown2(h)
    resblockdown4 = ResBlockDown(input_shape=(32,32,128),channel_size=4, channel_multiplier=channel_multiplier, name='D_resblock_down_4')
    h = resblockdown4(h)
    resblockdown8 = ResBlockDown(input_shape=(16,16,256),channel_size=8, channel_multiplier=channel_multiplier, name='D_resblock_down_8')
    h = resblockdown8(h)
    resblockdown16 = ResBlockDown(input_shape=(8,8,512),channel_size=16, channel_multiplier=channel_multiplier, name='D_resblock_down_16')
    h = resblockdown16(h)
    resblock16 = ResBlock(input_shape=(4,4,1024),channel_size=16, channel_multiplier=channel_multiplier, name='D_resblock_16')
    h = resblock16(h)
    h = Activation('relu', name="D_relu")(h)
    h = GlobalSumPooling2D(name="D_global_sum_pooling_2D")(h)
    model_output = Dense(1, name="D_dense")(h)
    model = Model(model_input, model_output, name="Discriminator")
    return model
build_discriminator().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
D_input (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
D_resblock_down_1 (Model)    (None, 64, 64, 64)        40780     
_________________________________________________________________
D_self_attention_block (Mode (64, 64, 64, 64)          9328      
_________________________________________________________________
D_resblock_down_2 (Model)    multiple                  296064    
_________________________________________________________________
D_resblock_down_4 (Model)    multiple                  1181952   
_________________________________________________________________
D_resblock_down_8 (Model)    multiple                  4723200   
_________________________________________________________________
D_resblock_down_16 (Model)   multiple                  18883584  
__________

In [254]:
# only noise, non conditional
def build_generator(channel_multiplier=64):
    model_input = Input(shape=(128,), name="G_input")
    h = Dense(4*4*16*channel_multiplier, name="G_dense")(model_input)
    h = Reshape((4,4,16*channel_multiplier))(h)
    resblockup16 = ResBlockUp(input_shape=(4,4,1024), channel_size=16, channel_multiplier=channel_multiplier, name="G_resblock_up_16")
    h = resblockup16(h)
    resblockup8 = ResBlockUp(input_shape=(8,8,1024), channel_size=8, channel_multiplier=channel_multiplier, name="G_resblock_up_8")
    h = resblockup8(h)
    resblockup4 = ResBlockUp(input_shape=(16,16,512), channel_size=4, channel_multiplier=channel_multiplier, name="G_resblock_up_4")
    h = resblockup4(h)
    resblockup2 = ResBlockUp(input_shape=(32,32,256), channel_size=2, channel_multiplier=channel_multiplier, name="G_resblock_up_2")
    h = resblockup2(h)
#     need to debug the following
#     selfattentionblock = SelfAttentionBlock(input_shape=(64,64,128), name='G_self_attention_block')
#     h = selfattentionblock(h)
    resblockup1 = ResBlockUp(input_shape=(64,64,128), channel_size=1, channel_multiplier=channel_multiplier, name="G_resblock_up_1")
    h = resblockup1(h)
    h = BatchNormalization()(h)
    h = Activation('relu')(h)
    model_output = model_output
    
    return Model(model_input, model_output, name="Generator")


In [255]:
build_generator().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
G_input (InputLayer)         (None, 128)               0         
_________________________________________________________________
G_dense (Dense)              (None, 16384)             2113536   
_________________________________________________________________
reshape_18 (Reshape)         (None, 4, 4, 1024)        0         
_________________________________________________________________
G_resblock_up_16 (Model)     (None, 8, 8, 1024)        19934208  
_________________________________________________________________
G_resblock_up_8 (Model)      (None, 16, 16, 512)       7609856   
_________________________________________________________________
G_resblock_up_4 (Model)      (None, 32, 32, 256)       1904384   
_________________________________________________________________
G_resblock_up_2 (Model)      (None, 64, 64, 128)       477056    
__________