-
Notifications
You must be signed in to change notification settings - Fork 5
/
EarlyStoppingGraphFeaturizedTrainer.java
35 lines (29 loc) · 1.47 KB
/
EarlyStoppingGraphFeaturizedTrainer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package org.genericsystem.cv.nn;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
public class EarlyStoppingGraphFeaturizedTrainer extends EarlyStoppingGraphTrainer {
private TransferLearningHelper transferLearningHelper;
public EarlyStoppingGraphFeaturizedTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, TransferLearningHelper transferLearningHelper,
DataSetIterator train) {
super(esConfig, transferLearningHelper.unfrozenGraph(), train, null);
this.transferLearningHelper = transferLearningHelper;
}
public EarlyStoppingGraphFeaturizedTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, TransferLearningHelper transferLearningHelper, MultiDataSetIterator train) {
super(esConfig, transferLearningHelper.unfrozenGraph(), train, null);
this.transferLearningHelper = transferLearningHelper;
}
@Override
protected void fit(DataSet ds) {
transferLearningHelper.fitFeaturized(ds);
}
@Override
protected void fit(MultiDataSet mds) {
transferLearningHelper.fitFeaturized((org.nd4j.linalg.dataset.MultiDataSet) mds);
}
}