## A customized DQN implementation

A small-scaled DQN implementation for fast prototyping and playing with

In [5]:
import Pkg;
#uncomment the following if you have not installed them
# Pkg.add("ReinforcementLearning");
# Pkg.add("Flux");
# Pkg.add("StableRNGs");
# Pkg.add("Distributions");
# Pkg.add("UnicodePlots");
# Pkg.add("Zygote")
using Flux: InvDecay;
using ReinforcementLearning;
using StableRNGs;
using Flux;
using Flux.Losses;
using Distributions;
using UnicodePlots:lineplot, lineplot!
using Statistics
using Zygote

In [6]:
# set random seed and env
seed = 245;
rng = StableRNG(seed);
env = CartPoleEnv(; T = Float32);
ns, na = length(state_space(env)), length(action_space(env));

In [60]:
function GreedyPolicy(states, learner)
    logits = learner(states)
    actions = mapslices(argmax, logits, dims=1)
    return actions
end

function EpsilonGreedyPolicy(states, learner, t_current, t_max)
    ϵ_min = 0.005
    ϵ = max(1-t_current/t_max, ϵ_min)
    random_number = rand(Uniform(0,1))
    if random_number > ϵ
        action = GreedyPolicy(states, learner)
    else
        action = rand(rng, 1:2)
    end
    return action
end

function gather(q_values,actions)
    num_samples = size(actions,2)
    q_filtered = Array{Float64}(undef,1,1)
    for i = 1:num_samples
        q_filtered = [q_filtered q_values[actions[:,i],i]]
    end
    q_filtered = q_filtered[:,2:end]
    return q_filtered
end

function value_loss(learner, batch)
    # TODO: improve the inefficient loss calculation
    num_sample = size(batch["actions"],2)
    loss = 0
    γ = 0.96
    q_values = learner(batch["states"])
    next_values = findmax(learner(batch["next_states"]); dims=1)[1]
################################################
# implementation 1 (work)
#     target = Flux.Zygote.ignore() do
#         batch["rewards"] + γ*next_values
#     end
    target = batch["rewards"] + γ*next_values
    for i = 1:size(batch["actions"],2)
        loss = loss + mse(q_values[batch["actions"][i],i], target[i])
    end
    return loss/num_sample
###############################################
# implementation 2 (does not work)
#     q_values = gather(q_values,batch["actions"])
#     target = Flux.Zygote.ignore() do
#         batch["rewards"] + γ*next_values
#     end
#     loss = mse(q_values, target)
#     return loss
###############################################
# implementation 3 (does not work)
#     next_q_unfiltered = learner(batch["next_states"])
#     target_unfiltered = Flux.Zygote.ignore() do
#         broadcast(+, batch["rewards"], γ * next_q_unfiltered)
#     end
#     loss_unfiltered = mse(q_values,target_unfiltered;agg=identity)
#     for i = 1:num_sample
#         loss = loss + loss_unfiltered[batch["actions"][i],i]
#     end
#     println(loss)
#     return loss / num_sample
end

function update_learner(learner, batch)
    grad = Flux.gradient(() -> value_loss(learner, batch), Flux.params(learner))
    Flux.update!(optimizer, Flux.params(learner), grad)
end

update_learner (generic function with 1 method)

In [61]:
# model, policy, loss, update step
learner = Chain(
    Dense(ns, 128, relu; init = glorot_uniform(rng)),
    Dense(128, 128, relu; init = glorot_uniform(rng)),
    Dense(128, na; init = glorot_uniform(rng)),
);
optimizer = ADAM();

In [62]:
# interact with env to collect data and do the update steps
policy = EpsilonGreedyPolicy
stop_criterion = StopAfterEpisode(15000)
total_rewards = Array{Float64}(undef, 1, 1)
step_counter = 0
max_step = 1e5

while true
    reset!(env)
    episode_reward = 0
    states = Array(state(env))
    actions = Array{Int32}(undef, 1, 1)
    rewards = Array{Float64}(undef, 1, 1)
    
    while !is_terminated(env)
        #env |> policy |> env
        action = policy(state(env), learner, step_counter, max_step)[1]
        step_counter = step_counter +1
        env(action)
        
        states = [states state(env)]
        actions = [actions action]
        rewards = [rewards reward(env)]
        episode_reward += reward(env)
        #stop_criterion(policy, env) && return # stop criterion: max episodes
    end
    # end of an episode
    # processing the data
    next_states = states[:,2:end]
    states = states[:,1:end-1]
    rewards = rewards[:,2:end]
    actions = actions[:,2:end]
    action_index = [(0,0)]
    for i = 1:length(actions)
        action_index = [action_index (actions[i],i)]
    end
    action_index = action_index[:,2:end]
    # TODO: use named tuple instead of dictionary for performance
    batch = Dict("states"=>states, "actions"=>actions, "rewards"=>rewards,
                 "next_states"=>next_states, "action_mask"=>action_index)
    
    total_rewards = [total_rewards episode_reward]
    step_counter >= max_step && break # stop criterion: max steps
    
    # update steps
    for i = 1:3
        update_learner(learner, batch)
    end
end

In [63]:
# plot the episodes
p = lineplot(total_rewards[2:end], title="Total reward per episode", xlabel="Episode", ylabel="Score")

             ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
             [38;5;8m┌────────────────────────────────────────┐[0m 
         [38;5;8m200[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢸[0m⠀⠀[38;5;2m⣷[0m[38;5;2m⡇[0m[38;5;2m⣇[0m[38;5;2m⢸[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⡏[0m[38;5;2m⠁[0m[38;5;8m│[0m [38;5;8m[0m
            [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢸[0m[38;5;2m⢠[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣇[0m[38;5;2m⣿[0m[38;5;2m⢸[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⡇[0m⠀[38;5;8m│[0m [38;5;8m[0m
            [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢸[0m[38;5;2m⢸[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⢸[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⡇[0m⠀[38;5;8m│[0m [38;5;8m[0m
  

### Adding experience replay buffer

We have already implemented a basic DQN agent using on-policy data (in the first update repetition) with epsilon greedy exploration, the policy can already pick up effectively. However, due to the violation of some standard machine learning assumptions (samples are not i.i.d. and drawn from a stationary distribution), the learned policy can be unstable ('catastrophic forgetting'). Now we add an huge experience replay buffer to make our data distribution more 'stationary'.

In [20]:
learner_buffer = Chain(
    Dense(ns, 128, relu; init = glorot_uniform(rng)),
    Dense(128, 128, relu; init = glorot_uniform(rng)),
    Dense(128, na; init = glorot_uniform(rng)),
);
optimizer = ADAM();

In [21]:
function load_buffer(buffer, batch, buffer_size)
    buffer["states"] = [buffer["states"] batch["states"]]
    buffer["actions"] = [buffer["actions"] batch["actions"]]
    buffer["next_states"] = [buffer["next_states"] batch["next_states"]]
    buffer["rewards"] = [buffer["rewards"] batch["rewards"]]
    
    # trim the old data if the buffer is full
    if size(buffer["states"],2) >= buffer_size
        offset = size(buffer["states"],2) - buffer_size
        offset = convert(Int32, offset) + 1 # avoid 0-index
        buffer["states"] = buffer["states"][:,offset:end]
        buffer["actions"] = buffer["actions"][:,offset:end]
        buffer["next_states"] = buffer["next_states"][:,offset:end]
        buffer["rewards"] = buffer["rewards"][:,offset:end]
    end
    return buffer
end

function sample_buffer(buffer, batch_size)
    idx = rand(rng, 1:size(buffer["states"],2), batch_size)
    batch = Dict("states"=>Array{Float64}(undef, ns, 1), "actions"=>Array{Int32}(undef, 1, 1), 
              "rewards"=>Array{Float64}(undef, 1, 1), "next_states"=>Array{Float64}(undef, ns, 1))    
    batch["states"] = buffer["states"][:,idx]
    batch["actions"] = buffer["actions"][:,idx]
    batch["next_states"] = buffer["next_states"][:,idx]
    batch["rewards"] = buffer["rewards"][:,idx]
    
    # to include the newest trajectories
    batch["states"] = [batch["states"] buffer["states"][:,end-200:end]]
    batch["actions"] = [batch["actions"] buffer["actions"][:,end-200:end]]
    batch["next_states"] = [batch["next_states"] buffer["next_states"][:,end-200:end]]
    batch["rewards"] = [batch["rewards"] buffer["rewards"][:,end-200:end]]
    
    return batch
end

function learn_from_batch(learner_buffer, buffer, batch_size)
    batch = sample_buffer(buffer, batch_size)
    # update steps
    for i = 1:3
        update_learner(learner_buffer, batch)
    end
end

learn_from_batch (generic function with 1 method)

# TODO: debugging here

In [24]:
# interact with env to collect data and do the update steps
policy = EpsilonGreedyPolicy
#stop_criterion = StopAfterEpisode(15000)

ss = nothing
aa = nothing
rr = nothing
nst = nothing
bb = nothing
offset = nothing
total_rewards = Array{Float64}(undef, 1, 1)
step_counter = 0
max_step = 2e5
batch_size = 2048
buffer_size = 1e5
update_interval = 200
buffer = Dict("states"=>Array{Float64}(undef, ns, 1), "actions"=>Array{Int32}(undef, 1, 1), 
              "rewards"=>Array{Float64}(undef, 1, 1), "next_states"=>Array{Float64}(undef, ns, 1))
while true
    reset!(env)
    episode_reward = 0
    states = Array(state(env))
    actions = Array{Int32}(undef, 1, 1)
    rewards = Array{Float64}(undef, 1, 1)
    
    
    while !is_terminated(env)
        #env |> policy |> env
        action = policy(state(env), learner_buffer, step_counter, max_step)[1]
        step_counter = step_counter +1
        env(action)
        
        states = [states state(env)]
        actions = [actions action]
        rewards = [rewards reward(env)]
        episode_reward += reward(env)
        #stop_criterion(policy, env) && return # stop criterion: max episodes
        
        # every `update_interval` steps do the update step
        if mod(step_counter, update_interval) == 0 && size(buffer["states"], 2) > batch_size
            learn_from_batch(learner_buffer, buffer, batch_size)
        end        
    end
    # end of an episode
    # processing the data
    next_states = states[:,2:end]
    states = states[:,1:end-1]
    rewards = rewards[:,2:end]
    actions = actions[:,2:end]
    episode = Dict("states"=>states, "actions"=>actions, "rewards"=>rewards,
                 "next_states"=>next_states)    
    buffer = load_buffer(buffer, episode, buffer_size)
    
    total_rewards = [total_rewards episode_reward]
    step_counter >= max_step && break # stop criterion: max steps

    ss = states
    aa = actions
    rr = rewards
    nst = next_states
    bb = episode
end

In [25]:
# plot the episodes
p = lineplot(total_rewards[2:end], title="Total reward per episode", xlabel="Episode", ylabel="Score")

            ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
            [38;5;8m┌────────────────────────────────────────┐[0m 
         [38;5;8m90[0m [38;5;8m│[0m[38;5;2m⢠[0m[38;5;2m⡄[0m[38;5;2m⣷[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m[38;5;2m⢸[0m[38;5;2m⣿[0m[38;5;2m⣿[0m⠀⠀⠀⠀⠀[38;5;2m⢠[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m[38;5;2m⣾[0m[38;5;2m⣿[0m[38;5;2m⣿[0m⠀⠀⠀⠀⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⢰[0m⠀⠀⠀⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m[38;5;2m⣿[0m[38;5;2m⣿[0m[38;5;2m⣿