/
solver.go
100 lines (86 loc) · 2.17 KB
/
solver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package main
import (
"math"
"math/rand"
"github.com/frickiericker/learn-go/07-magic_square/matrix"
"github.com/frickiericker/learn-go/07-magic_square/seq"
)
type MagicSquareSolver struct {
square *matrix.Square
magic float64
rnd *rand.Rand
}
func NewMagicSquareSolver(size int, rnd *rand.Rand) *MagicSquareSolver {
square := matrix.NewSquare(size)
for i := range square.Data() {
square.Data()[i] = float64(i + 1)
}
return &MagicSquareSolver{
square: square,
magic: float64(magicConstant(size)),
rnd: rnd,
}
}
func (solver *MagicSquareSolver) Get() *matrix.Square {
return solver.square
}
func (solver *MagicSquareSolver) Randomize() {
seq.Shuffle(seq.NewFloat64Slice(solver.square.Data()), solver.rnd)
}
func (solver *MagicSquareSolver) evaluate() float64 {
loss := float64(0)
magic := solver.magic
for i := 0; i < solver.square.Size(); i++ {
loss += math.Abs(sumRow(solver.square, i) - magic)
loss += math.Abs(sumCol(solver.square, i) - magic)
}
loss += math.Abs(sumDiag(solver.square) - magic)
loss += math.Abs(sumAntidiag(solver.square) - magic)
return loss
}
func (solver *MagicSquareSolver) SearchNeighbor() float64 {
n := solver.square.Size()
data := solver.square.Data()
c1 := solver.rnd.Intn(n * n)
c2 := solver.rnd.Intn(n * n)
lossBefore := solver.evaluate()
data[c1], data[c2] = data[c2], data[c1]
lossAfter := solver.evaluate()
if lossAfter - 3 > lossBefore {
data[c1], data[c2] = data[c2], data[c1]
return lossBefore
}
return lossAfter
}
func sumRow(square *matrix.Square, row int) float64 {
sum := float64(0)
for col := 0; col < square.Cols(); col++ {
sum += square.Get(row, col)
}
return sum
}
func sumCol(square *matrix.Square, col int) float64 {
sum := float64(0)
for row := 0; row < square.Rows(); row++ {
sum += square.Get(row, col)
}
return sum
}
func sumDiag(square *matrix.Square) float64 {
sum := float64(0)
for i := 0; i < square.Size(); i++ {
sum += square.Get(i, i)
}
return sum
}
func sumAntidiag(square *matrix.Square) float64 {
sum := float64(0)
size := square.Size()
for i := 0; i < size; i++ {
sum += square.Get(i, size-i-1)
}
return sum
}
func magicConstant(n int) int {
return n * (n*n + 1) / 2
}