In [1]:
using Statistics
using DataStructures: CircularDeque
using Flux, Flux.Optimise
using Images.ImageCore
using Flux: onehotbatch, onecold, crossentropy, Momentum, params, ADAM
using Flux: loadmodel!
using Base.Iterators: partition
using CUDA
using StatsBase: sample
using BSON: @save, @load
using JLD2
using Serialization
using Plots;




In [2]:
MAX_REPLAY_MEMORY_SIZE = 1000
MINIBATCH_SIZE = 100
DISCOUNT = 0.9
UPDATE_TARGET_EVERY = 10

10

In [3]:
mutable struct Memory
    sₜ
    aₜ
    rₙ
    sₙ
    done
end

In [4]:
function model()
    m = Chain(
        Dense(550, 256),
        σ,
        Dense(256, 128),
        σ,
        Dense(128, 64),
        σ,
        Dense(64, 4),
        softmax
    )
    if isfile("mymodel.bson")
        @load "mymodel.bson" m
    end
    m
end

model (generic function with 1 method)

In [5]:
function update_memory!(mem::Memory, replay_mem::Array{Memory})
    if length(replay_mem) < MAX_REPLAY_MEMORY_SIZE 
        push!(replay_mem, mem)
    else
        popfirst!(replay_mem)
        push!(replay_mem, mem)
    end
end

update_memory! (generic function with 1 method)

In [6]:
function update_weights!(target_model, main_model)
    for (main_param, target_param) in zip(params(main_model), params(target_model))
        target_param .= main_param
    end
end

update_weights! (generic function with 1 method)

In [7]:
loss(x, y) = Flux.Losses.mse(main_model(x), y)
opt = ADAM()

function train!(replay_mem::Array{Memory}, main_model, target_model, terminal_state, target_update_counter)
    minibatch = sample(replay_memory, MINIBATCH_SIZE, replace=false)

    current_states = hcat((transition.sₜ for transition in minibatch)...)
    current_qs_list = main_model(current_states)
    
    new_current_states = hcat((transition.sₙ for transition in minibatch)...)
    future_qs_list = target_model(new_current_states)
    X = []
    y = []
    for (i, data_point) in enumerate(minibatch)
        println("here")
        sₜ, aₜ, rₙ, sₙ, done = data_point.sₜ, data_point.aₜ, data_point.rₙ, data_point.sₙ, data_point.done
        if done
            new_q = rₙ
        else
            new_q = rₙ + DISCOUNT * maximum(future_qs_list[:,i])
        end
        current_qs = current_qs_list[:,i]
        current_qs[aₜ,:] .= new_q
        push!(X, sₜ)
        push!(y, current_qs)
    end
    
    X = hcat((X)...);
    y = hcat((y)...);
    
    ps = Flux.params(main_model);
    data = Flux.Data.DataLoader((X, y),shuffle=true);
    Flux.train!(loss, ps, data, opt)

    if terminal_state
        target_update_counter += 1
    end

    if target_update_counter > UPDATE_TARGET_EVERY
        update_weights!(target_model, main_model)
        target_update_counter = 0
    end

end

train! (generic function with 1 method)

In [8]:
main_model = model() |> gpu
target_model = model() |> gpu;
update_weights!(target_model, main_model);

In [None]:
replay_memory = Memory[]

In [None]:
target_update_counter = 0

In [None]:
for i in 1:1000
    update_memory!(Memory(rand(550, 1) |> gpu, rand(1:4) |> gpu, rand(-10:10) |> gpu, rand(550, 1) |> gpu, false |> gpu), replay_memory)
end

In [None]:
length(replay_memory)

In [None]:
iteration = 100

In [None]:
serialize("replay_memory_$iteration.dat", replay_memory)

In [9]:
deserialize("replay_memory_10.dat")

KeyError: KeyError: key Interstate [d1dc07d0-456e-4a45-b97a-3837634689d3] not found

In [None]:
minibatch = sample(replay_memory, MINIBATCH_SIZE, replace=false);

In [None]:
replay_memory[length(replay_memory)].done = true

In [None]:
minibatch[MINIBATCH_SIZE] = replay_memory[length(replay_memory)]

In [None]:
minibatch