Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set "d_container.trainable = True" after training all_model #13

Closed
bis-carbon opened this issue Mar 28, 2019 · 4 comments
Closed

Set "d_container.trainable = True" after training all_model #13

bis-carbon opened this issue Mar 28, 2019 · 4 comments

Comments

@bis-carbon
Copy link

bis-carbon commented Mar 28, 2019

Thank you for the great work. Don't you think d_container.trainable should be set to True after training all_model.
Something like this:

d_container.trainable = True 
all_model.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer) 
@neka-nat
Copy link
Owner

Thank you for your opinion.
Probably I think that trainable is correct with false.

I attached the pseudo code of the paper.
The place where all_model is used is line 10. Here the discriminator is fixed and learning only the completion network.
The discriminator update is done on line 8, and it uses d_model in my code.
algorithm

@bis-carbon
Copy link
Author

bis-carbon commented Mar 28, 2019

I agree that d_container.trainable = False but once you make the discriminator non trainable then you wouldn't able to train it on the following batches. The discriminator has to be trained as long as t>Tc and in order to do that I guess we need to set d_container.trainable = True after the completion network is trained. Correct me if I am wrong and thank you for your quick response.

The algorithm on the paper is something like this:

            if n < tc:
                ''' Train completion network '''
            elif n<tc+td:
                ''' Train discriminator network '''
            else:
                ''' Train both completion and discriminator '''

What I am suggesting is something like this:

if n >= tc + td:
     d_container.trainable = False
    all_model = Model([org_img, mask, in_pts], [cmp_out, d_container([cmp_out, in_pts])])
    all_model.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer)
    g_loss = all_model.train_on_batch([inputs, masks, points], [inputs, valid])
    g_loss = g_loss[0] + alpha * g_loss[1]
  
    "" the following codes makes the discriminator trainable again on the following batch "" 
     d_container.trainable = True 
     all_model.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer) ```
             

@neka-nat
Copy link
Owner

The following links may be helpful for your point.
https://stackoverflow.com/questions/45154180/how-to-dynamically-freeze-weights-after-compiling-model-in-keras

The trainable flag is fixed in the model at compile time.
So, changes to the flag after compile will not affect the compiled model.

@bis-carbon
Copy link
Author

Thank you, that clarifies my question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants