Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
lpapailiou committed Jan 28, 2021
1 parent 4f381ed commit 31961df
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Add following snippets to your ``pom.xml`` file to import the library:
<dependency>
<groupId>neuralnetwork</groupId>
<artifactId>neural-network-repo</artifactId>
<version>2.3</version>
<version>2.4</version>
</dependency>
</dependencies>

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>neuralnetwork</groupId>
<artifactId>neural-network-repo</artifactId>
<version>2.3</version>
<version>2.4</version>
<build>
<plugins>
<plugin>
Expand Down
17 changes: 11 additions & 6 deletions src/main/java/neuralnet/NeuralNetwork.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class NeuralNetwork implements Serializable {
private double initialLearningRate;
private double learningRate;
private double momentum;
private int iteration_count;
private int iterationCount;

static {
URL path = NeuralNetwork.class.getClassLoader().getResource("neuralnetwork.properties");
Expand Down Expand Up @@ -150,7 +150,7 @@ public List<Double> learn(double[] inputNodes, double[] expectedOutputNodes) {
Matrix target = Matrix.fromArray(rectifier, expectedOutputNodes);

// backward propagate to adjust weights in layers
iteration_count++;
iterationCount++;
Matrix error = null;
for (int i = steps.size()-1; i >= 0; i--) {
if (error == null) {
Expand Down Expand Up @@ -219,7 +219,12 @@ public NeuralNetwork mutate() {
*/
public NeuralNetwork copy() {
NeuralNetwork neuralNetwork = new NeuralNetwork(inputLayerNodes, layers);
neuralNetwork.setRectifier(this.rectifier).setLearningRateDescent(this.learningRateDescent).setLearningRate(this.learningRate).setMomentum(this.momentum);
neuralNetwork.rectifier = this.rectifier;
neuralNetwork.learningRateDescent = this.learningRateDescent;
neuralNetwork.momentum = this.momentum;
neuralNetwork.initialLearningRate = this.initialLearningRate;
neuralNetwork.learningRate = this.learningRate;
neuralNetwork.iterationCount = this.iterationCount;
return neuralNetwork;
}

Expand Down Expand Up @@ -295,8 +300,8 @@ public double getLearningRate() {
* Decreases the current learning rate according to the chosen LearningRateDescent function.
*/
public void decreaseLearningRate() {
this.learningRate = learningRateDescent.decrease(initialLearningRate, momentum, iteration_count);
iteration_count++;
this.learningRate = learningRateDescent.decrease(initialLearningRate, momentum, iterationCount);
iterationCount++;
}

/**
Expand Down Expand Up @@ -389,7 +394,7 @@ public String toString() {

@Override
public int hashCode() {
return Integer.parseInt(learningRate + "" + iteration_count + layers.size());
return Integer.parseInt(learningRate + "" + iterationCount + layers.size());
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/util/LearningRateDecreaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void descentTest(LearningRateDescent lrd) {
System.out.println("testing with: " + lrd.getDescription());
int iter = 250;
for (int i = 0; i < iter; i++) {
lr = lrd.decrease(lrinit, 0.005, i);
lr = lrd.decrease(lr, 0.005, i);
}
System.out.println("learning rate " + lr + " after " + iter + " iterations.\n");
}
Expand Down

0 comments on commit 31961df

Please sign in to comment.