-
Notifications
You must be signed in to change notification settings - Fork 14
/
NeuralNetwork.java
329 lines (262 loc) · 12.1 KB
/
NeuralNetwork.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
package basicneuralnetwork;
import basicneuralnetwork.activationfunctions.*;
import basicneuralnetwork.utilities.FileReaderAndWriter;
import basicneuralnetwork.utilities.MatrixUtilities;
import org.ejml.simple.SimpleMatrix;
import java.util.Arrays;
import java.util.Random;
/**
* Created by KimFeichtinger on 04.03.18.
*/
public class NeuralNetwork {
private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory();
private Random random = new Random();
// Dimensions of the neural network
private int inputNodes;
private int hiddenLayers;
private int hiddenNodes;
private int outputNodes;
private SimpleMatrix[] weights;
private SimpleMatrix[] biases;
private double learningRate;
private String activationFunctionKey;
// Constructor
// Generate a new neural network with 1 hidden layer with the given amount of nodes in the individual layers
public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes) {
this(inputNodes, 1, hiddenNodes, outputNodes);
}
// Constructor
// Generate a new neural network with a given amount of hidden layers with the given amount of nodes in the individual layers
// Every hidden layer will have the same amount of nodes
public NeuralNetwork(int inputNodes, int hiddenLayers, int hiddenNodes, int outputNodes) {
this.inputNodes = inputNodes;
this.hiddenLayers = hiddenLayers;
this.hiddenNodes = hiddenNodes;
this.outputNodes = outputNodes;
initializeDefaultValues();
initializeWeights();
initializeBiases();
}
// Copy constructor
public NeuralNetwork(NeuralNetwork nn) {
this.inputNodes = nn.inputNodes;
this.hiddenLayers = nn.hiddenLayers;
this.hiddenNodes = nn.hiddenNodes;
this.outputNodes = nn.outputNodes;
this.weights = new SimpleMatrix[hiddenLayers + 1];
this.biases = new SimpleMatrix[hiddenLayers + 1];
for (int i = 0; i < nn.weights.length; i++) {
this.weights[i] = nn.weights[i].copy();
}
for (int i = 0; i < nn.biases.length; i++) {
this.biases[i] = nn.biases[i].copy();
}
this.learningRate = nn.learningRate;
this.activationFunctionKey = nn.activationFunctionKey;
}
private void initializeDefaultValues() {
this.setLearningRate(0.1);
// Sigmoid is the default ActivationFunction
this.setActivationFunction(ActivationFunction.SIGMOID);
}
private void initializeWeights() {
weights = new SimpleMatrix[hiddenLayers + 1];
// Initialize the weights between the layers and fill them with random values
for (int i = 0; i < weights.length; i++) {
if (i == 0) { // 1st weights that connects inputs to first hidden layer
weights[i] = SimpleMatrix.random64(hiddenNodes, inputNodes, -1, 1, random);
} else if (i == weights.length - 1) { // last weights that connect last hidden layer to output
weights[i] = SimpleMatrix.random64(outputNodes, hiddenNodes, -1, 1, random);
} else { // everything else
weights[i] = SimpleMatrix.random64(hiddenNodes, hiddenNodes, -1, 1, random);
}
}
}
private void initializeBiases() {
biases = new SimpleMatrix[hiddenLayers + 1];
// Initialize the biases and fill them with random values
for (int i = 0; i < biases.length; i++) {
if (i == biases.length - 1) { // bias for last layer (output layer)
biases[i] = SimpleMatrix.random64(outputNodes, 1, -1, 1, random);
} else {
biases[i] = SimpleMatrix.random64(hiddenNodes, 1, -1, 1, random);
}
}
}
// Guess method, input is a one column matrix with the input values
public double[] guess(double[] input) {
if (input.length != inputNodes){
throw new WrongDimensionException(input.length, inputNodes, "Input");
} else {
// Get ActivationFunction-object from the map by key
ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
// Transform array to matrix
SimpleMatrix output = MatrixUtilities.arrayToMatrix(input);
for (int i = 0; i < hiddenLayers + 1; i++) {
output = calculateLayer(weights[i], biases[i], output, activationFunction);
}
return MatrixUtilities.getColumnFromMatrixAsArray(output, 0);
}
}
public void train(double[] inputArray, double[] targetArray) {
if (inputArray.length != inputNodes) {
throw new WrongDimensionException(inputArray.length, inputNodes, "Input");
} else if (targetArray.length != outputNodes) {
throw new WrongDimensionException(targetArray.length, outputNodes, "Output");
} else {
// Get ActivationFunction-object from the map by key
ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
// Transform 2D array to matrix
SimpleMatrix input = MatrixUtilities.arrayToMatrix(inputArray);
SimpleMatrix target = MatrixUtilities.arrayToMatrix(targetArray);
// Calculate the values of every single layer
SimpleMatrix layers[] = new SimpleMatrix[hiddenLayers + 2];
layers[0] = input;
for (int j = 1; j < hiddenLayers + 2; j++) {
layers[j] = calculateLayer(weights[j - 1], biases[j - 1], input, activationFunction);
input = layers[j];
}
for (int n = hiddenLayers + 1; n > 0; n--) {
// Calculate error
SimpleMatrix errors = target.minus(layers[n]);
// Calculate gradient
SimpleMatrix gradients = calculateGradient(layers[n], errors, activationFunction);
// Calculate delta
SimpleMatrix deltas = calculateDeltas(gradients, layers[n - 1]);
// Apply gradient to bias
biases[n - 1] = biases[n - 1].plus(gradients);
// Apply delta to weights
weights[n - 1] = weights[n - 1].plus(deltas);
// Calculate and set target for previous (next) layer
SimpleMatrix previousError = weights[n - 1].transpose().mult(errors);
target = previousError.plus(layers[n - 1]);
}
}
}
// Generates an exact copy of a NeuralNetwork
public NeuralNetwork copy(){
return new NeuralNetwork(this);
}
// Merges the weights and biases of two NeuralNetworks and returns a new object
// Merge-ratio: 50:50 (half of the values will be from nn1 and other half from nn2)
public NeuralNetwork merge(NeuralNetwork nn){
return this.merge(nn, 0.5);
}
// Merges the weights and biases of two NeuralNetworks and returns a new object
// Everything besides the weights and biases will be the same
// of the object on which this method is called (Learning Rate, activation function, etc.)
// Merge-ratio: defined by probability
public NeuralNetwork merge(NeuralNetwork nn, double probability){
// Check whether the nns have the same dimensions
if(!Arrays.equals(this.getDimensions(), nn.getDimensions())){
throw new WrongDimensionException(this.getDimensions(), nn.getDimensions());
}else{
NeuralNetwork result = this.copy();
for (int i = 0; i < result.weights.length; i++) {
result.weights[i] = MatrixUtilities.mergeMatrices(this.weights[i], nn.weights[i], probability);
}
for (int i = 0; i < result.biases.length; i++) {
result.biases[i] = MatrixUtilities.mergeMatrices(this.biases[i], nn.biases[i], probability);
}
return result;
}
}
// Gaussian mutation with given probability, Slightly modifies values (weights + biases) with given probability
// Probability: number between 0 and 1
// Depending on probability more/ less values will be mutated (e.g. prob = 1.0: all the values will be mutated)
public void mutate(double probability) {
applyMutation(weights, probability);
applyMutation(biases, probability);
}
// Adds a randomly generated gaussian number to each element of a Matrix in an array of matrices
// Probability: determines how many values will be modified
private void applyMutation(SimpleMatrix[] matrices, double probability) {
for (SimpleMatrix matrix : matrices) {
for (int j = 0; j < matrix.getNumElements(); j++) {
if (random.nextDouble() < probability) {
double offset = random.nextGaussian() / 2;
matrix.set(j, matrix.get(j) + offset);
}
}
}
}
// Generic function to calculate one layer
private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, SimpleMatrix input, ActivationFunction activationFunction) {
// Calculate outputs of layer
SimpleMatrix result = weights.mult(input);
// Add bias to outputs
result = result.plus(bias);
// Apply activation function and return result
return applyActivationFunction(result, false, activationFunction);
}
private SimpleMatrix calculateGradient(SimpleMatrix layer, SimpleMatrix error, ActivationFunction activationFunction) {
SimpleMatrix gradient = applyActivationFunction(layer, true, activationFunction);
gradient = gradient.elementMult(error);
return gradient.scale(learningRate);
}
private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer) {
return gradient.mult(layer.transpose());
}
// Applies an activation function to a matrix
// An object of an implementation of the ActivationFunction-interface has to be passed
// The function in this class will be to the matrix
private SimpleMatrix applyActivationFunction(SimpleMatrix input, boolean derivative, ActivationFunction activationFunction) {
// Applies either derivative of activation function or regular activation function to a matrix and returns the result
return derivative ? activationFunction.applyDerivativeOfActivationFunctionToMatrix(input)
: activationFunction.applyActivationFunctionToMatrix(input);
}
public void writeToFile() {
FileReaderAndWriter.writeToFile(this, null);
}
public void writeToFile(String fileName) {
FileReaderAndWriter.writeToFile(this, fileName);
}
public static NeuralNetwork readFromFile() {
return FileReaderAndWriter.readFromFile(null);
}
public static NeuralNetwork readFromFile(String fileName) {
return FileReaderAndWriter.readFromFile(fileName);
}
public String getActivationFunctionName() {
return activationFunctionKey;
}
public void setActivationFunction(String activationFunction) {
this.activationFunctionKey = activationFunction;
}
public void addActivationFunction(String key, ActivationFunction activationFunction){
activationFunctionFactory.addActivationFunction(key, activationFunction);
}
public double getLearningRate() {
return learningRate;
}
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}
public int getInputNodes() {
return inputNodes;
}
public int getHiddenLayers() {
return hiddenLayers;
}
public int getHiddenNodes() {
return hiddenNodes;
}
public int getOutputNodes() {
return outputNodes;
}
public SimpleMatrix[] getWeights() {
return weights;
}
public void setWeights(SimpleMatrix[] weights) {
this.weights = weights;
}
public SimpleMatrix[] getBiases() {
return biases;
}
public void setBiases(SimpleMatrix[] biases) {
this.biases = biases;
}
public int[] getDimensions(){
return new int[]{inputNodes, hiddenLayers, hiddenNodes, outputNodes};
}
}