-
Notifications
You must be signed in to change notification settings - Fork 4
/
SPGradientGenerator.java
58 lines (43 loc) · 1.48 KB
/
SPGradientGenerator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*
* Copyright (c) 2020 Georgios Damaskinos
* All rights reserved.
* @author Georgios Damaskinos <georgios.damaskinos@gmail.com>
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
package apps;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import apps.cppNN.CppNNGradientGenerator;
import apps.lr.LRGradientGenerator;
import coreComponents.GradientGenerator;
public class SPGradientGenerator implements GradientGenerator {
// TODO: modify with the Application's SPGradientGenerator
//SimpleSimpleCNNGradientGenerator gen = new SimpleSimpleCNNGradientGenerator();
CppNNGradientGenerator gen = new CppNNGradientGenerator();
//Dl4jGradientGenerator gen = new Dl4jGradientGenerator();
//MLPGradientGenerator gen = new MLPGradientGenerator();
//LRGradientGenerator gen = new LRGradientGenerator();
public void fetch(Input input) {
gen.fetch(input);
}
public void computeGradient(Output output){
gen.computeGradient(output);
}
@Override
public int getSize() {
return gen.getSize();
}
@Override
public double getFetchMiniBatchTime() {
return gen.getFetchMiniBatchTime();
}
@Override
public double getFetchModelTime() {
return gen.getFetchModelTime();
}
@Override
public double getComputeGradientsTime() {
return gen.getComputeGradientsTime();
}
}