In [1]:
using Revise

In [2]:
using SQLite
using MsgPack
using Flux
using Flux.NNlib



# Import training data from DB

In [20]:
include("gamegenerator.jl")



GameGenerator

In [21]:
db, _, _ = GameGenerator.open_db()

(SQLite.DB("games.sqlite"), SQLite.Stmt(SQLite.DB("games.sqlite"), Ptr{Void} @0x0000000005680008), SQLite.Stmt(SQLite.DB("games.sqlite"), Ptr{Void} @0x00000000062d7cb8))

In [88]:
positions = SQLite.query(db, "select positions.*, games.outcome from positions, games where positions.game_id=games.id order by random() limit 10000");

In [89]:
positions[:board_state_u] = convert(Vector{Checkers.State}, GameGenerator.unpack_state.(positions[:board_state]));
positions[:mcts_probs_u] = convert(Vector{Vector{Float32}}, MsgPack.unpack.(positions[:mcts_probs]));
positions[:mcts_moves_u] = convert(Vector{Vector{Checkers.Move}}, GameGenerator.unpack_moves.(positions[:mcts_moves]));
positions[:state_tensor] = Checkers.NN.state_to_tensor.(positions[:board_state_u]);
positions[:move_tensor] = [Checkers.NN.moves_to_tensor(p[:mcts_probs_u], p[:mcts_moves_u]) for p in DataFrames.eachrow(positions)];

# Model

## Game outcome predictor

In [90]:
?Conv

search: [1mC[22m[1mo[22m[1mn[22m[1mv[22m [1mc[22m[1mo[22m[1mn[22m[1mv[22m [1mc[22m[1mo[22m[1mn[22m[1mv[22m2 [1mC[22m[1mo[22m[1mn[22m[1mv[22m2D [1mc[22m[1mo[22m[1mn[22m[1mv[22m2d [1mc[22m[1mo[22m[1mn[22m[1mv[22mert de[1mc[22m[1mo[22m[1mn[22m[1mv[22m [1mC[22m[1mo[22m[1mn[22mj[1mV[22mector [1mc[22m[1mo[22mde_[1mn[22mati[1mv[22me



```
Conv(size, in=>out)
Conv(size, in=>out, relu)
```

Standard convolutional layer. `size` should be a tuple like `(2, 2)`. `in` and `out` specify the number of input and output channels respectively.

Data should be stored in WHCN order. In other words, a 100×100 RGB image would be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.

Takes the keyword arguments `pad` and `stride`.


In [109]:
model = Chain(
    Conv((3, 3), 8=>16, relu, pad=(1, 1)),
    Conv((3, 3), 16=>32, relu, pad=(1, 1)),
    Conv((3, 3), 32=>64, relu, pad=(1, 1)),
    Conv((3, 3), 64=>128, relu, pad=(1, 1)),
    Conv((3, 3), 128=>128, relu, pad=(1, 1)),
    Conv((3, 3), 128=>128, relu, pad=(1, 1)),
    Conv((3, 3), 128=>128, relu, pad=(1, 1)),
    Conv((3, 3), 128=>128, relu, pad=(1, 1)),
    Conv((3, 3), 128=>4, relu, pad=(1, 1)),
    x -> reshape(x, :, size(x, 4)),
    softmax
)

Chain(Conv((3, 3), 8=>16, NNlib.relu), Conv((3, 3), 16=>32, NNlib.relu), Conv((3, 3), 32=>64, NNlib.relu), Conv((3, 3), 64=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu), Conv((3, 3), 128=>4, NNlib.relu), #35, NNlib.softmax)

In [110]:
loss(x, y) = Flux.crossentropy(model(x), reshape(y, :, size(y, 4)))

loss (generic function with 1 method)

In [111]:
get_batch(positions, start_idx, n) = (cat(4, positions[1:30, :state_tensor]...), cat(4, positions[1:30, :move_tensor]...))

get_batch (generic function with 1 method)

In [112]:
minibatches = [get_batch(positions, i, 1024) for i=1:1024:length(positions[1])];

In [113]:
x, y = minibatches[1];

In [114]:
loss(x, y)

4.852030290257223 (tracked)

In [118]:
opt = ADAM(params(model))

(::#71) (generic function with 1 method)

In [119]:
Flux.@epochs 50 Flux.train!(loss, minibatches, opt, cb=Flux.throttle(() -> println("Loss: ", loss(x, y)), 5))

[1m[36mINFO: [39m[22m[36mEpoch 1
[39m

Loss: 4.505379580477796 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 2
[39m

Loss: 3.017158944261936 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 3
[39m

Loss: 2.8558893517749344 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 4
[39m

Loss: 2.77312034060319 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 5
[39m

Loss: 2.739573279845636 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 6
[39m

Loss: 2.7232940901866356 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 7
[39m

Loss: 2.7173160663836424 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 8
[39m

Loss: 2.7146632929544965 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 9
[39m

Loss: 2.713467884216215 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 10
[39m

Loss: 2.7127790484088004 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 11
[39m

Loss: 2.712398184711141 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 12
[39m

Loss: 2.7123748890418984 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 13
[39m

Loss: 2.7121404090922065 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 14
[39m

Loss: 2.712111559826774 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 15
[39m

Loss: 2.7118171267564786 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 16
[39m

Loss: 2.711761580874466 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 17
[39m

Loss: 2.716936752133669 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 18
[39m

Loss: 2.7148063619981118 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 19
[39m

Loss: 2.7126888084835477 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 20
[39m

Loss: 2.712094997307827 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 21
[39m

Loss: 2.7117440482846233 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 22
[39m

Loss: 2.711578209670665 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 23
[39m

Loss: 2.711916541516564 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 24
[39m

Loss: 2.714184028979663 (tracked)
Loss: 2.712458751166721 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 25
[39m

Loss: 2.7120454295457828 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 26
[39m

Loss: 2.711680744980778 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 27
[39m

Loss: 2.711506573677109 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 28
[39m

Loss: 2.7114764969122875 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 29
[39m

Loss: 2.7124672555117746 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 30
[39m

Loss: 2.7131582119757183 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 31
[39m

Loss: 2.7127648074832544 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 32
[39m

Loss: 2.7119762874475914 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 33
[39m

Loss: 2.7116381247160937 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 34
[39m

Loss: 2.711537300094332 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 35
[39m

Loss: 2.711506335482012 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 36
[39m

Loss: 2.7114573912218067 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 37
[39m

Loss: 2.711473759283172 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 38
[39m

Loss: 2.715189250234892 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 39
[39m

Loss: 2.7116677890001197 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 40
[39m

Loss: 2.711693070412793 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 41
[39m

Loss: 2.711537953067622 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 42
[39m

Loss: 2.7114202274520998 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 43
[39m

Loss: 2.713854784703106 (tracked)
Loss: 2.714112906437473 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 44
[39m

Loss: 2.7124877597360455 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 45
[39m

Loss: 2.712930515240696 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 46
[39m

Loss: 2.711682133777005 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 47
[39m

Loss: 2.711565728604496 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 48
[39m

Loss: 2.7114281745133346 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 49
[39m

Loss: 2.711400543161475 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 50
[39m

Loss: 2.711409223894782 (tracked)
