Skip to content

Commit

Permalink
Merge pull request #339 from domino14/welch
Browse files Browse the repository at this point in the history
use welch's t-test for stopping condition
  • Loading branch information
domino14 committed Jul 8, 2024
2 parents 17c4e36 + 3c9729e commit df47064
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 59 deletions.
6 changes: 4 additions & 2 deletions montecarlo/montecarlo.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ type StoppingCondition int

const (
StopNone StoppingCondition = iota
Stop90
Stop95
Stop98
Stop99
Stop999
)

type InferenceMode int
Expand Down Expand Up @@ -648,8 +650,8 @@ func (s *Simmer) EquityStats() string {
fmt.Fprintf(&ss, "%-20s%-9s%-16s%-16s\n", "Play", "Score", "Win%", "Equity")

for _, play := range s.plays {
wpStats := fmt.Sprintf("%.3f±%.3f", 100.0*play.winPctStats.Mean(), 100.0*play.winPctStats.StandardError(stats.Z99))
eqStats := fmt.Sprintf("%.3f±%.3f", play.equityStats.Mean(), play.equityStats.StandardError(stats.Z99))
wpStats := fmt.Sprintf("%.2f±%.2f", 100.0*play.winPctStats.Mean(), 100.0*stats.Z99*play.winPctStats.StandardError())
eqStats := fmt.Sprintf("%.2f±%.2f", play.equityStats.Mean(), stats.Z99*play.equityStats.StandardError())
ignore := ""
if play.ignore {
ignore = "❌"
Expand Down
38 changes: 25 additions & 13 deletions montecarlo/stopping_condition.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package montecarlo

import (
"math"
"sort"

"github.com/domino14/word-golib/tilemapping"
Expand Down Expand Up @@ -36,6 +37,7 @@ func (s *Simmer) shouldStop(iterationCount uint64,
// count ignored plays
ignoredPlays := 0
bottomUnignoredWinPct := 0.0
bottomUnignoredSerr := 0.0
for i := range c {
c[i] = plays[i]
c[i].RLock()
Expand Down Expand Up @@ -66,6 +68,7 @@ func (s *Simmer) shouldStop(iterationCount uint64,
c[i].RLock()
if !c[i].ignore {
bottomUnignoredWinPct = c[i].winPctStats.Mean()
bottomUnignoredSerr = c[i].winPctStats.StandardError()
c[i].RUnlock()
break
}
Expand All @@ -79,23 +82,28 @@ func (s *Simmer) shouldStop(iterationCount uint64,

var ci float64
switch sc {
case Stop90:
ci = stats.Z90
case Stop95:
ci = stats.Z95
case Stop98:
ci = stats.Z98
case Stop99:
ci = stats.Z99
case Stop999:
ci = stats.Z999
}
tiebreakByEquity := false
tentativeWinner := c[0]
tentativeWinner.RLock()
μ := tentativeWinner.winPctStats.Mean()
e := tentativeWinner.winPctStats.StandardError(ci)
if μ <= MinReasonableWProb {
e := tentativeWinner.winPctStats.StandardError()
if passTest(MinReasonableWProb, 0, μ, e, ci) {
// If the top play by win % has basically no win chance, tiebreak the whole
// thing by equity.
tiebreakByEquity = true
} else if μ >= (1-MinReasonableWProb) && bottomUnignoredWinPct >= (1-MinReasonableWProb) {
} else if passTest(μ, e, 1-MinReasonableWProb, 0, ci) &&
passTest(bottomUnignoredWinPct, bottomUnignoredSerr, 1-MinReasonableWProb, 0, ci) {
// If the top play by win % has basically no losing chance, check if the bottom
// play also has no losing chance
tiebreakByEquity = true
Expand Down Expand Up @@ -124,7 +132,7 @@ func (s *Simmer) shouldStop(iterationCount uint64,
Msg("tiebreaking by equity, re-determining tentative winner")
}
μ = tentativeWinner.equityStats.Mean()
e = tentativeWinner.equityStats.StandardError(ci)
e = tentativeWinner.equityStats.StandardError()
log.Debug().Msg("stopping-condition-tiebreak-by-equity")
}

Expand All @@ -137,13 +145,13 @@ func (s *Simmer) shouldStop(iterationCount uint64,
continue
}
μi := p.winPctStats.Mean()
ei := p.winPctStats.StandardError(ci)
ei := p.winPctStats.StandardError()
if tiebreakByEquity {
μi = p.equityStats.Mean()
ei = p.equityStats.StandardError(ci)
ei = p.equityStats.StandardError()
}
p.RUnlock()
if passTest(μ, e, μi, ei) {
if passTest(μ, e, μi, ei, ci) {
p.Ignore()
newIgnored++
} else if iterationCount > SimilarPlaysIterationsCutoff {
Expand All @@ -163,12 +171,16 @@ func (s *Simmer) shouldStop(iterationCount uint64,
return false
}

// passTest: determine if a random variable X > Y with the given
// confidence level; return true if X > Y.
func passTest(μ, e, μi, ei float64) bool {
// Z := zVal(μ, v, μi, vi)
// X > Y if (μ - e) > (μi + ei)
return (μ - e) > (μi + ei)
// passTest: determine if a random variable X > Y with the given z-score; return true if X > Y.
// μ and e are the mean and standard error of variable X
// μi, ei are the mean and standard error of variable Y
func passTest(μ, e, μi, ei, z float64) bool {
sediff := math.Sqrt(e*e + ei*ei)
if sediff == 0 {
return true
}
zcalc := (μ - μi) / sediff
return zcalc > z
}

func materiallySimilar(p1, p2 *SimmedPlay, pcache map[string]bool) bool {
Expand Down
26 changes: 6 additions & 20 deletions montecarlo/stopping_condition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,15 @@ package montecarlo
import (
"testing"

"github.com/domino14/macondo/stats"
"github.com/matryer/is"
)

// func TestPassZTest(t *testing.T) {
// is := is.New(t)
// is.Equal(passTest(450, 10000, 460, 10000, Stop95), false)
// is.True(passTest(450, 10, 400, 5, Stop95))
// is.True(passTest(450, 10, 400, 5, Stop99))
// is.Equal(passTest(450, 10, 450, 5, Stop95), false)
// // 53% win chances with a stdev of 0.01 beats 50% win chances with a stdev of 0.01
// // at the 95% confidence level, but not at the 99% confidence level.
// is.True(passTest(0.53, 0.0001, 0.50, 0.0001, Stop95))
// is.Equal(passTest(0.53, 0.0001, 0.50, 0.0001, Stop99), false)
// }

// func TestZVal(t *testing.T) {
// is := is.New(t)
// is.Equal(zValStdev(10, 5, 10, 2), float64(0))
// is.True(math.Abs(zValStdev(450, 100, 460, 100)-(0.07071)) < 0.0001)
// }

func TestPassTest(t *testing.T) {
is := is.New(t)
is.True(passTest(30, 1, 27.9, 1))
is.True(!passTest(30, 1, 28.0, 1))
is.True(passTest(30, 1, 27.2, 1, stats.Z95))
is.True(!passTest(30, 1, 29.0, 1, stats.Z95))

is.True(!passTest(100, 5, 90, 4, stats.Z99))
is.True(passTest(30, 1, 25, 1, stats.Z999))
}
41 changes: 41 additions & 0 deletions scripts/repeat_sim.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
local macondo = require("macondo")


-- does not work on Windows, I think:
function sleep(n)
os.execute("sleep " .. tonumber(n))
end

function dump(o)
if type(o) == 'table' then
local s = '{ '
for k,v in pairs(o) do
if type(k) ~= 'number' then k = '"'..k..'"' end
s = s .. '['..k..'] = ' .. dump(v) .. ','
end
return s .. '} '
else
return tostring(o)
end
end


local plays = {}
start_time = os.time()

for i=1,100 do
macondo.load('cgp 14Q/13GI/9A1B1U1/9C1O1C1/9TUPIK1/9I5/4R4V3OE/3JIBED1E3uH/2LID1WOOSH2T1/3V1A4I2G1/3EFT1cANNULAE/3YAR4T2Z1/4XI3PENNED/5A3ER2DO/7TANSY2L AEINOST/ 378/316 0 lex NWL23;')
local elite_play = macondo.elite_play()
if plays[elite_play] == nil then
plays[elite_play] = 1
else
plays[elite_play] = plays[elite_play] + 1
end
print("WINNER ".. elite_play)
end

end_time = os.time()
elapsed_time = os.difftime(end_time, start_time)

print(dump(plays))
print("elapsed_time = " .. elapsed_time)
64 changes: 50 additions & 14 deletions shell/script.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
package shell

import (
"context"
"errors"
"strings"
"time"

"github.com/rs/zerolog/log"
lua "github.com/yuin/gopher-lua"
)

var exports = map[string]lua.LGFunction{
"load": load,
"set": set,
"gen": gen,
"sim": sim,
"turn": turn,
"gid": gid,
"endgame": endgame,
"busy": busy,
"elite_play": elitePlay,
}

func getShell(L *lua.LState) *ShellController {
shell := L.GetGlobal("macondo_shell")
ud, ok := shell.(*lua.LUserData)
Expand All @@ -21,7 +35,7 @@ func getShell(L *lua.LState) *ShellController {
return sc
}

func Set(L *lua.LState) int {
func set(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
r, err := sc.set(&shellcmd{
Expand All @@ -38,7 +52,7 @@ func Set(L *lua.LState) int {
return 1
}

func Load(L *lua.LState) int {
func load(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
r, err := sc.load(&shellcmd{
Expand All @@ -55,7 +69,7 @@ func Load(L *lua.LState) int {
return 1
}

func Gen(L *lua.LState) int {
func gen(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
r, err := sc.generate(&shellcmd{
Expand All @@ -70,7 +84,7 @@ func Gen(L *lua.LState) int {
return 1
}

func Turn(L *lua.LState) int {
func turn(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
r, err := sc.turn(&shellcmd{
Expand All @@ -85,7 +99,7 @@ func Turn(L *lua.LState) int {
return 1
}

func Gid(L *lua.LState) int {
func gid(L *lua.LState) int {
sc := getShell(L)
r, err := sc.gid(&shellcmd{
cmd: "gid",
Expand All @@ -98,7 +112,20 @@ func Gid(L *lua.LState) int {
return 1
}

func Endgame(L *lua.LState) int {
func elitePlay(L *lua.LState) int {
sc := getShell(L)
sc.botCtx, sc.botCtxCancel = context.WithTimeout(context.Background(), time.Second*time.Duration(60))
defer sc.botCtxCancel()
m, err := sc.elitebot.BestPlay(sc.botCtx)
if err != nil {
log.Err(err).Msg("error with eliteplay")
return 0
}
L.Push(lua.LString(m.ShortDescription()))
return 1
}

func endgame(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
cmd, err := extractFields("endgame " + lv)
Expand All @@ -115,7 +142,7 @@ func Endgame(L *lua.LState) int {
return 1
}

func Sim(L *lua.LState) int {
func sim(L *lua.LState) int {
lv := L.ToString(1)
sc := getShell(L)
cmd, err := extractFields("sim " + lv)
Expand All @@ -132,6 +159,19 @@ func Sim(L *lua.LState) int {
return 1
}

func busy(L *lua.LState) int {
sc := getShell(L)
L.Push(lua.LBool(sc.solving()))
return 1
}

func Loader(L *lua.LState) int {
mod := L.SetFuncs(L.NewTable(), exports)

L.Push(mod)
return 1
}

func (sc *ShellController) script(cmd *shellcmd) (*Response, error) {
if cmd.args == nil {
return nil, errors.New("need arguments for script")
Expand All @@ -145,17 +185,13 @@ func (sc *ShellController) script(cmd *shellcmd) (*Response, error) {
L := lua.NewState()
defer L.Close()

L.PreloadModule("macondo", Loader)

lsc := L.NewUserData()
lsc.Value = sc

L.SetGlobal("macondo_shell", lsc)
L.SetGlobal("macondo_gen", L.NewFunction(Gen))
L.SetGlobal("macondo_load", L.NewFunction(Load))
L.SetGlobal("macondo_gid", L.NewFunction(Gid))
L.SetGlobal("macondo_set", L.NewFunction(Set))
L.SetGlobal("macondo_turn", L.NewFunction(Turn))
L.SetGlobal("macondo_endgame", L.NewFunction(Endgame))
L.SetGlobal("macondo_sim", L.NewFunction(Sim))

if len(cmd.args) > 1 {
table := L.NewTable()
joinedStr := strings.Join(cmd.args[1:], " ")
Expand Down
8 changes: 7 additions & 1 deletion shell/sim.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ func (sc *ShellController) handleSim(args []string, options CmdOptions) error {
return err
}
switch sci {
case 90:
stoppingCondition = montecarlo.Stop90
case 95:
stoppingCondition = montecarlo.Stop95
case 98:
stoppingCondition = montecarlo.Stop98
case 99:
stoppingCondition = montecarlo.Stop99
case 999:
stoppingCondition = montecarlo.Stop999
default:
return errors.New("only allowed values are 95, 98, and 99 for stopping condition")
return errors.New("only allowed values are 90, 95, 98, 99, and 999 for stopping condition")
}
case "opprack":
knownOppRack = options.String(opt)
Expand Down Expand Up @@ -186,6 +190,8 @@ func (sc *ShellController) simControlArguments(args []string) error {
sc.showMessage(sc.simmer.ScoreDetails())
case "show":
sc.showMessage(sc.simmer.EquityStats())
case "winner":
sc.showMessage(sc.simmer.WinningPlay().Move().ShortDescription())
case "continue":
if sc.simmer.IsSimming() {
return errors.New("there is an ongoing simulation")
Expand Down
9 changes: 3 additions & 6 deletions stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ func (s *Statistic) Last() float64 {
return s.last
}

// StandardError returns the standard error of the statistic, multiplied
// by a factor. i.e. if you pass in 1.96, that would be a 95% confidence
// interval, 2.58 is a 99% confidence interval. (see math for more details)
// 1 is just 68% or 1 stdev.
func (s *Statistic) StandardError(m float64) float64 {
return m * math.Sqrt(s.Variance()/float64(s.totalIterations))
// StandardError returns the standard error of the statistic.
func (s *Statistic) StandardError() float64 {
return math.Sqrt(s.Variance() / float64(s.totalIterations))
}

func (s *Statistic) Iterations() int {
Expand Down
Loading

0 comments on commit df47064

Please sign in to comment.