Skip to content

Commit

Permalink
Merge pull request #1 from cshenton/xnes
Browse files Browse the repository at this point in the history
options, examples
  • Loading branch information
cshenton committed Mar 28, 2018
2 parents c99dacd + 652bb11 commit 90778fd
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 58 deletions.
42 changes: 24 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,39 @@ Just `go get github.com/cshenton/opt` to install.

## Basic Usage

Opt uses evolution strategies optimisers under the hood, which enables a simple API.
First, `Search()` against an optimiser to get a test point, then evaluate its score how
you see fit, then `Show()` that score with the test seed to the optimiser. Keep going
until you converge.
Opt uses [natural evolution strategies](http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf)
optimisers under the hood, which enables a simple API. First, `Search()` against an optimiser
to get a test point, then evaluate its score how you see fit, then `Show()` that score with
the test seed to the optimiser. Keep going until you converge.

```go
package main

import (
"github.com/cshenton/opt"
"github.com/cshenton/opt/bench"
"fmt"

"github.com/cshenton/opt"
"github.com/cshenton/opt/bench"
)

func main() {
n := 1000
o := opt.NewSNES(...)

for i := 0; i < n; i++ {
point, seed := o.Search()
// Minimise the rastrigin function
score := -bench.Rastrigin(point)
o.Show(score, seed)
}

final, _ := o.Search()
fmt.Println(final)
n := 10000

op := opt.DefaultOptions
op.LearningRate = 0.5
o := opt.NewSNES(10, op)

for i := 0; i < n; i++ {
point, seed := o.Search()
// Minimise the 10-dimensional sphere function
score := -bench.Sphere(point)
o.Show(score, seed)
}

final, _ := o.Search()
fmt.Println(final)
}

```

Want to use more workers on the same machine to speed up evaluations? Just do it,
Expand Down
2 changes: 1 addition & 1 deletion examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func main() {
o := opt.NewSNES(10, 10, 42, 0.01, false)
o := opt.NewSNES(10, opt.DefaultOptions)
t := time.Now()
p := 1.0
n := 0
Expand Down
2 changes: 1 addition & 1 deletion examples/multithread/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func main() {
o := opt.NewSNES(10, 10, 42, 0.01, false)
o := opt.NewSNES(10, opt.DefaultOptions)
w := 8
p := 1.0
done := make(chan int)
Expand Down
17 changes: 0 additions & 17 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,3 @@ type Float64Searcher interface {
// Informs the searcher of the score a particular seeded draw achieved
Show(score float64, seed int64)
}

// searchReq is the request sent in a Search call.
type searchReq struct {
respChan chan<- searchResp
}

// searchReq is the response received in a Search call.
type searchResp struct {
point []float64
seed int64
}

// showReq is the request sent is a Show call.
type showReq struct {
score float64
seed int64
}
2 changes: 1 addition & 1 deletion interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ package opt
import "testing"

func TestFloat64SearcherSNES(t *testing.T) {
s := NewSNES(10, 10, 3, 0.1, false)
s := NewSNES(10, DefaultOptions)
_ = Float64Searcher(s)
}
18 changes: 18 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package opt

// Options is a configuration struct which can holds options commond to several of
// the optimisers in the package.
type Options struct {
Adaptive bool // Whether to use an adaptive learning rate, if available
GenerationSize uint // Number of score evaluations per gradient update
LearningRate float64 // The learning rate
RandomSeed uint64 // Used to seed the source used to generate seeds for searches
}

// DefaultOptions is the recommended set of initial options for most optimisers.
var DefaultOptions = &Options{
Adaptive: true,
GenerationSize: 10,
LearningRate: 0.01,
RandomSeed: 24601,
}
23 changes: 12 additions & 11 deletions snes.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ type SNES struct {

const initScale = 1e3

// NewSNES creates a SNES optimiser with the provided parameters and starts its run goroutine.
func NewSNES(len, size uint, seed int64, rate float64, adaptive bool) (s *SNES) {
scale := make([]float64, len)
// NewSNES creates a SNES optimiser over the d-dimensional real numbers, using the provided
// options for the optimiser.
func NewSNES(d uint, o *Options) (s *SNES) {
scale := make([]float64, d)
for i := range scale {
scale[i] = initScale
}

s = &SNES{
size: size,
size: o.GenerationSize,
showCount: 0,
searchCount: 0,
scores: make([]float64, size),
seeds: make([]int64, size),
scores: make([]float64, o.GenerationSize),
seeds: make([]int64, o.GenerationSize),

len: len,
loc: make([]float64, len),
len: d,
loc: make([]float64, d),
scale: scale,

rate: rate,
adaptive: adaptive,
source: rand.New(rand.NewSource(uint64(seed))),
rate: o.LearningRate,
adaptive: o.Adaptive,
source: rand.New(rand.NewSource(uint64(o.RandomSeed))),

Mutex: &sync.Mutex{},
}
Expand Down
31 changes: 22 additions & 9 deletions snes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
func TestNewSNES(t *testing.T) {
tt := []struct {
name string
len uint
dim uint
size uint
seed int64
seed uint64
rate float64
adaptive bool
}{
Expand All @@ -20,10 +20,16 @@ func TestNewSNES(t *testing.T) {

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
s := NewSNES(tc.len, tc.size, tc.seed, tc.rate, tc.adaptive)
o := &Options{
Adaptive: tc.adaptive,
GenerationSize: tc.size,
LearningRate: tc.rate,
RandomSeed: tc.seed,
}
s := NewSNES(tc.dim, o)

if s.len != tc.len {
t.Errorf("expected len %v, but got %v", tc.len, s.len)
if s.len != tc.dim {
t.Errorf("expected len %v, but got %v", tc.dim, s.len)
}
if s.size != tc.size {
t.Errorf("expected size %v, but got %v", tc.size, s.size)
Expand All @@ -39,7 +45,7 @@ func TestNewSNES(t *testing.T) {
}

func TestSNESSearch(t *testing.T) {
s := NewSNES(3, 10, 42, 0.1, false)
s := NewSNES(3, DefaultOptions)

point, seed := s.Search()

Expand All @@ -57,11 +63,16 @@ func TestSNESSearch(t *testing.T) {
}

func TestSNESSearchBlock(t *testing.T) {
s := NewSNES(3, 10, 42, 0.1, false)
s := NewSNES(3, DefaultOptions)
for i := 0; i < 10; i++ {
_, _ = s.Search()
}

type searchResp struct {
point []float64
seed int64
}

result := make(chan searchResp, 1)
timeout := make(chan bool, 1)

Expand All @@ -85,7 +96,9 @@ func TestSNESSearchBlock(t *testing.T) {
}

func TestSNESShow(t *testing.T) {
s := NewSNES(3, 2, 42, 0.1, false)
o := DefaultOptions
o.GenerationSize = 2
s := NewSNES(3, o)

s.Show(20, 42)
if s.showCount != 1 {
Expand Down Expand Up @@ -114,7 +127,7 @@ func TestSNESShow(t *testing.T) {

func TestSNESmakeNoise(t *testing.T) {
var seed int64 = 404
s := NewSNES(10, 10, 42, 0.1, false)
s := NewSNES(10, DefaultOptions)

n1 := s.makeNoise(seed)
n2 := s.makeNoise(seed)
Expand Down
64 changes: 64 additions & 0 deletions xnes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package opt

import (
"sync"

"golang.org/x/exp/rand"
)

// XNES is the Exponential Natural Evolution Strategies optimiser. It is an NES optimiser
// that uses a multinormal search distribution, taking advantage of a closed form computation
// of the fisher information matrix.
type XNES struct {
// Generation data
size uint
searchCount uint
showCount uint
scores []float64
seeds []int64

// Search distribution parameters
len uint
loc []float64
scale float64
shape []float64

// Search hyperparameters
rate float64
adaptive bool

// Noise source
source *rand.Rand

// Mutex
*sync.Mutex
}

// NewXNES creates an XNES optimiser over the d-dimensional real numbers, using the provided
// options for the optimiser.
// func NewXNES(d uint, o *Options) (s *SNES) {
// scale := make([]float64, d)
// for i := range scale {
// scale[i] = initScale
// }

// s = &SNES{
// size: o.GenerationSize,
// showCount: 0,
// searchCount: 0,
// scores: make([]float64, o.GenerationSize),
// seeds: make([]int64, o.GenerationSize),

// len: d,
// loc: make([]float64, d),
// scale: scale,
// //shape: IDENTITY MATRIX

// rate: o.LearningRate,
// adaptive: o.Adaptive,
// source: rand.New(rand.NewSource(uint64(o.RandomSeed))),

// Mutex: &sync.Mutex{},
// }
// return s
// }
1 change: 1 addition & 0 deletions xnes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package opt

0 comments on commit 90778fd

Please sign in to comment.