Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does Keras support using multiple GPUs? #2436

Closed
henry0312 opened this issue Apr 21, 2016 · 162 comments
Closed

Does Keras support using multiple GPUs? #2436

henry0312 opened this issue Apr 21, 2016 · 162 comments

Comments

@henry0312
Copy link
Contributor

Theano has supported multiple GPUs since v0.8.0.
(cf. Using multiple GPUs — Theano 0.8.0 documentation )
Does Keras also support using multiple GPUs?

For example, can I run the below task?

  1. Learn a sequential model A on gpu0
  2. Learn a sequential model B on gpu1
  3. Merge A and B on gpu0
@fchollet
Copy link
Member

Yes, can run Keras models on multiple GPUs. This is only possible with the TensorFlow backend for the time being, because the Theano feature is still rather new. We are looking at adding support for multi-gpu in Theano in the near future (it should be fairly straightforward).

With the TensorFlow backend, you can achieve this the same way as you would in pure TensorFlow: by using the with tf.device(d) scope when defining Keras layers.

@henry0312
Copy link
Contributor Author

We are looking at adding support for multi-gpu in Theano in the near future (it should be fairly straightforward).

I'm looking forward to it 😃
Thank you.

@lemuriandezapada
Copy link

tf.device() scope?
Can you expand on this?
I haven't seen it in the api

@jeffzhengye
Copy link
Contributor

Any example to use multiple gpus with TF?

@phalexo
Copy link

phalexo commented Apr 23, 2016

Hm. Theano has libgpuarray, which allows one to push shared variables to different devices. This will not do all the work for you of recombining weight matrices but with a little effort you could use multiple GPUs.

@nouiz
Copy link
Contributor

nouiz commented Apr 25, 2016

There is platoon, a project on top of Theano for data parallelism. Should
be easy to use. We currently focus more on days parallelism then model
parallelism in Theano. But both are possible.

Fred
Le 23 avr. 2016 17:24, "phalexo" notifications@github.com a écrit :

Hm. Theano has libgpuarray, which allows one to push shared variables to
different devices. This will not do all the work for you of recombining
weight matrices but with a little effort you could use multiple GPUs.


You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub
#2436 (comment)

@fchollet
Copy link
Member

I have looked into Platoon and it seemed like it was pretty much compatible
with Keras out of the box except for a couple lines of code. Easy to adapt,
in any case...

On 25 April 2016 at 05:46, Frédéric Bastien notifications@github.com
wrote:

There is platoon, a project on top of Theano for data parallelism. Should
be easy to use. We currently focus more on days parallelism then model
parallelism in Theano. But both are possible.

Fred
Le 23 avr. 2016 17:24, "phalexo" notifications@github.com a écrit :

Hm. Theano has libgpuarray, which allows one to push shared variables to
different devices. This will not do all the work for you of recombining
weight matrices but with a little effort you could use multiple GPUs.


You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub
#2436 (comment)


You are receiving this because you commented.
Reply to this email directly or view it on GitHub
#2436 (comment)

@phalexo
Copy link

phalexo commented Apr 25, 2016

The way libgpuarray work is by mapping variables to different GPUs, and
then function automatically generates code to transfer data between GPUs as
needed.
On Apr 25, 2016 16:13, "François Chollet" notifications@github.com wrote:

I have looked into Platoon and it seemed like it was pretty much compatible
with Keras out of the box except for a couple lines of code. Easy to adapt,
in any case...

On 25 April 2016 at 05:46, Frédéric Bastien notifications@github.com
wrote:

There is platoon, a project on top of Theano for data parallelism. Should
be easy to use. We currently focus more on days parallelism then model
parallelism in Theano. But both are possible.

Fred
Le 23 avr. 2016 17:24, "phalexo" notifications@github.com a écrit :

Hm. Theano has libgpuarray, which allows one to push shared variables
to
different devices. This will not do all the work for you of recombining
weight matrices but with a little effort you could use multiple GPUs.


You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub
#2436 (comment)


