In [1]:
include("src/TakEnv.jl")
include("src/Encoder.jl")

Main.Encoder

In [2]:
using .TakEnv
using .TakEnv: FIELD_SIZE, FIELD_HEIGHT, possible_carries, possible_directions, stone_player, stone_type, get_stack_height, get_top_stone, flat, cap, stand, north, south, east, west, placement, carry, white, black
using StatProfilerHTML
using BenchmarkTools
using .Encoder
using Distributions

Float = Float64

Float64

In [7]:
hparams = Dict(
    "d_puct" => 1.0,
    "batch_size" => 64,
    "run_name" => "testrun",
    "exploration_factor" => 0.25)
HParams = typeof(hparams)


Dict{String,Any}

In [4]:
mutable struct NodeStats
    actions::Vector{Action}
    visit_counts::Vector{Int}
    values::Vector{Float}
    avg_values::Vector{Float}
    probs::Vector{Float}
end

MCTSStorage = Dict{CompressedBoard, NodeStats}


Dict{BitArray{1},NodeStats}

In [5]:
storage = MCTSStorage()


Dict{BitArray{1},NodeStats}()

In [8]:
function is_leaf(storage::MCTSStorage, state::CompressedBoard)
    !haskey(storage, state)
end

function find_leaf(hparams::HParams, storage::MCTSStorage, start_state::CompressedBoard, player::Player)
    states = CompressedBoard[]
    actions = CompressedAction[]
    
    cur_state = start_state
    cur_player = player
    value = nothing
    
    while !is_leaf(storage, cur_state)
        push!(states, cur_state)
        
        stats = storage[cur_state]
        total_sqrt = sqrt(sum(stats.visit_counts))
        
        probs = stats.probs
        
        # For the first move, add noise for exploration
        if cur_state == start_state
            noise_dist = Dirichlet(length(stats), 0.03)
            exploration = hparams["exploration_factor"]
            probs = (1-exploration) .* prob .+ exploration .* rand(noise_dist)
        end
        
        # Calculate action scores
        d_puct = hparams["d_puct"]
        score = stats.values_avg .+ d_puct .* probs .* total_sqrt ./ (1 .+ stats.visit_counts)
        
        
        # Advance the game state
        board = decompress_board(cur_state)
        possible_actions = enumerate_actions(board, cur_player)
        
        # Set all those score values to -Inf which are invalid actions
        score[isnothing.(getindex.(indexin.(stats.actions, Ref(compress_action.(possible_actions)))))] .= -Inf
        
        # Choose the argmax action
        cur_action = stats.actions[argmax(score)]
        push!(actions, cur_action)
        apply_action!(board, decompress_action(cur_action), cur_player)  
        
        # Check for win
        result = check_win(board, player)
        cur_player = opponent_player(cur_player)
        cur_state = compress_board(board)
        
        # We already switched players, so a victory was actually a loss
        if !isnothing(result)
            victory_type, player_won = result
            if victory_type == draw
                value = 0.0
            elseif player_won == cur_player
                value = -1.0
            else
                value = 1.0
            end
            break
        end
    end
    
    value, cur_state, cur_player, states, actions
end

find_leaf (generic function with 2 methods)

In [9]:
find_leaf(hparams, storage, compress_board(empty_board()), white::Player)

(nothing, Bool[0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0], white, BitArray{1}[], Int64[])

In [15]:
function search_batch(hparams::HParams, storage::MCTSStorage, start_state, player, model)
    backup_queue = []
    expand_states = []
    expand_players = []
    expand_queue = []
    planned = []
    
    # Perform a search
    for _ in 1:hparams["batch_size"]
        value, leaf_state, leaf_player, states, actions = find_leaf(hparams, storage, start_state, player)
        if !isnothing(value) # Win/lose/draw
            push!(backup_queue, (value, states, actions))
        elseif !(leaf_state in planned) # Need to expand
            # For expansion, precalculate all possible actions
            possible_actions = compress_action.(enumerate_actions(decompress_board(leaf_state)))
            
            push!(planned, leaf_state)
            push!(expand_states, leaf_state)
            push!(expand_players, leaf_player)
            push!(expand_queue, (leaf_state, states, actions, possible_actions))
        end
    end
    
    # Expand nodes
    if length(expand_queue)>0
        # TODO get values and probs via model
        
        # Save the node
        for (leaf_state, states, actions, possible_actions) in expand_queue
            stats = NodeStats(
                possible_actions,
                zeros(length(possible_actions)), # Visit counts
                zeros(length(possible_actions)), # Values
                zeros(length(possible_actions)), # Avg values
                rand(Dirichlet(length(possible_actions), 0.1)) # Probs, TODO replace with model predictions
            )
            storage[leaf_state] = stats
            push!(backup_queue, (rand(), states, actions))
        end
    end
end

search_batch (generic function with 2 methods)

In [16]:
search_batch(hparams, MCTSStorage(), compress_board(empty_board()), white::Player, x -> rand(size(x)))

In [23]:
NodeStats([], zeros(3), zeros(3), zeros(3), []).visit_counts

3-element Array{Int64,1}:
 0
 0
 0

In [52]:
x = rand(10)
y = vcat(x[1:3], zeros(3))

y[isnothing.(getindex.(indexin.(y, Ref(x))))] .= -1
y[2] = -Inf
y

6-element Array{Float64,1}:
   0.10210144791730258
 -Inf
   0.7351115019489811
  -1.0
  -1.0
  -1.0

In [42]:
getindex(indexin(0, y))

4