You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While trying to update the parameters of a model which is created using Keras and train_op in tensorflow, the parameters are not getting updated.
Below is a minimum working example
# Test for mixing Keras and TF for MLPfrom __future__ importprint_functionimporttensorflow.kerasaskerasfromtensorflow.keras.modelsimportModelfromtensorflow.keras.layersimportInput, Densefromtensorflow.keras.optimizersimportSGDimporttensorflow.keras.backendasKimporttensorflowastfimportOptimizersasOpt# from spsa import SimultaneousPerturbationOptimizer # Replace with thisimportnumpyasnpK.clear_session()
tf.reset_default_graph()
# ---------------- Create a dataset -------------------------n_train=1000n_test=100y_train=np.random.randint(2, size=(n_train,2))
y_test=np.random.randint(2, size=(n_test,2))
x_train=0.5*np.random.randn(n_train, 2) +y_trainx_test=0.5*np.random.randn(n_test,2) +y_test# --------------- Create model using Keras -------------------withtf.name_scope("Model"):
inputs=Input(shape=(2,))
l1=Dense(4, activation='relu')(inputs)
l2=Dense(4, activation='relu')(l1)
predictions=Dense(2, activation='softmax')(l2)
model=Model(inputs=inputs, outputs=predictions)
model.summary()
# --------------- Create place holder and loss with TF------withtf.name_scope("Model"):
y_ph_tf=tf.placeholder(tf.float32, shape=[None, 2])
loss=tf.losses.softmax_cross_entropy(y_ph_tf, predictions)
params=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "Model")
# trainer = tf.train.GradientDescentOptimizer(learning_rate=0.5) # This workstrainer=Opt.SimultaneousPerturbationOptimizer(a=0.05,c=0.05,alpha=0.99,gamma=0.40) # This don't_train=trainer.minimize(loss)
sess=K.get_session()
sess.run(tf.global_variables_initializer())
print(">>>>>> PRE")
print(sess.run(params[:2])) # Print foriinrange(20):
l, _=sess.run([loss,_train], feed_dict={inputs:x_train,y_ph_tf:y_train})
print(">>>>>> POST")
print(sess.run(params[:2]))
This produces output as below
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 2) 0
_________________________________________________________________
dense (Dense) (None, 4) 12
_________________________________________________________________
dense_1 (Dense) (None, 4) 20
_________________________________________________________________
dense_2 (Dense) (None, 2) 10
=================================================================
Total params: 42
Trainable params: 42
Non-trainable params: 0
_________________________________________________________________
SPSA: a = 0.050 c = 0.050 alpha = 0.990 gamma = 0.400
42 parameters
WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.
>>>>>> PRE
[array([[ 3.8956022e-01, -3.0612946e-01, 3.6983490e-02, -2.1815300e-04],
[ 6.0597062e-01, 1.7913413e-01, -6.0641050e-01, -2.9886746e-01]],
dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
>>>>>> POST
[array([[ 3.8956022e-01, -3.0612946e-01, 3.6983490e-02, -2.1815300e-04],
[ 6.0597062e-01, 1.7913413e-01, -6.0641050e-01, -2.9886746e-01]],
dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
At "POST"-update, there is no change in value of parameters.
Using tf-1.9.
The text was updated successfully, but these errors were encountered:
v-i-s-h
changed the title
Do not update weights when tried to mix with Keras
Not updating weights when tried to mix with Keras
Apr 14, 2019
While trying to update the parameters of a model which is created using Keras and train_op in tensorflow, the parameters are not getting updated.
Below is a minimum working example
This produces output as below
At "POST"-update, there is no change in value of parameters.
Using tf-1.9.
The text was updated successfully, but these errors were encountered: