In [1]:
from keras import layers
from keras.models import Model, Sequential
from keras.utils import plot_model

Using TensorFlow backend.


In [2]:
def res_block(y, nb_channels, _strides = (1,1), _project_shortcut=False):
        shortcut = y
        
        y = layers.Conv2D(nb_channels, kernel_size=(3, 3), strides=_strides, padding='same')(y)
        y = layers.BatchNormalization()(y)
        y = layers.LeakyReLU()(y)
        
        y = layers.Conv2D(nb_channels, kernel_size=(3, 3), strides=(1, 1), padding='same')(y)
        y = layers.BatchNormalization()()
        
        if _project_shortcut or _strides != (1, 1):
            shortcut = layers.Conv2D(nb_channels, kernel_size=(1, 1), strides=_strides, padding='same')(shortcut)
            shortcut = layers.BatchNormalization()(shortcut)
            
        y = layers.add([shortcut, y])
        y = layers.LeakyReLU()(y)
        
        return y

In [3]:
def conv_net(x, nb_channels, _strides=(1, 1)):
    x = layers.Conv2D(32, kernel_size=(3, 3), strides=_strides, padding='same', activation='relu')(x)
    x = layers.Conv2D(64, kernel_size=(3, 3), strides=_strides, padding='same', activation='relu')(x)
    
    return x

In [4]:
def post_net(y, nb_channels, _strides=(1, 1)):
    y = layers.Conv2D(64, kernel_size=(3, 3), strides=_strides, padding='same', activation='relu')(y)
    y = layers.Conv2D(32, kernel_size=(3, 3), strides=_strides, padding='same', activation='relu')(y)
    y = layers.Conv2D(3, kernel_size=(3, 3), strides=_strides, padding='same', activation='linear')(y)
    
    return y

In [5]:
#inport training data
import numpy as np
x1_train = np.random.random([100,64,64,3])
x2_train = np.random.random([100,64,64,3])
y_train = np.random.random([100, 64, 64, 3])
x1_test = np.random.random([100,64,64,3])
x2_test = np.random.random([100,64,64,3])
y_test = np.random.random([100, 64, 64, 3])

def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val

In [6]:
img_a = layers.Input(shape=(64, 64, 3))
img_b = layers.Input(shape=(64, 64, 3))
feature_a = conv_net(img_a, 3)
feature_b = conv_net(img_b, 3)
merge = layers.concatenate([feature_a, feature_b])
aif = post_net(merge, 128)
gen = Model(inputs = [img_a, img_b], output = [aif])
gen.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
gen.summary()
plot_model(gen, to_file='generator.png')

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 64, 64, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 32)   896         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (

  import sys


In [7]:
image_fake = gen([img_a, img_b])
dis = Sequential()
dis.add(layers.Conv2D(32, kernel_size=(3, 3), padding='same', activation='relu'))
dis.add(layers.Conv2D(64, kernel_size=(3, 3), padding='same', activation='relu'))
dis.add(layers.Flatten())
dis.add(layers.Dense(1))
dis.add(layers.Activation('sigmoid'))
pred_prob = dis(image_fake)
dis.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
dis.summary()
plot_model(dis, to_file='discriminator.png')
make_trainable(dis, False)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 64, 64, 32)        896       
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 64, 64, 64)        18496     
_________________________________________________________________
flatten_1 (Flatten)          (None, 262144)            0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 262145    
_________________________________________________________________
activation_1 (Activation)    (None, 1)                 0         
Total params: 281,537
Trainable params: 281,537
Non-trainable params: 0
_________________________________________________________________


In [8]:
am = Model(inputs = [img_a, img_b], output = [pred_prob])
am.summary()
am.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
plot_model(am, to_file='adversary.png')

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
model_1 (Model)                 (None, 64, 64, 3)    131907      input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 1)            281537      model_1[1][0]                    
Total para

  """Entry point for launching an IPython kernel.


In [9]:
gen.fit([x1_train, x2_train], y_train)
img_fake = gen.predict([x1_train, x2_train])

Epoch 1/1


In [10]:
# Train discriminator on generated images
X = np.concatenate((y_train, img_fake))
y = np.zeros([200,])
y[0:100] = 1
y[100:] = 0

make_trainable(dis,True)
dis.fit(X, y)

y2 = np.ones([100, ])
# train Generator-Discriminator stack on input noise to non-generated output class
make_trainable(dis,False)
am.fit([x1_train, x2_train], y2)

Epoch 1/1
Epoch 1/1


<keras.callbacks.History at 0x7f35b551d208>