You are receiving this because you commented.
Reply to this email directly or view it on GitHub
#2436 (comment)


You are receiving this because you commented.
Reply to this email directly or view it on GitHub
#2436 (comment)

@jtanios
Copy link

jtanios commented Jun 22, 2016

I have looked into Platoon and it seemed like it was pretty much compatible
with Keras out of the box except for a couple lines of code. Easy to adapt,
in any case...

What's the priority of adding multi GPU support for the theano backend?

@phalexo
Copy link

phalexo commented Jun 23, 2016

I think it would expand user base for Keras. I have several Titan X in the
same box. Please, take a look at libgpuarray as well.
On Jun 22, 2016 19:54, "themummy" notifications@github.com wrote:

I have looked into Platoon and it seemed like it was pretty much compatible
with Keras out of the box except for a couple lines of code. Easy to adapt,
in any case...

What's the priority of adding multi GPU support for the theano backend?


You are receiving this because you commented.
Reply to this email directly, view it on GitHub
#2436 (comment),
or mute the thread
https://github.com/notifications/unsubscribe/AEY95aBPElrTcVv0ZzPFyDgcDKMaw-0iks5qOcsXgaJpZM4IMTcS
.

@tetmin
Copy link

tetmin commented Jul 26, 2016

How does this actually work in tensorflow? There is a brief tutorial here: http://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html, I understand the concept of running the model replicas on seperate GPU devices & then merging the weights, but how do we actually run this? instead of model.fit do we call merged.fit on the result of the merged models?

@pengpaiSH
Copy link

@tetmin I have the same confusion as yours. Although the blog shows how to predict model in different GPUs, it is still unclear how to train the same model across different GPUs in a single machine, i.e. I need data parallelism and don't know how to implement it in Keras with TensorFlow as backend.

@rudaoshi
Copy link

Agreed with @pengpaiSH and @tetmin . Hope there would be more details.

@pengpaiSH
Copy link

pengpaiSH commented Aug 23, 2016

@rudaoshi Well, I know this would not be proper to suggest since we are in the Keras community, and personally I am a Big Big Big fan of Keras! We know TensorFlow could utilize Multi-GPUs by computing averaging gradients across different devices, however, I am expecting Keras could provide a simple and unified API (Keras's Style) to help me focus my big picture end hide those IO/Parallel Computing details. For the time being, in order to make good use of multiple GPUs, I am doing my deep learning programs with MXNET, which I only specify the GPU IDs and the lib will do everything it needs under the hood.

@WenchenLi
Copy link

@fchollet I saw your blog with multi gpu training, thanks for pointing out the way doing multi gpu training, but I would really appreciate it if say model.fit() has a gpu=n option, I'll willing to implement my own version on that, may I ask for suggestions? or I'm willing to contribute on the multi gpu training within keras with more abstraction from end users. Thanks in advance!

@pengpaiSH
Copy link

@WenchenLi +1, gpus=0,1,2... is exactly what I need!

@acrosson
Copy link

@WenchenLi did you create a PR for multigpu abstraction?

@anewlearner
Copy link

Hope someone can contribute on the multi gpu training within keras. Thanks in advance.

I have two gpus. I did not do anything to set which gpu would be used for training. But when I used the nvidia-smi to check memory. I found almost all of the memory in two gpus were in use. I thought only one gpu would be used.

@acrosson
Copy link

@anewlearner apparently this is the intended functionality of TF.
Use export CUDA_VISIBLE_DEVICES="0".

See tensorflow/tensorflow#5066 for details

Looking forward to a simplified version of mult-gpu :)

@jonilaserson
Copy link

For data parallelization in keras, you can use this approach:

import tensorflow as tf

from keras import backend as K

from keras.models import Model

from keras.layers import Input, merge

from keras.layers.core import Lambda

def slice_batch(x, n_gpus, part):

sh = K.shape(x)

L = sh[0] / n_gpus

if part == n_gpus - 1:

    return x[part*L:]

return x[part*L:(part+1)*L]

