Skip to content
This repository has been archived by the owner on Aug 3, 2023. It is now read-only.

Commit

Permalink
fix: struct pointers as map keys
Browse files Browse the repository at this point in the history
  • Loading branch information
maolonglong committed Oct 16, 2021
1 parent 9a222a7 commit 0205d7d
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 13 deletions.
48 changes: 48 additions & 0 deletions counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2021 go-mcts. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package mcts

import (
"fmt"
)

// Struct pointers as map keys is not work correctly.
// see https://abhinavg.net/posts/pointers-as-map-keys/
//
// use fmt.Sprintf("%v", key) as map keys
type counter struct {
m map[string]*entry
}

type entry struct {
key interface{}
count float64
}

func newCounter() *counter {
return &counter{make(map[string]*entry)}
}

func (c *counter) incr(key interface{}, count float64) {
s := fmt.Sprintf("%v", key)
if ent, ok := c.m[s]; ok {
ent.count += count
} else {
c.m[s] = &entry{key, 1}
}
}

func (c *counter) get(key interface{}) float64 {
if ent, ok := c.m[fmt.Sprintf("%v", key)]; ok {
return ent.count
}
return 0
}

func (c *counter) rng(f func(key interface{}, count float64)) {
for _, ent := range c.m {
f(ent.key, ent.count)
}
}
32 changes: 32 additions & 0 deletions counter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2021 go-mcts. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package mcts

import (
"testing"

"github.com/stretchr/testify/assert"
)

type someStructPointer struct {
s string
}

func newPointer(s string) *someStructPointer {
return &someStructPointer{s}
}

func TestCounter(t *testing.T) {
m := make(map[*someStructPointer]int)
m[newPointer("abc")]++
m[newPointer("abc")]++
assert.Equal(t, 2, len(m))
assert.Equal(t, 0, m[newPointer("abc")])

c := newCounter()
c.incr(newPointer("abc"), 1)
c.incr(newPointer("abc"), 1)
assert.Equal(t, float64(2), c.get(newPointer("abc")))
}
126 changes: 126 additions & 0 deletions examples/tictactoe/tictactoe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2021 go-mcts. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package tictactoe

import (
"math/rand"

"github.com/go-mcts/mcts"
)

var (
_ mcts.Move = (*move)(nil)
_ mcts.State = (*state)(nil)
)

type move struct {
x int
y int
v int
}

type state struct {
playerToMove int
board [3][3]int
}

func (s *state) PlayerToMove() int {
return s.playerToMove
}

func (s *state) HasMoves() bool {
return s.getResult(s.playerToMove) == -1
}

func (s *state) GetMoves() []mcts.Move {
moves := make([]mcts.Move, 0)
if s.getResult(s.playerToMove) == -1 {
for i := 0; i < 3; i++ {
for j := 0; j < 3; j++ {
if s.board[i][j] == 0 {
m := &move{
x: i,
y: j,
v: s.playerToMove,
}
if s.playerToMove == 1 {
m.v = 1
} else {
m.v = -1
}
moves = append(moves, m)
}
}
}
}
return moves
}

func (s *state) DoMove(mctsMove mcts.Move) {
m := mctsMove.(*move)
if m.x < 0 || m.y < 0 || m.x > 2 || m.y > 2 || s.board[m.x][m.y] != 0 {
panic("illegal move")
}
s.board[m.x][m.y] = m.v
s.playerToMove = 3 - s.playerToMove
}

func (s *state) DoRandomMove(rd *rand.Rand) {
moves := s.GetMoves()
s.DoMove(moves[rd.Intn(len(moves))])
}

func (s *state) GetResult(currentPlayerToMove int) float64 {
if result := s.getResult(currentPlayerToMove); result == -1 {
panic("game is not over")
} else {
return result
}
}

func (s *state) getResult(currentPlayerToMove int) float64 {
zero := 0

for i := 0; i < 3; i++ {
row, col := 0, 0
for j := 0; j < 3; j++ {
if s.board[i][j] == 0 {
zero++
}
row += s.board[i][j]
col += s.board[j][i]
}

if row == 3 || row == -3 || col == 3 || col == -3 {
if s.playerToMove == currentPlayerToMove {
return 1
}
return 0
}
}

tl := s.board[0][0] + s.board[1][1] + s.board[2][2]
tr := s.board[0][2] + s.board[1][1] + s.board[2][0]

if tl == 3 || tr == 3 || tl == -3 || tr == -3 {
if s.playerToMove == currentPlayerToMove {
return 1
}
return 0
}

if zero == 0 {
return 0.5
}

return -1
}

func (s *state) Clone() mcts.State {
return &state{
playerToMove: s.playerToMove,
board: s.board,
}
}
43 changes: 43 additions & 0 deletions examples/tictactoe/tictactoe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2021 go-mcts. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package tictactoe

import (
"testing"

"github.com/go-mcts/mcts"
"github.com/stretchr/testify/assert"
)

func TestTicTacToe(t *testing.T) {
rootState := &state{
playerToMove: 1,
board: [3][3]int{
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
},
}
mctsMove := mcts.ComputeMove(rootState, mcts.MaxIterations(20000), mcts.Verbose(true))
m := mctsMove.(*move)
assert.Equal(t, 1, m.x)
assert.Equal(t, 1, m.y)
assert.Equal(t, 1, m.v)

rootState = &state{
playerToMove: 1,
board: [3][3]int{
{0, 0, 0},
{0, 1, 0},
{0, -1, 0},
},
}
mctsMove = mcts.ComputeMove(rootState, mcts.Verbose(true))
m = mctsMove.(*move)
assert.Equal(t, 1, m.v)

assert.True(t, m.x == 0 && (m.y == 0 || m.y == 2) ||
m.x == 2 && (m.y == 0 || m.y == 2))
}
2 changes: 2 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func Goroutines(number int) Option {
}

// MaxIterations maximum number of iterations, default is 10000
//
// iter < 0: not limit
func MaxIterations(iter int) Option {
return func(o *Options) {
o.MaxIterations = iter
Expand Down
27 changes: 14 additions & 13 deletions uct.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,41 +91,42 @@ func ComputeMove(rootState State, opts ...Option) Move {
}()
}

visits := make(map[Move]int)
wins := make(map[Move]float64)
visits := newCounter()
wins := newCounter()
gamePlayed := 0
for i := 0; i < options.Goroutines; i++ {
root := <-rootFutures
gamePlayed += root.visits
for _, c := range root.children {
visits[c.move] += c.visits
wins[c.move] += c.wins
visits.incr(c.move, float64(c.visits))
wins.incr(c.move, c.wins)
}
}

bestScore := float64(-1)
var bestMove Move
for move, v := range visits {
w := wins[move]
expectedSuccessRate := (w + 1) / (float64(v) + 2)
visits.rng(func(key interface{}, v float64) {
move := key.(Move)
w := wins.get(move)
expectedSuccessRate := (w + 1) / (v + 2)
if expectedSuccessRate > bestScore {
bestMove = move
bestScore = expectedSuccessRate
}

if options.Verbose {
Debugf("Move: %v (%2d%% visits) (%2d%% wins)",
move, int(100.0*float64(v)/float64(gamePlayed)+0.5), int(100.0*w/float64(v)+0.5))
move, int(100.0*v/float64(gamePlayed)+0.5), int(100.0*w/v+0.5))
}
}
})

if options.Verbose {
bestWins := wins[bestMove]
bestVisits := visits[bestMove]
bestWins := wins.get(bestMove)
bestVisits := visits.get(bestMove)
Debugf("Best: %v (%2d%% visits) (%2d%% wins)",
bestMove,
int(100.0*float64(bestVisits)/float64(gamePlayed)+0.5),
int(100.0*bestWins/float64(bestVisits)+0.5),
int(100.0*bestVisits/float64(gamePlayed)+0.5),
int(100.0*bestWins/bestVisits+0.5),
)

now := time.Now()
Expand Down

0 comments on commit 0205d7d

Please sign in to comment.