In [1]:
import numpy as np
from keras.layers import Dense, Input, Flatten, Reshape
from keras.losses import mse, mae, binary_crossentropy
from keras.models import Model
from keras.optimizers import Adam

def settrainable(model, toset):
    for layer in model.layers:
        layer.trainable = toset
    model.trainable = toset

input_shape=(1024,1)
layers = 4
latent = 1024

inputs = Input(shape=input_shape)
x = Flatten()(inputs)

for ilayer in range(layers):
    x = Dense(latent,activation='relu')(x)

outputs = Reshape(input_shape)(x)
model1 = Model(inputs,outputs)
model1.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy")
model1.summary()

inputs2 = Input(shape=input_shape)
x = Flatten()(inputs2)

for ilayer in range(layers):
    x = Dense(latent,activation='relu')(x)

outputs2 = Reshape(input_shape)(x)
model2 = Model(inputs2,outputs2)
model2.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy")
model2.summary()

settrainable(model1,True)
settrainable(model2,False)
outputs3 = model2(model1(inputs))
model3 = Model(inputs,outputs3)
model3.compile(optimizer=Adam(lr=1e-4), loss="binary_crossentropy")
model3.summary()

Using TensorFlow backend.


Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 1024, 1)           0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
dense_3 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
resh

In [2]:
# make sure the layers are the same between the individual models and the composite model
assert (model3.layers[1].layers[2].get_weights()[0] == model1.layers[2].get_weights()[0]).all()
assert (model3.layers[2].layers[2].get_weights()[0] == model2.layers[2].get_weights()[0]).all()
# store the weights for the models before we train
wm1 = model3.layers[1].layers[2].get_weights()[0]
wm2 = model3.layers[2].layers[2].get_weights()[0]
wm2_2 = model2.layers[2].get_weights()[0]
assert (wm2_2 == wm2).all()

In [3]:
input_data = np.random.uniform(0,1,(10000,1024,1))
intermediate_data = np.random.uniform(0,1,(10000,1024,1))
output_data = np.random.uniform(0,1,(10000,1024,1))

In [4]:
model3.fit(input_data,output_data,epochs=2)

Instructions for updating:
Use tf.cast instead.
Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fb1d4306a58>

In [5]:
# make sure that the layers are still the same between the individual models and the composite model
assert (model3.layers[1].layers[2].get_weights()[0] == model1.layers[2].get_weights()[0]).all()
assert (model3.layers[2].layers[2].get_weights()[0] == model2.layers[2].get_weights()[0]).all()
# make sure that the weights changed for model1
assert not (model3.layers[1].layers[2].get_weights()[0] == wm1).all()
# make sure that the weights did not change for model2
assert (model3.layers[2].layers[2].get_weights()[0] == wm2).all()

In [6]:
model2.fit(input_data,output_data,epochs=2)

  'Discrepancy between trainable weights and collected trainable'


Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fb1d43ca668>

In [7]:
# make sure that the layers between model3 and model2 are still the same
assert (model3.layers[2].layers[2].get_weights()[0] == model2.layers[2].get_weights()[0]).all()
# make sure that the weights for model2 changed
assert not (model3.layers[2].layers[2].get_weights()[0] == wm2).all()

In [9]:
wm2_3 = model2.layers[2].get_weights()[0]

In [10]:
model3.fit(input_data,output_data,epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fb170271518>

In [12]:
# make sure that when we train model3 the layer2 weights don't change
assert (wm2_3 == model2.layers[2].get_weights()[0]).all()
# make sure that the weights are still the same between the initial models and the composite model
assert (model3.layers[1].layers[3].get_weights()[0] == model1.layers[3].get_weights()[0]).all()
assert (model3.layers[2].layers[3].get_weights()[0] == model2.layers[3].get_weights()[0]).all()