def to_multi_gpu(model, n_gpus=2):

with tf.device('/cpu:0'):

    x = Input(model.input_shape[1:], name=model.input_names[0])


towers = []

for g in range(n_gpus):

    with tf.device('/gpu:' + str(g)):

        slice_g = Lambda(slice_batch, lambda shape: shape,

arguments={'n_gpus':n_gpus, 'part':g})(x)

        towers.append(model(slice_g))


    with tf.device('/cpu:0'):

        merged = merge(towers, mode='concat', concat_axis=0)


return Model(input=[x], output=merged)

To use just take any model and set model = to_multi_gpu(model).

model.fit() and model.predict() should work without any change.

On Fri, Oct 21, 2016 at 6:13 PM, Alexander notifications@github.com wrote:

@anewlearner https://github.com/anewlearner apparently this is the
intended functionality of TF.
Use export CUDA_VISIBLE_DEVICES="0".

See tensorflow/tensorflow#5066
tensorflow/tensorflow#5066 for details

Looking forward to a simplified version of mult-gpu :)


You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
#2436 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AFdLCAoGsCYxA9CVIN7IIJqX7ddkxEurks5q2NaagaJpZM4IMTcS
.

@jtanios
Copy link

jtanios commented Oct 27, 2016

@jonilaserson , looks great! Does this work with the Theano backend or only TF?

@pengpaiSH
Copy link

@jonilaserson Could you please provide more detailed comments for the codes?
For example, what's the purpose of slice_g? And what does tower do actually? Thank you !

@anewlearner
Copy link

I tested the code provided by @jonilaserson and got a error.
merged = merge(towers, mode='concat', concat_axis=0)
Exception: A Merge should only be applied to a list of layers with at least 2 elements. Found: [<keras.engine.training.Model object at 0x7f9c1c3123d0>]

@pengpaiSH
Copy link

@anewlearner Have you solved the problem that you met with before?

@jonilaserson
Copy link

@Carol

There was an indentation error in the code I posted.
The [with tf.device('/cpu:0')] paragraph should be outside the loop.

Here is a piece of code that should work:

import tensorflow as tf

from keras import backend as K

from keras.models import Model

from keras.layers import Input, merge

from keras.layers.core import Lambda

def slice_batch(x, n_gpus, part):

"""

Divide the input batch into [n_gpus] slices, and obtain slice no.

[part].

i.e. if len(x)=10, then slice_batch(x, 2, 1) will return x[5:].

"""

sh = K.shape(x)

L = sh[0] / n_gpus

if part == n_gpus - 1:

    return x[part*L:]

return x[part*L:(part+1)*L]

def to_multi_gpu(model, n_gpus=2):

"""Given a keras [model], return an equivalent model which parallelizes

the computation over [n_gpus] GPUs.



Each GPU gets a slice of the input batch, applies the model on that

slice

and later the outputs of the models are concatenated to a single

tensor,

hence the user sees a model that behaves the same as the original.

"""

with tf.device('/cpu:0'):

    x = Input(model.input_shape[1:], name=model.input_names[0])


towers = []

for g in range(n_gpus):

    with tf.device('/gpu:' + str(g)):

        slice_g = Lambda(slice_batch, lambda shape: shape,

arguments={'n_gpus':n_gpus, 'part':g})(x)

        towers.append(model(slice_g))


with tf.device('/cpu:0'):

    merged = merge(towers, mode='concat', concat_axis=0)


return Model(input=[x], output=merged)

To use just take any model and set model = to_multi_gpu(model).

model.fit() and model.predict() should work without any change.

Example:

from keras.layers.convolutional import Convolution2D

from keras.layers.core import Activation

import numpy as np

def get_model():

x = Input( (96,96,1), name="input1")

output = Convolution2D(64, 5, 5, border_mode='same', name="conv1")(x)

output = Activation('relu', name="relu1")(output)

[More layers...]

model = Model(input=x, output=output)
model.compile(optimizer='rmsprop', loss='mse')
return model

