In [2]:
from tensorflow.python.keras.layers.merge import Concatenate
from tensorflow.python.keras.layers.core import Lambda
from tensorflow.keras.models import Model

import tensorflow as tf

def make_parallel(model, gpu_count):
    def get_slice(data, idx, parts):
        shape = tf.shape(data)
        size = tf.concat([ shape[:1] // parts, shape[1:] ],axis=0)
        stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ],axis=0)
        start = stride * idx
        return tf.slice(data, start, size)

    outputs_all = []
    for i in range(len(model.outputs)):
        outputs_all.append([])

    #Place a copy of the model on each GPU, each getting a slice of the batch
    for i in range(gpu_count):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('tower_%d' % i) as scope:

                inputs = []
                #Slice each input into a piece for processing on this GPU
                for x in model.inputs:
                    input_shape = tuple(x.get_shape().as_list())[1:]
                    slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x)
                    inputs.append(slice_n)                

                outputs = model(inputs)
                
                if not isinstance(outputs, list):
                    outputs = [outputs]
                
                #Save all the outputs for merging back together later
                for l in range(len(outputs)):
                    outputs_all[l].append(outputs[l])

    # merge outputs on CPU
    with tf.device('/cpu:0'):
        merged = []
        for outputs in outputs_all:
            merged.append(Concatenate(axis=0)(outputs))
            
        return Model(inputs=model.inputs, outputs=merged)

In [3]:
import os
import numpy as np
import tensorflow.keras.optimizers
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from utils_dir import utils
import tensorflow.keras.backend as K
import tensorflow.keras.layers as KL
import tensorflow.python.keras.engine as KE
import tensorflow.keras.models as KM
GPU_COUNT = 2

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

input_shape = (28,28,1)
img_input = KL.Input(shape = input_shape, name="input_image")
x = KL.Conv2D(32, (3, 3), padding="same",name="conv1")(img_input)
x = KL.Activation('relu')(x)
x = KL.Conv2D(64, (3, 3), padding="same",name="conv2")(x)
x = KL.Activation('relu')(x)
x = KL.MaxPooling2D(pool_size=(2, 2), name="pool1")(x)
x = KL.Flatten(name="flat1")(x)
x = KL.Dense(128, activation='relu', name="dense1")(x)
x = KL.Dense(10, activation='softmax', name="dense2")(x)

model = KM.Model(inputs = img_input, outputs = x)

# Load MNIST Data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32') / 255
x_test = np.expand_dims(x_test, -1).astype('float32') / 255

print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

# Build data generator and model
datagen = ImageDataGenerator()
# model = build_model(x_train, 10)
# utils.show_obj_params(model)
# Add multi-GPU support.
model = make_parallel(model, GPU_COUNT)

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=tensorflow.keras.optimizers.Adadelta(), metrics=['accuracy'])

model.summary()

# Train
model.fit_generator(
    datagen.flow(x_train, y_train, batch_size=64),
    steps_per_epoch=50, epochs=10, verbose=1,
    validation_data=(x_test, y_test))

x_train shape: (60000, 28, 28, 1)
x_test shape: (10000, 28, 28, 1)
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_image (InputLayer)        [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 28, 28, 1)    0           input_image[0][0]                
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 28, 28, 1)    0           input_image[0][0]                
__________________________________________________________________________________________________
model (Model)                   (None, 10)           1625866     lambda[0][0]                     
                         

<tensorflow.python.keras.callbacks.History at 0x7f69a47f8550>