Very simple multilayer neural network with backpropagation and genetic algorithm.
- Multilayer networks.
- Backptopagation method.
- Methods for genetic algorithms.
- Network visualizer.
- Simple matrix class.
- Useful static math methods.
- Three example scenes.
If you don't care about the example scenes, then copy the NetworkScripts folder to your Assets folder.
This line will generate a neural network with two inputs, one hidden layer (three neurons in hidden layers) and one output.
NeuralNetwork nn = new NeuralNetwork(2, 1, 3, 1);Getting the network prediction:
float[] inputs = { 0, 1 };
float[] output = nn.FeedForward(inputs);
Debug.Log(output[0]); // network prediction (number between 0 and 1)You can train your network like this:
// Training method from XOR example
// supervised learning
void Train()
{
// training 5k times
float[] inputs = new float[2];
float[] targets = new float[1];
for (int i = 0; i < 5000; i++)
{
// randomizing data
switch (Random.Range(0, 4))
{
case 0:
inputs[0] = 0;
inputs[1] = 0;
targets[0] = 0;
break;
case 1:
inputs[0] = 0;
inputs[1] = 1;
targets[0] = 1;
break;
case 2:
inputs[0] = 1;
inputs[1] = 0;
targets[0] = 1;
break;
default:
inputs[0] = 1;
inputs[1] = 1;
targets[0] = 0;
break;
}
nn.TrainNetwork(inputs, targets);
}
// printing network predictions after training
Debug.Log("[0, 0] -> " + nn.FeedForward(new float[] { 0, 0 })[0]);
Debug.Log("[0, 1] -> " + nn.FeedForward(new float[] { 0, 1 })[0]);
Debug.Log("[1, 0] -> " + nn.FeedForward(new float[] { 1, 0 })[0]);
Debug.Log("[1, 1] -> " + nn.FeedForward(new float[] { 1, 1 })[0]);
}If you want to visualize your network, you can use:
// public void DrawNetwork | Attributes: (NeuralNetwork network, int size, int layerGap, Color neuronColor, Color connectionStrong, Color connectionWeak, Color background)
public NetworkVisualizer visualizer;
visualizer.DrawNetwork(nn, 400, 5, Color.cyan, Color.red, Color.blue, new Color(1, 1, 1, 0.3f));Some useful extra methods:
// public static NeuralNetwork Crossover | Attributes: (NeuralNetwork nn1, NeuralNetwork nn2, float mutationPercent)
NeuralNetwork parent1;
NeuralNetwork parent2;
NeuralNetwork child = NeuralNetwork.Crossover(parent1, parent2, 5);
// public static float GetDistBetweenPoints | Attributes: (float x1, float y1, float x2, float y2)
float distance = StaticMath.GetDistBetweenPoints(hit.point.x, hit.point.y, transform.position.x, transform.position.y);
// public static float GetAngleBetweenPoints | Attributes: (float x1, float y1, float x2, float y2)
float angle = StaticMath.GetAngleBetweenPoints(hit.point.x, hit.point.y, transform.position.x, transform.position.y);
// public static float Remap | Attributes: (float value, float from1, float to1, float from2, float to2)
float newValue = StaticMath.Remap(50, 0, 100, 0, 1); // newValue = 0.5f
