Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit(DataSetIterator) without pretrain #4279

Merged
merged 2 commits into from Nov 13, 2017

Conversation

@clavvis
Copy link

commented Nov 10, 2017

What changes were proposed in this pull request?

MultiLayerNetwork method fit( DataSetIterator) do in one epoch pretraining and backpropagation too. It is not suitable.

Training NN needs some epochs only pretraining and later some epochs only backpropagation.

I suggest:

  • do pretraning for DataSetIterator in method pretrain(DataSetIterator)
  • in method fit(DataSetIterator) do only backpropagation

In pretrain(DataSetIterator) method I added for TrainingListener's set onEpochStart(this).

How was this patch tested?

I checked manually code.

@agibsonccc
Copy link

left a comment

@AlexDBlack if I remember right this makes sense..making sure on this one.

@huitseeker huitseeker requested review from AlexDBlack and eraly Nov 11, 2017

@AlexDBlack
Copy link
Contributor

left a comment

The change itself looks fine. I agree that doing both pretrain and backprop on the one fit(DataSetIterator) call can be confusing for users (i.e., they won't at all get what they want if they do a loop externally for multiple epochs).

Even better would be to add a javadoc comment on the behaviour for the fit methods (with an {@link ...} to pretrain and pretrainLayer methods, basically saying "this method doesn't do layerwise pretraining, use methods X and Y instead")

ComputationGraph needs to be similarly changed:
https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java#L893-L895

https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java#L1004-L1006

@clavvis

This comment has been minimized.

Copy link
Author

commented Nov 13, 2017

I changed ComputationGraph and added java doc.

@AlexDBlack AlexDBlack merged commit 8dc6371 into eclipse:master Nov 13, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.