Converting convolution kernels from Theano to TensorFlow and vice versa

ColaColin edited this page Jun 13, 2017 · 6 revisions

If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D), be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe), and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano). Here's how.

From Theano to TensorFlow

Keras backend should be TensorFlow in this case.

First, load the Theano-trained weights into your TensorFlow model:

model.load_weights('my_weights_theano.h5')

Then, iterate over the weights and collect conversion ops:

from keras import backend as K
from keras.utils.conv_utils import convert_kernel
import tensorflow as tf
ops = []
for layer in model.layers:
   if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D', 'Convolution3D', 'AtrousConvolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)

Finally, run all conversion ops at once in the global Keras session (this is to make the process faster):

K.get_session().run(ops)

and save it:

model.save_weights('my_weights_tensorflow.h5')

From TensorFlow to Theano

Keras backend should be Theano in this case.

First, load the TensorFlow-trained weights into your Theano model:

model.load_weights('my_weights_tensorflow.h5')

Then, just iterate over the weights and convert them on the fly:

from keras import backend as K
from keras.utils.conv_utils import convert_kernel

for layer in model.layers:
   if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)

and save it:

model.save_weights('my_weights_theano.h5')

That's it! (note that the above also works with TensorFlow, but it would be slower than the TensorFlow-only method we first outlined).

You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.
Press h to open a hovercard with more details.