model = get_model()
model = to_multi_gpu(model)

x = np.random.rand(1000, 96, 96, 1)
y = model.predict(x, verbose=True)

On Mon, Oct 31, 2016 at 10:18 AM, Pai Peng notifications@github.com wrote:

@anewlearner https://github.com/anewlearner Have you solved the problem
that you met with before?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#2436 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AFdLCCOCSqV-6UR7QtUN7Gv8YGe73u03ks5q5ZZegaJpZM4IMTcS
.

@pengpaiSH
Copy link

pengpaiSH commented Nov 1, 2016

@jonilaserson Thank you for you updating! Would you please comments on the code snippets

 for g in range(n_gpus):
        with tf.device('/gpu:' + str(g)):
            slice_g = Lambda(slice_batch, lambda shape: shape, 
                            arguments={'n_gpus':n_gpus, 'part':g})(x)
            towers.append(model(slice_g))

@anewlearner
Copy link

anewlearner commented Nov 1, 2016

@jonilaserson
Thanks for sharing your code. It works. :)
I tested code to compare the cost of time bettween a gpu and two gpus.
When I used two gpus(same type of gpus here), the speed was slower than expected. Does the switch bettween cpu and gpu affects the speed?
My test result is as follows.

Two gpus

97650/682307 [===>..........................] - ETA: 1933s - loss: 0.3320 - acc: 0.8320
188593/682307 [=======>......................] - ETA: 1654s - loss: 0.2354 - acc: 0.8904
279093/682307 [===========>..................] - ETA: 1348s - loss: 0.1936 - acc: 0.9140

One gpu

97650/682307 [===>..........................] - ETA: 2669s - loss: 0.3488 - acc: 0.8266
188593/682307 [=======>......................] - ETA: 2239s - loss: 0.2431 - acc: 0.8880
279093/682307 [===========>..................] - ETA: 1844s - loss: 0.2004 - acc: 0.9116

@pengpaiSH
Copy link

I think you should compile the model in case of a error: you must compile the model before training/testing.

@alsrgv
Copy link
Contributor

alsrgv commented Oct 5, 2017

FYI - we just added an example of data-parallel distributed training with Keras using Horovod - https://github.com/uber/horovod/blob/master/examples/keras_mnist.py. It works both for multiple GPUs within the server, and across servers. Hope it helps.

@michelleowen
Copy link

I used the code of @jonilaserson. And it works. However, it seems that multi-gpu converged slower compared to single gpu. Anyone else observed the same?

@alsrgv
Copy link
Contributor

alsrgv commented Oct 5, 2017

@michelleowen you typically want to adjust learning rate to total # of GPUs across all the servers - here's an example for very simple scaling. Facebook published a paper with a more sophisticated strategy that works for a large number of GPUs.

@michelleowen
Copy link

@alsrgv, thank you. This is very helpful. I will do some experiments to see how it works in my case.

@fernandoandreotti
Copy link

fernandoandreotti commented Oct 18, 2017

I guess the function previously mentioned by @avolkov1 is finally coming into Keras:
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py

@bzamecnik
Copy link
Contributor

@fernandoandreotti Yes and no. It's a cleaned-up variant of function from kuza55. It has nice documentation and grabs list of devices via device_lib instead of CUDA_VISIBLE_DEVICES. On the other hand it's missing some stuff from avolkov1: slicing on CPU, save/load of parameters of original serial model. Since there's no wrapper class, so the latter is not necessary, but at least might be documented.

@fernandoandreotti
Copy link

fernandoandreotti commented Nov 3, 2017

Keras v2.0.9 now includes it (release notes). Despite the improvements that can be done, I guess this issue should be closed.

@pGit1
Copy link

pGit1 commented Nov 9, 2017 via email

@fchollet
Copy link
Member

fchollet commented Nov 9, 2017

Yes: https://keras.io/utils/#multi_gpu_model

You can also check out Horovod, which seems nice.

@fchollet fchollet closed this as completed Nov 9, 2017
@ViaFerrata
Copy link

Is there any intention for making it work with CNTK too?

@nbansal90
Copy link

nbansal90 commented Dec 29, 2017

@avolkov1 @jonilaserson Is there an issue with saving models using ModelCheckpoint using multi_gpu model. I actually used few other callbacks but it worked fine, but ModelCheckpoint is the one which fails to save the model, and throws error after an epcoh.

CODE

`class MyCallBack(keras.callbacks.Callback):
def init(self, callbacks,model):
super().init()
self.callback = callbacks
self.model = model

def on_epoch_begin(self,epoch,logs=None):
        self.callback.on_epoch_begin(epoch, logs=logs)

def on_epoch_end(self,epoch,logs=None):
        self.callback.on_epoch_end(epoch, logs=logs)

def on_batch_end(self, batch, logs=None):
        self.callback.on_batch_end(batch, logs=logs)

def on_batch_begin(self, batch, logs=None):
        self.callback.on_batch_begin(batch, logs=logs)

def on_train_begin(self, logs=None):
        self.callback.set_model(self.model)
        self.callback.on_train_begin(logs=logs)

def on_train_end(self, logs=None):
        self.callback.on_train_end(logs=logs)

parallel_model = multi_gpu_model(model, gpus=2)
parallel_model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=lr_schedule(0)),metrics=['accuracy'])
#Setting up Callbacks, during fitting of the Model
filename='model_train_new.csv'
filepath = os.path.join(save_dir, model_name)
checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_acc',verbose=1,
save_best_only=True)
cbk3 = MyCallBack(checkpoint, model)
callbacks = [cbk3]

#Adding Data Augmentation Provided by Keras Module
datagen=ImageDataGenerator(featurewise_center=False,samplewise_center=False,featurewise_std_normalization=False,samplewise_std_normalization=False,zca_whitening=False,rotation_range=0, width_shift_range=0.1,height_shift_range=0.1,horizontal_flip=True,vertical_flip=False)

datagen.fit(x_train)
steps_per_epoch = int(np.ceil(x_train.shape[0] / float(batch_size)))
model_info = parallel_model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
steps_per_epoch=steps_per_epoch,
validation_data=(x_test, y_test),
epochs=epochs, verbose=1, workers=4,
callbacks=callbacks)`

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@nbansal90

I had this same problem. Model Checkpoint will not work with multi GPU model. You can change the parameter save_weights_only to True and this will work fine HOWEVER if you then want to do inference on a SINGLE gpu the model will not load weights properly even if you load the checkpointed weights by name.

@fchollet

Kind of an urgent question: is there a way to train on multiple GPUs but save the weights in such a way that I can do inference on only a single GPU? I am not sure how to get this to work properly as model.load_weights(''/weights_path', by_name=True) does not work. I have to re instantiate the network as a multi-gpu-model to properly load weights. I may be missing something simple though.

@fercook
Copy link

fercook commented Jan 2, 2018

mmmm since it's urgent, maybe a dirty patch will do: couldn't you save the weights as matrices and then load them directly into the weights of the layers of a new (single GPU) model?

edit: saving/loading the weights of the example from the docs doesn't work? https://keras.io/utils/#multi_gpu_model

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@fercook

Thanks for quick response. I believe I have tried that. My weights were saved via the Model Checkpoint callback for a multi gpu model.

When I re-instantiate the model I cannot load the weights to my single GPU model because I get an error stating that I am trying to load weights into a model with one layer when it expects four layers(4 is the number of GPUs I was using).

edit:

edit: saving/loading the weights of the example from the docs doesn't work? https://keras.io/utils/#multi_gpu_model

That is correct. It does not work. Although I haven't tried the cpu device scope. Will try and let know. Ive only used model checkpoint callback with save_weights_only = True and model.load_weights.

@fercook
Copy link

fercook commented Jan 2, 2018

Did you double check that you are saving with the template model, not the multi_gpu one?

From the docs:

On model saving

To save the multi-gpu model, use `.save(fname)` or `.save_weights(fname)`
with the template model (the argument you passed to `multi_gpu_model`),
rather than the model returned by `multi_gpu_model`.

edit: sorry I just re-read that you are saving through the callback...how are you doing that? Is each GPU saving a different file (or overwriting it)?

@avolkov1
Copy link

avolkov1 commented Jan 2, 2018

@pGit1 Take a look at my example:
https://github.com/avolkov1/keras_experiments/blob/master/examples/cifar/cifar10_cnn_mgpu.py

Run like this to save weights:

python ./examples/cifar/cifar10_cnn_mgpu.py --epochs=3 --mgpu --checkpt --aug

Can then run again and it will load the checkpoint file and continue training. This will work with single GPU also.

CUDA_VISIBLE_DEVICES=0 python ./examples/cifar/cifar10_cnn_mgpu.py --epochs=3 --mgpu --checkpt --aug

I have a slightly different implementation for multigpu, but you can use the mutligpu implementation from Keras. Just wrap it in a class to use the non-multigpu model for saving and loading weights.
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/_multigpu.py#L129

The essence of the wrapper class for saving/loading weights is:

    def __getattribute__(self, attrname):
        '''Override load and save methods to be used from the serial-model. The
        serial-model holds references to the weights in the multi-gpu model.
        '''
        # return Model.__getattribute__(self, attrname)
        if 'load' in attrname or 'save' in attrname:
            return getattr(self._smodel, attrname)

        return super(ModelMGPU, self).__getattribute__(attrname)

This works with fit_generator.

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@fercook

Since the ModelCheckpoint is only saving the weights it may be overwriting it.

@avolkov1

Thank you! I'll take a look!!

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@fercook

I've confirmed that the example from the docs will not work with Model Checkpoint call back either.
FYI: my callback code -

best_wts_callback = callbacks.ModelCheckpoint(mod_wt_path, save_weights_only=True, save_best_only=True)

@avolkov1

Your example seems like it maywork but I having trouble thinking of a simple example of how to use. Your guidance would be much appreciated.

Is something like this feasible?

.
.
.
.
.
# model topology instantiation above
ser_model = Keras.models.Model(inputs = x, output=out)
parallel_model = avolkov1.make_parallel(serial_model = ser_model, gdev_list=['/gpu:0', '/gpu:1', '/gpu:2','/gpu:3',]),ps_device='/cpu:0', model_class=avolkov1.ModelMGPU)

#callback to save best weights
mod_wt_path = './best_weights.hdf5'
best_wts_callback = callbacks.ModelCheckpoint(mod_wt_path, save_weights_only=True, save_best_only=True)

parallel_model.fit(X, y, callbacks=[best_wts_callback])

#Now I want to infer on single GPU so I load saved weights ??
ser_model.load_weights(mod_wt_path)

ser_model.predict(X_holdout)

Would something like this work? Actually I need a more exact version of what would actually work.

THANK YOU!

EDIT:

Looking at you Cifar 10 example it looks like something like this would work. Im in a crunch so don't want to embark on the above journey if I am missing something glaring.

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@avolkov1

In general I think this line from docs in your code explain it all

'''Override load and save methods of the multi-gpu model. The load and
save should correspond to the serial model's load and save.

In general one should be easily be able to train in parallel on multiple GPUs use callbacks to save weights on the parallel run and load back those saved weights to the serial model that was parallized in the first place (without having to re-instantiate the serial model as a parallel model). I think your code allows one to train on 8 GPUs but then load weights and infer on one. It should be a option perhaps in the >=2.0.9 implementation? Training with keras.utils.multi_gpu_model() works great and definitely provides a speed up. It just doesn't play nice with Model Checkpoint, or weight saving/loading.

@avolkov1
Copy link

avolkov1 commented Jan 2, 2018

@pGit1 Yea, what you have there should work. Or you can can use the keras.utils.multi_gpu_model so create a wrapper class:

from keras import Model
from keras.utils import multi_gpu_model


class ModelMGPU(Model):
    def __init__(self, ser_model, gpus):
        pmodel = multi_gpu_model(ser_model, gpus)
        self.__dict__.update(pmodel.__dict__)
        self._smodel = ser_model

    def __getattribute__(self, attrname):
        '''Override load and save methods to be used from the serial-model. The
        serial-model holds references to the weights in the multi-gpu model.
        '''
        # return Model.__getattribute__(self, attrname)
        if 'load' in attrname or 'save' in attrname:
            return getattr(self._smodel, attrname)

        return super(ModelMGPU, self).__getattribute__(attrname)

Then you can use your example above with this new class.

# model topology instantiation above
ser_model = Keras.models.Model(inputs = x, output=out)
parallel_model = ModelMGPU(ser_model , 4)

#callback to save best weights
mod_wt_path = './best_weights.hdf5'
best_wts_callback = callbacks.ModelCheckpoint(mod_wt_path, save_weights_only=True, save_best_only=True)

# compile the parallel model prior to fit
parallel_model.fit(X, y, callbacks=[best_wts_callback])

#Now I want to infer on single GPU so I load saved weights ??
ser_model.load_weights(mod_wt_path)

# I think you might have to compile the serial model prior to predict
ser_model.predict(X_holdout)

@pGit1
Copy link

pGit1 commented Jan 2, 2018

@avolkov1

THANK YOU!! Your code works. To test I bypassed multi-gpu-model altogether.
I used raw code from https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/_multigpu.py#L129.

After training on a simple dummy data set, I call the function a function that returns two models (serial and parallel) and only choose the serial_model. Keep in mind during training I call the fit function with the parallel model not the serial model. I also feed my best weight callback to the parallel model during training.

Once this is done I load the learned weights into the serial model and get the expected results without any errors. I am not entirely sure why this works but it does. I confirmed multi-gpu training and single gpu inference. Now I am going to clean up my code to do something like you outline above.

Thanks again for your help!!

EDIT: The cleaned up version where you wrap the multi-gpu-model class works flawlessly. This is definitely my preferred method. Thanks again for all of your help. Your code is an extremely valuable contribution.

@nicolefinnie
Copy link

nicolefinnie commented Dec 25, 2018

EDIT on Jan 11, 2019
@avolkov1, I found the problem after I reported that I tried your approach and hit an issue with tensorflow keras 1.11, the mistake I made was to save the entire model with save_weights_only=False. In a result, the weights in the model are saved in a messed-up order that Keras code cannot read.

I've tried to customize ModelCheckPoint, however, the optimizer states are not saved correctly and I'm unable to resume the training properly. I'd say saving the template model for every N epochs instead of using the checkpoint and calling fit() every N epochs to resume the training. It's the most mundane way but I think it's the safest way that preserves the model's / optimizer's weights.

@oeminaga
Copy link

oeminaga commented Jan 20, 2019

@fchollet @pGit1 I @nicolefinnie @@avolkov1, I solved the problem using the following way. I changed some lines in the major codes of keras (particularly in topology.py or network.py, and callbacks.py). Here, I just modified the following codes.
Reminder: You need to replace 'save_weights_to_hdf5_group' with 'saving.save_weights_to_hdf5_group(f, layers)' if you use the recent version of Keras.

Callbacks.py:

class ModelCheckpoint(Callback):
"""Save the model after every epoch.

`filepath` can contain named formatting options,
which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`).

For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.

# Arguments
    filepath: string, path to save the model file.
    monitor: quantity to monitor.
    verbose: verbosity mode, 0 or 1.
    save_best_only: if `save_best_only=True`,
        the latest best model according to
        the quantity monitored will not be overwritten.
    mode: one of {auto, min, max}.
        If `save_best_only=True`, the decision
        to overwrite the current save file is made
        based on either the maximization or the
        minimization of the monitored quantity. For `val_acc`,
        this should be `max`, for `val_loss` this should
        be `min`, etc. In `auto` mode, the direction is
        automatically inferred from the name of the monitored quantity.
    save_weights_only: if True, then only the model's weights will be
        saved (`model.save_weights(filepath)`), else the full model
        is saved (`model.save(filepath)`).
    period: Interval (number of epochs) between checkpoints.
"""

def __init__(self, filepath, monitor='val_loss', verbose=0,
             save_best_only=False, save_weights_only=False,
             mode='auto', period=1, multi_gpu_mode=False, name_of_model=None):
    super(ModelCheckpoint, self).__init__()
    self.monitor = monitor
    self.verbose = verbose
    self.filepath = filepath
    self.save_best_only = save_best_only
    self.save_weights_only = save_weights_only
    self.name_of_model = name_of_model # Usually model_1, you can check the name by calling summary after running multi_gpu_model
    self.multi_gpu_mode = multi_gpu_mode
    self.period = period
    self.epochs_since_last_save = 0

    if mode not in ['auto', 'min', 'max']:
        warnings.warn('ModelCheckpoint mode %s is unknown, '
                      'fallback to auto mode.' % (mode),
                      RuntimeWarning)
        mode = 'auto'

    if mode == 'min':
        self.monitor_op = np.less
        self.best = np.Inf
    elif mode == 'max':
        self.monitor_op = np.greater
        self.best = -np.Inf
    else:
        if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            self.monitor_op = np.less
            self.best = np.Inf

def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    self.epochs_since_last_save += 1
    if self.epochs_since_last_save >= self.period:
        self.epochs_since_last_save = 0
        filepath = self.filepath.format(epoch=epoch + 1, **logs)
        if self.save_best_only:
            current = logs.get(self.monitor)
            if current is None:
                warnings.warn('Can save best model only with %s available, '
                              'skipping.' % (self.monitor), RuntimeWarning)
            else:
                if self.monitor_op(current, self.best):
                    if self.verbose > 0:
                        print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                              ' saving model to %s'
                              % (epoch + 1, self.monitor, self.best,
                                 current, filepath))
                    self.best = current
                    if self.save_weights_only:
                        self.model.save_weights(filepath, overwrite=True,  multiple_gpu=self.multi_gpu_mode, name_of_model=self.name_of_model)
                    else:
                        self.model.save(filepath, overwrite=True)
                else:
                    if self.verbose > 0:
                        print('\nEpoch %05d: %s did not improve from %0.5f' %
                              (epoch + 1, self.monitor, self.best))
        else:
            if self.verbose > 0:
                print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
            if self.save_weights_only:
                self.model.save_weights(filepath, overwrite=True, multiple_gpu=self.multi_gpu_mode)
            else:
                self.model.save(filepath, overwrite=True)

Topology.py/network.py:

def save_weights(self, filepath, overwrite=True, multiple_gpu=False, name_of_model=""):
    """Dumps all layer weights to a HDF5 file.

    The weight file has:
        - `layer_names` (attribute), a list of strings
            (ordered names of model layers).
        - For every layer, a `group` named `layer.name`
            - For every such layer group, a group attribute `weight_names`,
                a list of strings
                (ordered names of weights tensor of the layer).
            - For every weight in the layer, a dataset
                storing the weight value, named after the weight tensor.

    # Arguments
        filepath: String, path to the file to save the weights to.
        overwrite: Whether to silently overwrite any existing file at the
            target location, or provide the user with a manual prompt.

    # Raises
        ImportError: If h5py is not available.
    """
    if h5py is None:
        raise ImportError('`save_weights` requires h5py.')
    # If file exists and should not be overwritten:
    if not overwrite and os.path.isfile(filepath):
        proceed = ask_to_proceed_with_overwrite(filepath)
        if not proceed:
            return
    with h5py.File(filepath, 'w') as f:
        if multiple_gpu:
            layers = self.get_layer(name_of_model)
            layers = layers.layers
            save_weights_to_hdf5_group(f, layers)
        else:
            save_weights_to_hdf5_group(f, self.layers)
        f.flush()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests