Skip to content

Commit

Permalink
Update CNN optimizer for new Keras API
Browse files Browse the repository at this point in the history
  • Loading branch information
duncanwp committed Jun 12, 2023
1 parent 8021be3 commit c7a097f
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion esem/__init__.py
Expand Up @@ -213,6 +213,7 @@ def cnn_model(training_params, training_data, data_processors=None,
from tensorflow.keras.layers import Dense, Input, Reshape, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.backend import floatx
import numpy as np

Expand Down Expand Up @@ -249,7 +250,13 @@ def cnn_model(training_params, training_data, data_processors=None,

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.compile(optimizer=optimizer(learning_rate=learning_rate, decay=decay), loss=loss)

lr_schedule = ExponentialDecay(
initial_learning_rate=learning_rate,
decay_steps=10000,
decay_rate=decay)

decoder.compile(optimizer=optimizer(learning_rate=lr_schedule), loss=loss)

model = KerasModel(decoder)

Expand Down

0 comments on commit c7a097f

Please sign in to comment.