New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit(DataSetIterator) without pretrain #4279

Merged
merged 2 commits into from Nov 13, 2017

Conversation

Projects
None yet
3 participants
@clavvis
Copy link
Contributor

clavvis commented Nov 10, 2017

What changes were proposed in this pull request?

MultiLayerNetwork method fit( DataSetIterator) do in one epoch pretraining and backpropagation too. It is not suitable.

Training NN needs some epochs only pretraining and later some epochs only backpropagation.

I suggest:

  • do pretraning for DataSetIterator in method pretrain(DataSetIterator)
  • in method fit(DataSetIterator) do only backpropagation

In pretrain(DataSetIterator) method I added for TrainingListener's set onEpochStart(this).

How was this patch tested?

I checked manually code.

@agibsonccc
Copy link
Member

agibsonccc left a comment

@AlexDBlack if I remember right this makes sense..making sure on this one.

@huitseeker huitseeker requested review from AlexDBlack and eraly Nov 11, 2017

@AlexDBlack
Copy link
Member

AlexDBlack left a comment

The change itself looks fine. I agree that doing both pretrain and backprop on the one fit(DataSetIterator) call can be confusing for users (i.e., they won't at all get what they want if they do a loop externally for multiple epochs).

Even better would be to add a javadoc comment on the behaviour for the fit methods (with an {@link ...} to pretrain and pretrainLayer methods, basically saying "this method doesn't do layerwise pretraining, use methods X and Y instead")

ComputationGraph needs to be similarly changed:
https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java#L893-L895

https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java#L1004-L1006

@clavvis

This comment has been minimized.

Copy link
Contributor

clavvis commented Nov 13, 2017

I changed ComputationGraph and added java doc.

@AlexDBlack AlexDBlack merged commit 8dc6371 into deeplearning4j:master Nov 13, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment