In [1]:
Threads.nthreads()

4

In [2]:
using Flux
using Zygote
using PyCall
using DataStructures
using StatsBase
using Printf
using LinearAlgebra

In [3]:
gym = pyimport("gym")
env = gym.make("CartPole-v1")

PyObject <TimeLimit<CartPoleEnv<CartPole-v1>>>

In [107]:
STATE_SPACE = length(env.observation_space.low)
ACTION_SPACE = 2
ACTIONS = collect(0:(ACTION_SPACE-1))
EPSILON_START = 0.50
EPSILON_END = 0.05
EPOCHS = 50
STEPS_PER_EPOCH = 2000
BUFFER_SIZE = 20000
BATCH_SIZE = 64
DISCOUNT = 0.99
L2_REG = 1e-3
TARGET_UPDATE = 5e-4
LEARNING_RATE = 1e-3
LAYER1_SIZE = 20
LAYER2_SIZE = 30

30

In [108]:
memory = CircularBuffer{Tuple{Array{Float32,1}, Int64, Float32, Array{Float32,1}, Bool}}(BUFFER_SIZE)

0-element CircularBuffer{Tuple{Array{Float32,1},Int64,Float32,Array{Float32,1},Bool}}

In [109]:
model = Chain(
    Dense(STATE_SPACE + ACTION_SPACE, LAYER1_SIZE, tanh),
    LayerNorm(LAYER1_SIZE),
    Dense(LAYER1_SIZE, LAYER2_SIZE, tanh),
    LayerNorm(LAYER2_SIZE),
    Dense(LAYER2_SIZE, 1),
)

target_model = deepcopy(model)

function forward(model, state, action::Int)
    model(vcat(state, Flux.onehot(action, ACTIONS)))
end

function forward(model, state, action::AbstractArray)
    model(vcat(state, Flux.onehotbatch(action, ACTIONS)))
end

function q_values(model, state)
    tmp = map((a) -> forward(model, state, ndims(state) == 1 ? a : fill(a, size(state)[2])), ACTIONS)
    reduce(vcat, tmp)
end

function valueest(model, state::AbstractArray)
    return maximum(q_values(model, state), dims=1)
end

# Act after an epsilon-greedy strategy
function act(model, state, epsilon)
    if epsilon != 0 && rand() <= epsilon
        return rand(ACTIONS)
    end
    tmp = ACTIONS[map((idx) -> idx[1], argmax(q_values(model, state), dims=1))]
    ndims(tmp) == 1 ? tmp[1] : tmp
end

act (generic function with 1 method)

In [110]:
optim = ADAM(LEARNING_RATE)

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

In [111]:
function grad_descend!(model, parameters, target_model, memory)
    outer_loss = 0
    # Sample BATCH_SIZE elements from the replay buffer
    batch = memory[1:BATCH_SIZE]
    sample!(memory, batch, replace=false)
    # Destructure the batch
    state, action, reward, next_state, death = [getindex.(batch, i) for i in 1:5]
    state = hcat(state...)
    next_state = hcat(next_state...)

    # Perform one step of grad descend
    Q_next = .!death .* valueest(target_model, next_state)
    target = reward .+ DISCOUNT .* Q_next
    
    grads = gradient(parameters) do
        Q_pred = forward(model, state, action)
        loss = Flux.huber_loss(Q_pred, target) + L2_REG * sum(norm, parameters)
        outer_loss += loss
        return loss
    end

    Flux.update!(optim, parameters, grads)
        
    for (p, p_target) in zip(parameters, Flux.params(target_model))
       p_target .= (1-TARGET_UPDATE) .* p_target .+ TARGET_UPDATE .* p
    end
    
    return outer_loss, mean(target)
end

grad_descend! (generic function with 1 method)

In [112]:
function train_epoch!(memory, model, target_model, env, epoch)
    state = Float32.(env.reset())
    total_reward = 0
    total_deaths = 0
    epsilon = EPSILON_START * (1 - epoch/EPOCHS) + EPSILON_END * (epoch/EPOCHS)
    
    parameters = Flux.params(model)
    total_loss = 0.0
    total_q = 0.0
    total_q_target = 0.0
    
   
    for step in 1:STEPS_PER_EPOCH
        action = act(model, state, epsilon) # Act
        next_state, reward, death, _ = env.step(action) # Advance the env
        
        # Convert to Float32
        next_state = Float32.(next_state)
        reward = Float32(reward)
        
        push!(memory, (state, action, reward, next_state, death)) # Save that memory
        total_reward += reward
        if death
            state = env.reset()
            total_deaths += 1
        else
            state = next_state
        end
        
        if length(memory) > BATCH_SIZE
            loss, q_target = grad_descend!(model, parameters, target_model, memory)
            total_loss += loss
            total_q_target += q_target
        end
    end
    return total_reward / (total_deaths + 1), total_loss / STEPS_PER_EPOCH, total_q_target / STEPS_PER_EPOCH
end

train_epoch! (generic function with 1 method)

In [113]:
for e in 1:EPOCHS
    loss, reward, target = train_epoch!(memory, model, target_model, env, e)
    println(@sprintf("Epoch %d, Loss %f, Reward %f, target %f", e, reward, loss, target))
    if reward > 200
        break
    end
end

Epoch 1, Loss 0.202699, Reward 14.705882, target 3.846654
Epoch 2, Loss 0.267166, Reward 17.391304, target 5.117557
Epoch 3, Loss 0.311139, Reward 22.222221, target 6.203709
Epoch 4, Loss 0.325515, Reward 45.454544, target 7.202230
Epoch 5, Loss 0.324456, Reward 51.282051, target 8.181096
Epoch 6, Loss 0.308244, Reward 38.461540, target 9.139338
Epoch 7, Loss 0.281239, Reward 35.714287, target 10.102285
Epoch 8, Loss 0.274610, Reward 28.571428, target 11.043230
Epoch 9, Loss 0.293690, Reward 24.691359, target 11.950754
Epoch 10, Loss 0.379699, Reward 22.471910, target 12.730969
Epoch 11, Loss 0.473040, Reward 19.047619, target 13.480741
Epoch 12, Loss 0.569868, Reward 18.348623, target 14.211634
Epoch 13, Loss 0.675977, Reward 18.867924, target 14.922304
Epoch 14, Loss 0.761607, Reward 19.230770, target 15.639124
Epoch 15, Loss 0.810820, Reward 21.739130, target 16.377488
Epoch 16, Loss 0.804513, Reward 27.397261, target 17.150451
Epoch 17, Loss 0.848867, Reward 14.184397, target 17.87

In [80]:
function test(model, env)
    total_reward = 0
    state = env.reset()
    for i in 1:STEPS_PER_EPOCH
        action = act(model, state, 0)
        state, reward, death, _ = env.step(action)
        total_reward += reward
        
        env.render()
        death && return total_reward, i
    end
end

test (generic function with 1 method)

In [98]:
test(target_model, env)

(129.0, 129)

In [29]:
env.reset()
s = 0
r = 0
d = 0
for _ in 1:5
    s,r,d = env.step(1)
end
valueest(model, s)

2-element Array{Float32,1}:
 10.167985
 10.290317