-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transfer learning & Fine-tuning #17
Comments
Now my code is: Buildingclass Noah_transfer(BayesianCNNBase):
Trainingnoah_transfer = Noah_transfer() Both ''model'' and ''model_prediction'' can be printed by summary(), but it will raise an error during training:
It seems that 'label' hasn't been taken into the training and model_2(which means ''model'' in this code) received only one input(which seems to be x_train) |
Sorry for the late reply. I have add a function as a first step to solve your issue. So now this new function Here is an example: from astroNN.models import ApogeeBCNN
# a model trained on the original survey
bneuralnet = ApogeeBCNN()
bneuralnet.fit(xdata, ydata)
# another astroNN model
bneuralnet2 = ApogeeBCNN()
# just to initialize the model with the correct input and output shape
bneuralnet2.max_epochs = 1
bneuralnet2.fit(xdata_another_survey, ydata_another_survey)
# transfer all the weights except layers with incompatible shape
bneuralnet2.transfer_weights(bneuralnet)
# training for real, the middle part of the model is not trainable
bneuralnet2.max_epochs = 60
bneuralnet2.fit(xdata_another_survey, ydata_another_survey)
# now bneuralnet2 is your new astroNN model transferred to anther survey with the same architecture of the original survey |
Thank you for your reply. The two of us seem to have different ideas, your way is to transfer the weights of the base model while mine is to transfer the whole base model. Function transfer_weights() is a clever and effective way to do the transfer learning, it should be enough for me, for now.
|
There still are some bugs. When the output layer of my transfered model and the base model have the same number of nodes, the summary says that all of my params are non-trainable. But the weights of the transfered model's output layer should be trained. |
Yes it seems so that supposedly non-trainable parameters still get trained somehow. I am still investigating what is going on but most likely I need to set them to be non-trainable before compiling the model. As for the output layer, the current strategy is to transfer all weights with compatible shape (i.e. if shape of weights are the same for a layer, then transfer those weights). I think what you want is to only train the input layer?? Or you can force a different output shape so that output layers wont get transferred (i.e. maybe train on T_eff and Log(g) for one survey and fe_h for another survey so output shapes are different). I think there could be a case where you have a small overlap between two surveys, then you can use the spectra from survey B but only train the input layer with label from the original survey A? Regarding your questions from a few days ago, what do you mean by training step goes wrong? And yes splicing/adding layers probably requires more work but its not undoable per say but we need to make the simplest case working correctly first... |
Thank you for your patience and reply. The training step failure happened because of model splicing a few days ago, but as you said, we should make the simplest case work first, so let's talk about it later. |
I think I have fixed the issue of weights still being trained even after setting |
Thank you for all the effort, it works now. |
Hi, Henry. I've got a well trained astroNN model, but I want to do some transfer learning to make it adaptable to another survey. What I've done is remove the top dense layer of the base model and build a new dense layer, but now it can only be treat like an ordinary keras model. By the way, the base model itself is a custom model under the parent class ''BayesianCNNBase''
I'm wondering:
Thank you!
The text was updated successfully, but these errors were encountered: