In [36]:
using Random, Distributions, Printf
using LinearAlgebra: dot, ⋅
using Profile, BenchmarkTools

#bids = [10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43]
bids = [10, 11, 20, 21, 30, 31, 40, 41]

mutable struct Node{K<:Integer, V<:AbstractFloat}
    regretsum::Dict{K, V}
    strategy::Dict{K, V}
    strategysum::Dict{K, V}
end

function key(hand, history)
    handstr = join(hand)
    histstr = join(history)
    handstr * histstr
end

pos(x) = x > 0 ? x : zero(x)

issomething(s) = !isnothing(s)

Base.map(f, dict::AbstractDict) = Dict(k => f(v) for (k, v) in dict)

function normalize(xs)
    total = sum(values(xs))
    n = length(xs)
    total > 0 ? map(x -> x / total, xs) : map(_ -> 1.0 / n, xs)
end

parsebid(n) = (n ÷ 10, n % 10)

function getstrategy(node, realizationweight)
    node.strategy = map(pos, node.regretsum) |> normalize
    for (k, v) in node.strategy
        node.strategysum[k] += realizationweight * v
    end
    node.strategy
end
    
function getnode(policys, hand, history)
    k = key(hand, history)
    if !haskey(policys, k)
        dict()::Dict{UInt8, Float64} = Dict(k => 0.0 for k in actions(history))
        node = Node(dict(), dict(), dict())
        policys[k] = node
    end
    policys[k]
end

function terminal(history, hands)
    plays = length(history)
    if plays > 2
        if history[end] == 1
            (quant, rank) = parsebid(history[end-2])
            cnt = count(c -> c == rank, hands[1]) + count(c -> c == rank, hands[2])
            return cnt >= quant ? -1 : 1
        end
    end
    return nothing
end

function actions(history)
    n = length(history)
    if n > 2 && history[end-2] == 0 && history[end] == 0
        return [1]
    elseif n > 1 && history[end] == 0
        return  pushfirst!(filter(x -> x > history[end-1], bids), 1)
    elseif n > 0
        return pushfirst!(filter(x -> x > history[end], bids), 0)
    else
        return bids
    end
end
                    
function cfr(policys, hands, history, p1, p2)::Float64
    
    terminalutility = terminal(history, hands)
    if issomething(terminalutility)
        return terminalutility
    end
    
    player = length(history) % 2 + 1
    prob = [p1, p2][player]
    node = getnode(policys, hands[player], history)
    strategy = getstrategy(node, prob)
    util = Dict{UInt8, Float64}()
    for a in actions(history)
        nexthistory = copy(history)
        push!(nexthistory, a)
        util[a] = player == 1 ?
            -cfr(policys, hands, nexthistory, p1 * strategy[a], p2) :
            -cfr(policys, hands, nexthistory, p1, p2 * strategy[a])
    end
    nodeutil = 0.0
    regret = Dict()
    for (k, v) in strategy
        nodeutil += v * util[k]
        regret[k] = util[k] - nodeutil
        node.regretsum[k] += [p2, p1][player] * regret[k]
    end
    nodeutil
end

function train(n)
    policys = Dict{String, Node{UInt8, Float64}}()
    util = 0.0
    for _ in 1:n
        hands = [sort!(rand(0:1,2)), sort!(rand(0:1,2))]
        util += cfr(policys, hands, [], 1.0, 1.0)
    end
    println(util / n)
    ps = Dict{String, Dict{UInt8, Float64}}()
    for (k, v::Node{UInt8, Float64}) in policys
        s = v.strategysum
        ps[k] = normalize(s)
    end
    policys, ps
end

function displayresult(d)
    for (k, v) in d
        print(k, " => ")
        map(x -> @printf("%0.2f, ", x), v)
        println()
    end
end

displayresult (generic function with 1 method)

In [37]:
@time train(100);

0.11801737439163165
  2.418754 seconds (22.40 M allocations: 1.399 GiB, 26.75% gc time)


In [38]:
policys, p = train(100_000);

0.08931528333205331


In [None]:
print(p)
println("-----")
displayresult(sort(p))
println("-----")
print(policys)

In [51]:
p["0030"]

Dict{UInt8,Float64} with 4 entries:
  0x00 => 0.280437
  0x1f => 0.204794
  0x28 => 0.514759
  0x29 => 1.00815e-5