-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there a way to save and load GAN models without losing the optimizer state? #10806
Comments
If you are defining custom optimizer then you need to import the definition. Or you can define it again when you load_model again, and compile with that optimizer. The model won't lose it's weights. |
Thanks, I was thinking of the state of the optimizer, e.g. if you are using learning decay or momentum. If you recompile it, the weights are fine, but I lose the optimizer state. |
Further, this script demonstrates that for CGANs (I'm assuming it also applies to the non-conditional variants), saving and loading does not preserve the discriminators state. import numpy as np
from keras import Input, Model, losses, optimizers
from keras.engine.saving import load_model
from keras.layers import Dense, concatenate
# Arbitrary constants
N_LATENT = 100
N_FEATURES = 100
N_FLAGS = 7
N_ROWS = 100
print("Building inputs")
noise = Input(shape=(N_LATENT,), name="noise")
flags = Input(shape=(N_FLAGS,), name="flags")
features = Input(shape=(N_FEATURES,), name="features")
print("Discriminator")
d = concatenate([features, flags])
d = Dense(52, activation='relu')(d)
d = Dense(52, activation='relu')(d)
d_out = Dense(1, name='d_out')(d)
D = Model([features, flags], d_out, name="D")
D.compile(
loss=losses.binary_crossentropy,
optimizer=optimizers.Adadelta(),
)
D.summary()
print("Generator")
g = concatenate([features, noise])
g = Dense(52, activation='relu')(g)
g = Dense(52, activation='relu')(g)
g_out = Dense(7, activation='sigmoid', name='g_out')(g)
G = Model([features, noise], g_out, name="G")
G.summary()
print("GAN")
for l in D.layers:
l.trainable = False
gan_out = D([G([features, noise]), features])
GAN = Model([features, noise], gan_out)
GAN.compile(
loss=losses.binary_crossentropy,
optimizer=optimizers.Adadelta(),
)
GAN.summary()
features = np.random.normal(0, 1, (N_ROWS, 100))
noise = np.random.normal(0, 1, (N_ROWS, N_LATENT))
flags = np.random.uniform(0, 1, (N_ROWS, 7))
ones = np.ones((N_ROWS, 1))
# Save
D.save('./D')
G.save('./G')
GAN.save('./GAN')
del D
del G
del GAN
print("D")
D = load_model('./D')
D.summary()
print("G")
G = load_model('./G')
G.summary()
print("GAN")
GAN = load_model('./GAN')
GAN.summary() After running that script, examine the output of D.summary(). All parameters are marked as non-trainable, and along with that, the optimizer has been refreshed. Are there any fixes for this? |
running into same issue also found similar issue here #9589 |
solved it by upgrading Keras to 2.2.4 and using pickle. i was having issue with pix2pix model |
@ismailsimsek Do you have a quick example I can test? Also, I created a related question on StackOverflow: https://stackoverflow.com/questions/52463551/how-do-you-train-gans-using-multiple-gpus-with-keras |
I'm also struggling to resolve this error. Is there a way to easily save and load a GAN without losing information? |
I think I managed to finally solve this issue after much frustration and eventually switching to
What's interesting is that this isn't the case with Another thing to keep in mind is that So in the context of GANs, one needs to ensure the following:
One can carry out saving and loading as follows:
You'll notice that its necessary to switch the It's also important to note that this example will only work on |
@ bradsheppard can you confirm that your suggested method works? |
@bradsheppard 's solution worked for me (currently using TensorFlow 2.1.0) |
@bradsheppard solution solved the issue that the model needs to be recompiled; however, the discriminator (for fake images) and combined model loss values go around 0. Also, the fake-generated images do not get better (maybe because they do not change)! In the first sample image, the gan model (generator, discriminator, and the combined model) is trained for 5 epochs (each epoch has 128 batches of images). In this step, all models were created from scratch. Then, the models were loaded using the provided solution by @bradsheppard, and the training process continued for another 5 epochs. (The X-axis is iterations) In the second sample image, all models were created from scratch and trained for 10 epochs continuously without loading them from saved models. (The X-axis is iterations) My TensorFlow version is 2.5, and I am using Google Colaboratory to perform the mentioned tests. |
Thanks @bradsheppard , your method worked like a charm :) |
@sinajhrm - I was experiencing the same issue. Try implementing the @bradsheppard method on just your discriminator and generator. Then load a fresh gan from your original code before resuming training. For my use, training for 10 epochs, then resuming using this method for an additional 10 epochs provided similar image output results and metrics to training for 20 from scratch. |
@jroback - I will try your method ASAP. However, what do you mean by "loading a fresh gan from my original code"? Do you mean that I should create a new combined model using the loaded Discriminator and Generator and then compile the combined model with a freshly defined optimizer? (if I am correct about your method, I should say, in this way, I will lose the optimizer states of the previous combined model) |
@sinajhrm - Yes, try defining a new gan from your loaded generator and discriminator. I should caveat by saying that I'm very new to GANs and data science so there may be something I'm missing. But in evaluating the image output and the metrics, training appears to have resumed. |
I don't like reopening old issues but I tried using the @bradsheppard method and this doesn't work as expected. When the models are first created, the discriminator weights are shared across both the gan and discriminator models and when the discriminator is updated in training, the gan model will use the latest state of the weights as updated when the discriminator is trained. However, when loading in the models again, because they are loaded in separately the link between the gan model and the discriminator model is broken - when the discriminator is updated, the untrainable weights in the discriminator portion of the gan model are not updated and will never be updated after loading the model in again. I've tested this by looking at the weights themselves and they are indeed the same in each epoch in the gan model even though they should be updated when the discriminator is trained. I'm not sure what exactly but another solution is needed for this, at least this is what happened for me so I'd be interested to know if it worked differently on a different version or something for others. |
@mattgrayling did you solve the problem ? I also used the mentioned approach and have found that my pictures don't change at all so that means no training is happening. |
@tstrych - try using Checkpoint and CheckpointManager classes. These Tensorflow built-in classes solved my issue regarding saving GAN models and continuing the training process. Also, You can define a custom callback for saving your models with a desired frequency! Note that when I say "these classes solved my issue," I am telling this based on my observation of loss values and generated fake images. To elaborate, I did not monitor the generator or discriminator weights. |
@mattgrayling is there a way to solve this issue? We are also facing the similar issue. |
I would echo @sinajhrm by suggesting using Checkpoint and CheckpointManager tensorflow classes, I was able to use these to get the model to save and resume training with the links between the weights preserved (this led to another issue I still haven't been able to solve, as with this approach the model training does not seem to work very well in general for me, but I think this is a separate issue that I have in my setup as other people have been able to use this successfully) |
Agree with @mattgrayling and @sinajhrm on using Checkpoint. I did below
Then to load a specific checkpoint and use the model
|
I know one use can be using Checkpoint and CheckpointManager tensorflow classes to save and resume training the GAN. I am looking for a way to save the Generator (G) , Discriminator(D) and their optimisation states such that I can manipulate the model weights (for G and D). One use case for me is average the weights of different Generator in Federated Learning for GAN. Does someone know a way? |
This worked perfectly for me! |
@sinajhrm Did you fix the problem of the combined gan model loss going to zero when resuming training ? |
Since the discriminators are included in the GAN and they also need to be used separately during training - how do you save and load GANs? Now, I save the generators and discriminators separately and recompile the GAN for each training episode, but I lose the optimizer state this way. Before you could extract the optimizer state, but it was removed a few releases ago.
The text was updated successfully, but these errors were encountered: