Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way to save and load GAN models without losing the optimizer state? #10806

Closed
emilwallner opened this issue Jul 30, 2018 · 24 comments
Closed

Comments

@emilwallner
Copy link

Since the discriminators are included in the GAN and they also need to be used separately during training - how do you save and load GANs? Now, I save the generators and discriminators separately and recompile the GAN for each training episode, but I lose the optimizer state this way. Before you could extract the optimizer state, but it was removed a few releases ago.

@jadevaibhav
Copy link

If you are defining custom optimizer then you need to import the definition. Or you can define it again when you load_model again, and compile with that optimizer. The model won't lose it's weights.

@emilwallner
Copy link
Author

Thanks, I was thinking of the state of the optimizer, e.g. if you are using learning decay or momentum. If you recompile it, the weights are fine, but I lose the optimizer state.

@thoward27
Copy link

Further, this script demonstrates that for CGANs (I'm assuming it also applies to the non-conditional variants), saving and loading does not preserve the discriminators state.

import numpy as np
from keras import Input, Model, losses, optimizers
from keras.engine.saving import load_model
from keras.layers import Dense, concatenate

# Arbitrary constants
N_LATENT = 100
N_FEATURES = 100
N_FLAGS = 7
N_ROWS = 100

print("Building inputs")
noise = Input(shape=(N_LATENT,), name="noise")
flags = Input(shape=(N_FLAGS,), name="flags")
features = Input(shape=(N_FEATURES,), name="features")

print("Discriminator")
d = concatenate([features, flags])
d = Dense(52, activation='relu')(d)
d = Dense(52, activation='relu')(d)
d_out = Dense(1, name='d_out')(d)
D = Model([features, flags], d_out, name="D")
D.compile(
    loss=losses.binary_crossentropy,
    optimizer=optimizers.Adadelta(),
)
D.summary()

print("Generator")
g = concatenate([features, noise])
g = Dense(52, activation='relu')(g)
g = Dense(52, activation='relu')(g)
g_out = Dense(7, activation='sigmoid', name='g_out')(g)
G = Model([features, noise], g_out, name="G")
G.summary()

print("GAN")
for l in D.layers:
    l.trainable = False
gan_out = D([G([features, noise]), features])
GAN = Model([features, noise], gan_out)
GAN.compile(
    loss=losses.binary_crossentropy,
    optimizer=optimizers.Adadelta(),
)
GAN.summary()

features = np.random.normal(0, 1, (N_ROWS, 100))
noise = np.random.normal(0, 1, (N_ROWS, N_LATENT))
flags = np.random.uniform(0, 1, (N_ROWS, 7))
ones = np.ones((N_ROWS, 1))

# Save
D.save('./D')
G.save('./G')
GAN.save('./GAN')

del D
del G
del GAN

print("D")
D = load_model('./D')
D.summary()

print("G")
G = load_model('./G')
G.summary()

print("GAN")
GAN = load_model('./GAN')
GAN.summary()

After running that script, examine the output of D.summary(). All parameters are marked as non-trainable, and along with that, the optimizer has been refreshed.

Are there any fixes for this?

@ismailsimsek
Copy link

ismailsimsek commented Oct 5, 2018

running into same issue also found similar issue here #9589

@ismailsimsek
Copy link

solved it by upgrading Keras to 2.2.4 and using pickle. i was having issue with pix2pix model

@emilwallner
Copy link
Author

@ismailsimsek Do you have a quick example I can test?

Also, I created a related question on StackOverflow: https://stackoverflow.com/questions/52463551/how-do-you-train-gans-using-multiple-gpus-with-keras

@ChosunOne
Copy link

I'm also struggling to resolve this error. Is there a way to easily save and load a GAN without losing information?

@bradsheppard
Copy link

I think I managed to finally solve this issue after much frustration and eventually switching to tensorflow.keras. I'll summarize.

keras doesn't seem to respect model.trainable when re-loading a model. So if you have a model with an inner submodel with submodel.trainable = False, when you attempt to reload model at a later point and call model.summary() you will notice that all layers are trainable and then you get that optimizer state warning when loading model.

What's interesting is that this isn't the case with tensorflow.keras. In that library if you set submodel.trainable = False and reload model latter, you'll notice that model.summary does in fact have quite a large number of un-trainable parameters.

Another thing to keep in mind is that submodel.trainable behaves differently when training versus saving the model. For training, whatever trainable is set to prior to calling model.compile is what is respected in training. However, when calling model.save() all that matters is what trainable is set to prior to calling save (it doesn't care about what trainable was when compiled).

So in the context of GANs, one needs to ensure the following:

  1. Dump keras and switch to tensorflow.keras.
  2. Assuming gan is the combined model, generator and discriminator are the submodels, then one can carry out constructing the models as follows:
def create_generator():
    generator = Sequential()
    ...

    return generator


def create_discriminator():
    discriminator = Sequential()
    ...
    return discriminator


def create_gan(generator, discriminator):
    discriminator.trainable = False

    gan_input = Input(shape=(INPUT_SIZE,))
    generator_output = generator(gan_input)
    gan_output = discriminator(generator_output)

    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

    gan.summary()
    discriminator.trainable = True
    return gan

One can carry out saving and loading as follows:

def save(gan, generator, discriminator):
    discriminator.trainable = False
    save_model(gan, 'gan')
    discriminator.trainable = True
    save_model(generator, 'generator')
    save_model(discriminator, 'discriminator')


def load():
    discriminator = load_model('discriminator')
    generator = load_model('generator')
    gan = load_model('gan')
    gan.summary()
    discriminator.summary()
    generator.summary()

    return gan, generator, discriminator

You'll notice that its necessary to switch the trainable state of discriminator prior to saving the gan in order to ensure the discriminator part of it isn't trainable, and hence properly re-loaded later. It doesn't matter whether or not it was set during compile.

It's also important to note that this example will only work on tensorflow.keras and not keras. Let me know if this helps or if anyone finds any issues.

@AloneTogetherY
Copy link

@ bradsheppard can you confirm that your suggested method works?

@hewletm
Copy link

hewletm commented Feb 13, 2020

@bradsheppard 's solution worked for me (currently using TensorFlow 2.1.0)

@sinajhrm
Copy link

urn gan, generator, discriminator

@bradsheppard solution solved the issue that the model needs to be recompiled; however, the discriminator (for fake images) and combined model loss values go around 0. Also, the fake-generated images do not get better (maybe because they do not change)!

In the first sample image, the gan model (generator, discriminator, and the combined model) is trained for 5 epochs (each epoch has 128 batches of images). In this step, all models were created from scratch. Then, the models were loaded using the provided solution by @bradsheppard, and the training process continued for another 5 epochs. (The X-axis is iterations)
losses_plots_test_1

In the second sample image, all models were created from scratch and trained for 10 epochs continuously without loading them from saved models. (The X-axis is iterations)
losses_plots_test_2

My TensorFlow version is 2.5, and I am using Google Colaboratory to perform the mentioned tests.

@LaurinHerbsthofer
Copy link

Thanks @bradsheppard , your method worked like a charm :)

@jroback
Copy link

jroback commented Aug 10, 2021

@sinajhrm - I was experiencing the same issue. Try implementing the @bradsheppard method on just your discriminator and generator. Then load a fresh gan from your original code before resuming training. For my use, training for 10 epochs, then resuming using this method for an additional 10 epochs provided similar image output results and metrics to training for 20 from scratch.

@sinajhrm
Copy link

@jroback - I will try your method ASAP. However, what do you mean by "loading a fresh gan from my original code"? Do you mean that I should create a new combined model using the loaded Discriminator and Generator and then compile the combined model with a freshly defined optimizer? (if I am correct about your method, I should say, in this way, I will lose the optimizer states of the previous combined model)

@jroback
Copy link

jroback commented Aug 11, 2021

@sinajhrm - Yes, try defining a new gan from your loaded generator and discriminator. I should caveat by saying that I'm very new to GANs and data science so there may be something I'm missing. But in evaluating the image output and the metrics, training appears to have resumed.

@mattgrayling
Copy link

mattgrayling commented Nov 10, 2021

I don't like reopening old issues but I tried using the @bradsheppard method and this doesn't work as expected. When the models are first created, the discriminator weights are shared across both the gan and discriminator models and when the discriminator is updated in training, the gan model will use the latest state of the weights as updated when the discriminator is trained. However, when loading in the models again, because they are loaded in separately the link between the gan model and the discriminator model is broken - when the discriminator is updated, the untrainable weights in the discriminator portion of the gan model are not updated and will never be updated after loading the model in again. I've tested this by looking at the weights themselves and they are indeed the same in each epoch in the gan model even though they should be updated when the discriminator is trained. I'm not sure what exactly but another solution is needed for this, at least this is what happened for me so I'd be interested to know if it worked differently on a different version or something for others.

@tstrych
Copy link

tstrych commented Dec 15, 2021

@mattgrayling did you solve the problem ? I also used the mentioned approach and have found that my pictures don't change at all so that means no training is happening.

@sinajhrm
Copy link

@tstrych - try using Checkpoint and CheckpointManager classes. These Tensorflow built-in classes solved my issue regarding saving GAN models and continuing the training process. Also, You can define a custom callback for saving your models with a desired frequency! Note that when I say "these classes solved my issue," I am telling this based on my observation of loss values and generated fake images. To elaborate, I did not monitor the generator or discriminator weights.

@bhaveshneekhra
Copy link

@mattgrayling is there a way to solve this issue? We are also facing the similar issue.

@mattgrayling
Copy link

I would echo @sinajhrm by suggesting using Checkpoint and CheckpointManager tensorflow classes, I was able to use these to get the model to save and resume training with the links between the weights preserved (this led to another issue I still haven't been able to solve, as with this approach the model training does not seem to work very well in general for me, but I think this is a separate issue that I have in my setup as other people have been able to use this successfully)

@nahidalam
Copy link

Agree with @mattgrayling and @sinajhrm on using Checkpoint.

I did below

checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_freq = 'epoch', 
    period = 10
)

history = gan_model.fit(..., epochs=EPOCHS,
                        verbose=2,
                        callbacks=model_checkpoint_callback).history

Then to load a specific checkpoint and use the model

weight_file = "model_checkpoints/cyclegan_checkpoints.023"
gan_model.load_weights(weight_file).expect_partial()
print("Weights loaded successfully")

@bhaveshneekhra
Copy link

I know one use can be using Checkpoint and CheckpointManager tensorflow classes to save and resume training the GAN.

I am looking for a way to save the Generator (G) , Discriminator(D) and their optimisation states such that I can manipulate the model weights (for G and D). One use case for me is average the weights of different Generator in Federated Learning for GAN.

Does someone know a way?

@JanOlucha
Copy link

@sinajhrm - I was experiencing the same issue. Try implementing the @bradsheppard method on just your discriminator and generator. Then load a fresh gan from your original code before resuming training. For my use, training for 10 epochs, then resuming using this method for an additional 10 epochs provided similar image output results and metrics to training for 20 from scratch.

This worked perfectly for me!

@anis-hentit
Copy link

anis-hentit commented Jan 25, 2024

@sinajhrm Did you fix the problem of the combined gan model loss going to zero when resuming training ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests