In [15]:
using Statistics
using DataStructures: CircularDeque
using Flux, Flux.Optimise
using Images.ImageCore
using Flux: onehotbatch, onecold, crossentropy, Momentum, params, ADAM
using Flux: loadmodel!
using Base.Iterators: partition
using CUDA
using StatsBase: sample
using Plots;

In [2]:
MAX_REPLAY_MEMORY_SIZE = 1000
MINIBATCH_SIZE = 10
DISCOUNT = 0.9
UPDATE_TARGET_EVERY = 10

0.9

In [3]:
struct Memory
    sₜ
    aₜ
    rₙ
    sₙ
    done
end

In [4]:
function model()
    Chain(
        Dense(10, 120),
        Dense(120, 84),
        Dense(84, 5),
    softmax)
end

model (generic function with 1 method)

In [5]:
function update_memory!(mem::Memory, replay_mem::Array{Memory})
    if length(replay_mem) < MAX_REPLAY_MEMORY_SIZE 
        push!(replay_mem, mem)
    else
        popfirst!(replay_mem)
        push!(replay_mem, mem)
    end
end

update_memory! (generic function with 1 method)

In [6]:
loss(x, y) = Flux.Losses.mse(m(x), y)
opt = ADAM()

function train!(replay_mem::Array{Memory}, main_model, target_model, terminal_state, target_update_counter)
    minibatch = sample(replay_memory, MINIBATCH_SIZE, replace=false)
    current_states = [transition.sₜ for transition in minibatch] |> gpu
    current_qs_list = main_model(current_states)

    new_current_states = [transition.sₙ for transition in minibatch] |> gpu
    future_qs_list = target_model(new_current_states)

    X = []
    y = []
    for (i, data_point) in enumerate(minibatch)
        sₜ, aₜ, rₙ, sₙ, done = data_point.sₜ, data_point.aₜ, data_point.rₙ, data_point.sₙ, data_point.done
        if done
            new_q = rₙ
        else
            new_q = rₙ + DISCOUNT * maximum(future_qs_list[i])
        end

        current_qs = current_qs_list[i]
        current_qs[action] = new_q

        push!(X, current_state)
        push!(y, current_qs)
    end

    ps = Flux.params(m)
    data = Flux.Data.DataLoader((X, Y),shuffle=true);
    Flux.train!(loss, ps, data, opt)

    if terminal_state
        target_update_counter += 1
    end

    if target_update_counter > UPDATE_TARGET_EVERY
        target_model = Flux.loadmodel!(target_model, main_model);
        target_update_counter = 0
    end

end

train! (generic function with 1 method)

In [7]:
main_model = model() |> gpu
target_model = model() |> gpu;
target_model = Flux.loadmodel!(target_model, main_model);

In [8]:
replay_memory = Memory[]

Memory[]

In [10]:
target_update_counter = 0

0

In [11]:
for i in 1:1000
    update_memory!(Memory(rand(1:1000), rand(1:1000), rand(1:1000), rand(1:1000), true), replay_memory)
end

In [12]:
a = train!(replay_memory, main_model, target_model) 

5-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.0
 0.0
 0.0
 0.0
 1.0

In [None]:
rand(0:2)

ADAM(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [None]:
import Pkg; Pkg.add("DataStructures")