## A customized DQN implementation

A small-scaled DQN implementation for fast prototyping and playing with

In [13]:
import Pkg;
# uncomment the following if you have not installed them
# Pkg.add("ReinforcementLearning");
# Pkg.add("Flux");
# Pkg.add("StableRNGs");
# Pkg.add("Distributions");]
# Pkg.add("UnicodePlots")
using Flux: InvDecay;
using ReinforcementLearning;
using StableRNGs;
using Flux;
using Flux.Losses;
using Distributions;
using UnicodePlots:lineplot, lineplot!
using Statistics

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m    Updating[22m[39m `~/.julia/environments/v1.7/Project.toml`
 [90m [b8865327] [39m[92m+ UnicodePlots v2.10.3[39m
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`


In [2]:
# set random seed and env
seed = 245;
rng = StableRNG(seed);
env = CartPoleEnv(; T = Float32);
ns, na = length(state_space(env)), length(action_space(env));

In [46]:
# model, policy, loss, update step
learner = Chain(
    Dense(ns, 128, relu; init = glorot_uniform(rng)),
    Dense(128, 128, relu; init = glorot_uniform(rng)),
    Dense(128, na; init = glorot_uniform(rng)),
) |> gpu;
optimizer = ADAM();

function GreedyPolicy(states, learner)
    logits = learner(states)
    actions = mapslices(argmax, logits, dims=1)
    return actions
end

function EpsilonGreedyPolicy(states, learner, t_current, t_max)
    ϵ_min = 0.005
    ϵ = max(1-t_current/t_max, ϵ_min)
    random_number = rand(Uniform(0,1))
    if random_number > ϵ
        action = GreedyPolicy(states, learner)
    else
        action = rand(rng, 1:2)
    end
    return action
end

function value_loss(batch)
    # TODO: improve the inefficient loss calculation
    num_sample = length(batch["actions"])
    loss = 0
    γ = 0.96
    q_values = learner(batch["states"])
    next_values = findmax(learner(batch["next_states"]); dims=1)[1]
    for i in 1:length(batch["actions"])
        target = Flux.Zygote.ignore() do
            batch["rewards"] + γ*next_values
        end
        loss = loss + mse(q_values[batch["actions"][i],i], target)
    end
    return loss/num_sample
end

function update_learner(learner, batch)
    grad = Flux.gradient(Flux.params(learner)) do
        value_loss(batch)
    end
    Flux.update!(optimizer, Flux.params(learner), grad)
end

update_learner (generic function with 1 method)

In [None]:
# interact with env to collect data and do the update steps
policy = EpsilonGreedyPolicy
stop_criterion = StopAfterEpisode(15000)

# ss = nothing
# aa = nothing
# rr = nothing
# nst = nothing
# bb = nothing
# aa_idx = nothing
total_rewards = Array{Float64}(undef, 1, 1)
step_counter = 0
max_step = 5e5

while true
    reset!(env)
    episode_reward = 0
    states = Array(state(env))
    actions = Array{Int32}(undef, 1, 1)
    rewards = Array{Float64}(undef, 1, 1)
    
    while !is_terminated(env)
        #env |> policy |> env
        action = policy(state(env), learner, step_counter, max_step)[1]
        step_counter = step_counter +1
        env(action)
        
        states = [states state(env)]
        actions = [actions action]
        rewards = [rewards reward(env)]
        episode_reward += reward(env)
        #stop_criterion(policy, env) && return # stop criterion: max episodes
    end
    # end of an episode
    # processing the data
    next_states = states[:,2:end]
    states = states[:,1:end-1]
    rewards = rewards[:,2:end]
    actions = actions[:,2:end]
    action_index = [(0,0)]
    for i = 1:length(actions)
        action_index = [action_index (actions[i],i)]
    end
    action_index = action_index[:,2:end]
    batch = Dict("states"=>states, "actions"=>actions, "rewards"=>rewards,
                 "next_states"=>next_states, "action_mask"=>action_index)
    
    total_rewards = [total_rewards episode_reward]
    step_counter >= max_step && break # stop criterion: max steps

#     ss = states
#     aa = actions
#     rr = rewards
#     nst = next_states
#     aa_idx = action_index
#     bb = batch
    
    # update steps
    for i = 1:3
        update_learner(learner, batch)
    end
end

In [None]:
# plot the episodes
p = lineplot(total_rewards[2:end], title="Total reward per episode", xlabel="Episode", ylabel="Score")