Skip to content

Commit

Permalink
adding network
Browse files Browse the repository at this point in the history
  • Loading branch information
kelindar committed Oct 3, 2020
1 parent 43c4b97 commit bc75f4e
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 10 deletions.
4 changes: 2 additions & 2 deletions binary/genome.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func (g Genome) Mutate() {
g[i] = randByte()
}

// Make creates a function for a random genome string
func Make(length int) func() evolve.Genome {
// New creates a function for a random genome string
func New(length int) evolve.Genesis {
return func() evolve.Genome {
v := make(Genome, length)
crand.Read(v)
Expand Down
2 changes: 1 addition & 1 deletion binary/genome_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestEvolve(t *testing.T) {
}

fit := fitnessFor(target)
pop := evolve.New(population, fit, binary.Make(len(target)))
pop := evolve.New(population, fit, binary.New(len(target)))

// Evolve
i, last := 0, ""
Expand Down
2 changes: 1 addition & 1 deletion evolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestEvolve(t *testing.T) {
}

fit := fitnessFor(target)
pop := evolve.New(population, fit, binary.Make(len(target)))
pop := evolve.New(population, fit, binary.New(len(target)))

// Evolve
i, last := 0, ""
Expand Down
5 changes: 1 addition & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,4 @@ module github.com/kelindar/evolve

go 1.14

require (
github.com/kelindar/rand v1.0.1
github.com/stretchr/testify v1.6.1
)
require github.com/stretchr/testify v1.6.1
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kelindar/rand v1.0.1 h1:PfCe86AM7iprR8lWOj4AWiYcdbT+b3iaLmFMRxjjihc=
github.com/kelindar/rand v1.0.1/go.mod h1:Ps9zsneYaqEckQIYsC/6I/VkYv+zVEGeUXs/WQIGpWw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
Expand Down
120 changes: 120 additions & 0 deletions neural/graph.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.

package neural

import (
"math"
"sort"
"sync/atomic"
)

var serial uint32

// Next generates a next sequence number.
func next() uint32 {
return atomic.AddUint32(&serial, 1)
}

// ----------------------------------------------------------------------------------

// Node represents a neuron in the network
type neuron struct {
Serial uint32 // The innovation serial number
Conns []synapse // The incoming connections
value float64 // The output value (for activation)
}

// makeNeuron creates a new neuron.
func makeNode() neuron {
return neuron{
Serial: next(),
}
}

// Value returns the value for the neuron
func (n *neuron) Value() float64 {
if n.value != 0 || len(n.Conns) == 0 {
return n.value
}

// Sum of the weighted inputs to the neuron
s := 0.0
for _, c := range n.Conns {
if c.Active {
s += c.Weight * c.From.Value()
}
}

// Keep the value to avoid recalculating
n.value = sigmoid(s)
return n.value
}

// connected checks whether the two neurons are connected or not.
func (n *neuron) connected(neuron *neuron) bool {
return searchNode(n, neuron) || searchNode(neuron, n)
}

// Sigmod activation function.
func sigmoid(x float64) float64 {
return 1.0 / (1 + math.Exp(-x))
}

// searchNode searches whether incoming connections of "to" contain a "from" neuron.
func searchNode(from, to *neuron) bool {
x := from.Serial
i := sort.Search(len(to.Conns), func(i int) bool {
return to.Conns[i].From.Serial >= x
})
return i < len(to.Conns) && to.Conns[i].From == from
}

// ----------------------------------------------------------------------------------

// Nodes represents a set of neurons
type neurons []neuron

// makeNodes creates a new neuron array.
func makeNodes(count int) neurons {
arr := make(neurons, 0, count)
for i := 0; i < count; i++ {
arr = append(arr, makeNode())
}
return arr
}

// ----------------------------------------------------------------------------------

// Synapse represents a synapse for the NEAT network.
type synapse struct {
Serial uint32 // The innovation serial number
Weight float64 // The weight of the connection
Active bool // Whether the connection is enabled or not
From, To *neuron // The neurons of the connection
}

// ID returns a unique key for the edge.
func (c *synapse) ID() uint64 {
return (uint64(c.To.Serial) << 32) | (uint64(c.From.Serial) & 0xffffffff)
}

// ----------------------------------------------------------------------------------

// sortedByNode represents a connection list which is sorted by neuron ID
type sortedByNode []synapse

// Len returns the number of connections.
func (c sortedByNode) Len() int {
return len(c)
}

// Less compares two connections in the slice.
func (c sortedByNode) Less(i, j int) bool {
return c[i].ID() < c[j].ID()
}

// Swap swaps two connections
func (c sortedByNode) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
29 changes: 29 additions & 0 deletions neural/graph_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.

package neural

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestConnected(t *testing.T) {
n := makeNodes(2)
n0, n1 := &n[0], &n[1]

// Disjoint
assert.False(t, n0.connected(n1))
assert.False(t, n1.connected(n0))

// Connect
n1.Conns = append(n1.Conns, synapse{
From: n0,
To: n1,
})

// Connected
assert.True(t, n0.connected(n1))
assert.True(t, n1.connected(n0))
}
120 changes: 120 additions & 0 deletions neural/network.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.

package neural

import (
"math"
"sort"

"github.com/kelindar/evolve"
)

// Network represents a neural network.
type Network struct {
input neurons
hidden neurons
output neurons
conns []synapse
}

// New creates a new neural network.
func New(in, out int) *Network {
nn := &Network{
input: makeNodes(in + 1),
output: makeNodes(out),
conns: make([]synapse, 0, 256),
}

// Bias neuron
nn.input[in].value = 1.0
return nn
}

// Predict activates the network
func (n *Network) Predict(input, output []float64) []float64 {
if output == nil {
output = make([]float64, len(n.output))
}

// Set the values for the input neurons
for i, v := range input {
n.input[i].value = v
}

// Clean the hidden neurons values
for i := range n.hidden {
n.hidden[i].value = 0
}

// Retrieve values and sum up exponentials
sum := 0.0
for i, neuron := range n.output {
v := math.Exp(neuron.Value())
output[i] = v
sum += v
}

// Normalize
for i := range output {
output[i] /= sum
}
return output
}

// sort sorts the connections depending on the neuron and assigns connection slices
// to the appropriate neurons for activation.
func (n *Network) sort() {
if len(n.conns) == 0 {
return
}

// Sort by neuron ID
sort.Sort(sortedByNode(n.conns))

// Assign connection slices to neurons
prev, lo := n.conns[0].To, 0
curr, hi := n.conns[0].To, 0
for i, conn := range n.conns {
curr, hi = conn.To, i
if prev != curr {
prev.Conns = n.conns[lo:hi]
prev, lo = curr, hi
}
}

// Last neuron
prev.Conns = n.conns[lo : hi+1]
}

// connect connects two neurons together.
func (n *Network) connect(from, to *neuron, weight float64) {
defer n.sort() // Keep sorted
n.conns = append(n.conns, synapse{
Serial: next(), // Innovation number
From: from, // Left neuron
To: to, // Right neuron
Weight: weight, // Weight for the connection
Active: true, // Default to active
})
}

// Mutate mutates the network.
func (n *Network) Mutate() {
defer n.sort() // Keep sorted

}

func (n *Network) Crossover(p1, p2 evolve.Genome) {

}

// Equal checks whether the connection is equal to another connection
/*func (c *conn) Equal(other *conn) bool {
return c.From == other.From && c.To == other.To
}*/

// https://github.com/Luecx/NEAT/tree/master/vid%209/src

// https://sausheong.github.io/posts/how-to-build-a-simple-artificial-neural-network-with-go/
// https://stats.stackexchange.com/questions/459491/how-do-i-use-matrix-math-in-irregular-neural-networks-generated-from-neuroevolut
57 changes: 57 additions & 0 deletions neural/network_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.

package neural

import (
"testing"

"github.com/stretchr/testify/assert"
)

func BenchmarkPredict(b *testing.B) {
b.Run("2x2", func(b *testing.B) {
nn := make2x2()
in := []float64{1, 0}
out := []float64{0, 0}

b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
nn.Predict(in, out)
}
})

}

func TestPredict(t *testing.T) {
nn := make2x2()
i0 := &nn.input[0]
i1 := &nn.input[1]
o0 := &nn.output[0]
o1 := &nn.output[1]

// must be connected
assert.True(t, i0.connected(o0))
assert.True(t, i1.connected(o0))
assert.True(t, i0.connected(o0))
assert.False(t, i1.connected(o1))

r := nn.Predict([]float64{0.5, 1}, nil)
assert.Equal(t, []float64{0.5216145455966438, 0.4783854544033563}, r)
}

// make2x2 creates a 2x2 tiny network
func make2x2() *Network {
nn := New(2, 2)
i0 := &nn.input[0]
i1 := &nn.input[1]
o0 := &nn.output[0]
o1 := &nn.output[1]

// connect inputs to output
nn.connect(i0, o0, .5)
nn.connect(i1, o0, .5)
nn.connect(i0, o1, .75)
return nn
}

0 comments on commit bc75f4e

Please sign in to comment.