## A customized DQN implementation

A small-scaled DQN implementation for fast prototyping and playing with

In [1]:
import Pkg;
# uncomment the following if you have not installed them
# Pkg.add("ReinforcementLearning");
# Pkg.add("Flux");
# Pkg.add("StableRNGs");
# Pkg.add("Distributions");
using Flux: InvDecay;
using ReinforcementLearning;
using StableRNGs;
using Flux;
using Flux.Losses;
using Distributions;

In [2]:
# set random seed and env
seed = 245;
rng = StableRNG(seed);
env = CartPoleEnv(; T = Float32);
ns, na = length(state_space(env)), length(action_space(env));

In [6]:
# model, policy, loss, update step
learner = Chain(
    Dense(ns, 128, relu; init = glorot_uniform(rng)),
    Dense(128, 128, relu; init = glorot_uniform(rng)),
    Dense(128, na; init = glorot_uniform(rng)),
) |> gpu;
optimizer = ADAM();

function GreedyPolicy(states, learner)
    logits = learner(states)
    actions = mapslices(argmax, logits, dims=1)
    return actions
end

#TODO: gradually decay the epsilon
function EpsilonGreedyPolicy(states, learner)
    ϵ = 0.01
    random_number = rand(Uniform(0,1))
    if random_number > ϵ
        action = GreedyPolicy(states, learner)
    else
        action = rand(rng, 1:2)
    end
    return action
end

function value_loss(batch)
    # TODO: improve the inefficient loss calculation
    num_sample = length(batch["actions"])
    loss = 0
    γ = 0.95
    q_values = learner(batch["states"])
    next_values = findmax(learner(batch["next_states"]); dims=1)[1]
    for i in 1:length(batch["actions"])
        target = Flux.Zygote.ignore() do
            batch["rewards"] + γ*next_values
        end
        loss = loss + mse(q_values[batch["actions"][i],i], target)
    end
    return loss/num_sample
end

function update_learner(learner, batch)
    grad = Flux.gradient(Flux.params(learner)) do
        value_loss(batch)
    end
    Flux.update!(optimizer, Flux.params(learner), grad)
end

update_learner (generic function with 1 method)

In [5]:
# interact with env to collect data and do the update steps
policy = EpsilonGreedyPolicy
stop_criterion = StopAfterEpisode(15000)

# ss = nothing
# aa = nothing
# rr = nothing
# nst = nothing
# bb = nothing
# aa_idx = nothing

step_counter = 0
while true
    reset!(env)
    episode_reward = 0
    #policy(PRE_EPISODE_STAGE)
    states = Array(state(env))
    actions = Array{Int32}(undef, 1, 1)
    rewards = Array{Float64}(undef, 1, 1)
    
    while !is_terminated(env)
        #env |> policy |> env
        action = policy(state(env), learner)[1]
        step_counter = step_counter +1
        
        #policy(PRE_ACT_STAGE, env, action)
        env(action)
        #println("action:", action)
        
        states = [states state(env)]
        actions = [actions action]
        rewards = [rewards reward(env)]
        episode_reward += reward(env)

        #policy(POST_ACT_STAGE, env)
        stop_criterion(policy, env) && return
    end
    # end of an episode
    # processing the data
    next_states = states[:,2:end]
    states = states[:,1:end-1]
    rewards = rewards[:,2:end]
    actions = actions[:,2:end]
    action_index = [(0,0)]
    for i = 1:length(actions)
        action_index = [action_index (actions[i],i)]
    end
    action_index = action_index[:,2:end]
    batch = Dict("states"=>states, "actions"=>actions, "rewards"=>rewards,
                 "next_states"=>next_states, "action_mask"=>action_index)
    
    println(episode_reward)
#     ss = states
#     aa = actions
#     rr = rewards
#     nst = next_states
#     aa_idx = action_index
#     bb = batch
    
    # update steps
    for i = 1:3
        update_learner(learner, batch)
    end
    #policy(POST_EPISODE_STAGE)
end

[32mProgress:  14%|█████▉                                   |  ETA: 0:01:14[39mm

9.0
9.0
8.0
9.0
9.0
9.0
8.0
9.0
8.0
8.0
9.0
8.0
9.0
8.0
7.0
9.0
9.0
8.0
9.0
9.0
9.0
9.0
9.0
7.0
8.0
8.0
9.0
9.0
9.0
9.0
9.0
8.0
8.0
8.0
8.0
8.0
9.0
9.0
9.0
9.0
7.0
10.0
9.0
8.0
8.0
9.0
8.0
9.0
8.0
9.0
10.0
9.0
9.0
8.0
9.0
9.0
9.0
8.0
9.0
9.0
8.0
8.0
9.0
9.0
10.0
9.0
8.0
7.0
9.0
8.0
9.0
9.0
8.0
8.0
9.0
8.0
8.0
8.0
9.0
7.0
7.0
9.0
8.0
9.0
8.0
9.0
8.0
9.0
7.0
7.0
7.0
9.0
9.0
8.0
9.0
8.0
9.0
8.0
9.0
8.0
9.0
9.0
9.0
8.0
9.0
9.0
9.0
7.0
7.0
7.0
8.0
10.0
9.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
9.0
8.0
7.0
8.0
9.0
9.0
7.0
7.0
8.0
9.0
9.0
7.0
9.0
10.0
8.0
7.0
8.0
10.0
8.0
8.0
7.0
8.0
9.0
7.0
8.0
9.0
8.0
9.0
9.0
9.0
8.0
9.0
9.0
9.0
7.0
8.0
8.0
9.0
8.0
8.0
9.0
9.0
9.0
7.0
7.0
9.0
9.0
8.0
9.0
9.0
8.0
8.0
8.0
8.0
9.0
8.0
8.0
9.0
9.0
10.0
8.0
10.0
10.0
9.0
9.0
9.0
9.0
8.0
8.0
8.0
7.0
8.0
9.0
7.0
9.0
8.0
8.0
9.0
9.0
8.0
8.0
9.0
8.0
8.0
9.0
7.0
7.0
8.0
8.0
8.0
7.0
7.0
9.0
9.0
9.0
9.0
9.0
8.0
8.0
9.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
10.0
9.0
11.0
11.0
10.0
9.0
10.0
8.0
8.0
9.0
7.0
8.0
9.0
9.0
9.0
9.0
8.0
7.0
8.0
8.

[32mProgress:  32%|█████████████                            |  ETA: 0:00:39[39m

30.0
42.0
53.0
47.0
35.0
36.0
43.0
31.0
31.0
39.0
33.0
37.0
31.0
35.0
36.0
27.0
26.0
29.0
40.0
31.0
30.0
33.0
30.0
40.0
43.0
47.0
33.0
40.0
47.0
33.0
28.0
70.0
42.0
37.0
38.0
46.0
27.0
40.0
43.0
36.0
29.0
32.0
47.0
44.0
59.0
33.0
37.0
55.0
30.0
31.0
41.0
38.0
35.0
45.0
43.0
31.0
31.0
32.0
39.0
34.0
20.0
40.0
68.0
36.0
40.0
40.0
49.0
50.0
55.0
33.0
30.0
41.0
43.0
37.0
54.0
55.0
29.0
40.0
33.0
52.0
54.0
54.0
38.0
50.0
30.0
27.0
49.0
58.0
48.0
43.0
46.0
65.0
128.0
47.0
45.0
39.0
28.0
19.0
14.0
11.0
10.0
11.0
12.0
22.0
13.0
12.0
11.0
27.0
38.0
13.0
11.0
12.0
16.0
29.0
53.0
40.0
24.0
32.0
41.0
29.0
23.0
27.0
33.0
42.0
22.0
47.0
41.0
30.0
37.0
28.0
26.0
25.0
26.0
29.0
31.0
40.0
29.0
32.0
25.0
24.0
40.0
40.0
35.0
29.0
32.0
35.0
35.0
37.0
30.0
23.0
23.0
26.0
32.0
21.0
28.0
51.0
36.0
33.0
36.0
40.0
22.0
51.0
29.0
29.0
32.0
46.0
50.0
32.0
48.0
24.0
29.0
67.0
29.0
27.0
43.0
29.0
22.0
69.0
46.0
34.0
29.0
45.0
59.0
52.0
41.0
33.0
28.0
29.0
47.0
39.0
29.0
50.0
78.0
39.0
53.0
34.0
32.0
29.0
30.0
46.0

[32mProgress:  46%|██████████████████▋                      |  ETA: 0:00:27[39m

9.0
8.0
7.0
9.0
7.0
9.0
10.0
8.0
8.0
10.0
9.0
8.0
8.0
8.0
9.0
9.0
8.0
9.0
8.0
7.0
7.0
11.0
9.0
8.0
9.0
11.0
9.0
9.0
9.0
10.0
8.0
8.0
8.0
8.0
8.0
9.0
8.0
7.0
7.0
8.0
8.0
9.0
9.0
8.0
8.0
8.0
9.0
7.0
9.0
8.0
8.0
8.0
10.0
7.0
8.0
9.0
9.0
9.0
7.0
9.0
10.0
9.0
9.0
8.0
9.0
11.0
9.0
8.0
8.0
9.0
13.0
9.0
10.0
11.0
9.0
8.0
8.0
8.0
9.0
9.0
9.0
7.0
9.0
9.0
8.0
7.0
9.0
8.0
8.0
9.0
9.0
9.0
8.0
8.0
7.0
8.0
8.0
9.0
7.0
9.0
9.0
9.0
14.0
19.0
14.0
69.0
9.0
9.0
9.0
7.0
9.0
74.0
9.0
9.0
9.0
9.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
7.0
9.0
9.0
9.0
8.0
9.0
7.0
8.0
14.0
9.0
8.0
8.0
8.0
8.0
9.0
8.0
9.0
9.0
8.0
7.0
9.0
7.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
9.0
8.0
9.0
9.0
8.0
7.0
8.0
9.0
8.0
10.0
50.0
15.0
10.0
9.0
8.0
9.0
8.0
9.0
11.0
9.0
9.0
8.0
9.0
8.0
9.0
10.0
9.0
9.0
9.0
9.0
9.0
9.0
8.0
9.0
9.0
7.0
9.0
10.0
9.0
8.0
8.0
10.0
8.0
7.0
8.0
9.0
7.0
8.0
9.0
9.0
8.0
11.0
8.0
8.0
8.0
9.0
7.0
9.0
8.0
7.0
7.0
7.0
8.0
8.0
8.0
9.0
8.0
9.0
8.0
8.0
7.0
9.0
7.0
9.0
8.0
9.0
7.0
8.0
8.0
8.0
8.0
9.0
9.0
7.0
9.0
9.0
9.0
9.0
8.0
9.0
8.0
8.

[32mProgress:  64%|██████████████████████████▍              |  ETA: 0:00:15[39m

9.0
8.0
9.0
9.0
8.0
7.0
8.0
8.0
8.0
7.0
12.0
13.0
14.0
9.0
8.0
9.0
10.0
7.0
9.0
9.0
7.0
31.0
8.0
9.0
9.0
9.0
8.0
8.0
9.0
8.0
9.0
9.0
9.0
8.0
7.0
9.0
8.0
8.0
9.0
8.0
8.0
9.0
9.0
8.0
11.0
11.0
9.0
9.0
8.0
8.0
9.0
9.0
10.0
57.0
8.0
8.0
7.0
9.0
9.0
8.0
9.0
9.0
17.0
41.0
26.0
16.0
7.0
8.0
8.0
7.0
7.0
9.0
9.0
10.0
9.0
9.0
8.0
9.0
9.0
7.0
9.0
9.0
8.0
8.0
8.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
8.0
7.0
9.0
9.0
8.0
7.0
8.0
9.0
9.0
8.0
8.0
8.0
7.0
8.0
8.0
7.0
8.0
8.0
9.0
9.0
9.0
8.0
9.0
7.0
7.0
9.0
9.0
9.0
11.0
8.0
9.0
7.0
8.0
9.0
48.0
9.0
8.0
9.0
9.0
8.0
9.0
9.0
7.0
9.0
7.0
8.0
9.0
8.0
7.0
8.0
9.0
7.0
8.0
9.0
8.0
9.0
9.0
7.0
8.0
8.0
9.0
9.0
8.0
9.0
8.0
9.0
8.0
9.0
9.0
8.0
7.0
8.0
8.0
8.0
9.0
9.0
9.0
11.0
8.0
9.0
8.0
9.0
8.0
9.0
9.0
8.0
9.0
9.0
8.0
8.0
8.0
9.0
9.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
8.0
9.0
9.0
8.0
8.0
8.0
9.0
10.0
8.0
9.0
7.0
9.0
10.0
8.0
8.0
8.0
8.0
10.0
9.0
9.0
8.0
9.0
8.0
8.0
9.0
8.0
8.0
8.0
8.0
9.0
8.0
7.0
8.0
8.0
7.0
8.0
9.0
7.0
11.0
8.0
9.0
7.0
7.0
8.0
8.0
9.0
7.0

[32mProgress:  79%|████████████████████████████████▎        |  ETA: 0:00:09[39m

8.0
10.0
9.0
9.0
8.0
9.0
7.0
8.0
9.0
9.0
8.0
8.0
8.0
8.0
8.0
9.0
9.0
9.0
9.0
8.0
7.0
7.0
9.0
9.0
9.0
9.0
9.0
9.0
8.0
9.0
9.0
9.0
9.0
8.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
9.0
8.0
8.0
8.0
9.0
9.0
9.0
8.0
7.0
8.0
8.0
9.0
9.0
7.0
8.0
9.0
8.0
8.0
9.0
8.0
8.0
8.0
8.0
9.0
8.0
9.0
10.0
8.0
8.0
7.0
9.0
9.0
8.0
7.0
9.0
7.0
8.0
8.0
9.0
9.0
9.0
8.0
9.0
9.0
9.0
9.0
8.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
9.0
9.0
8.0
9.0
8.0
8.0
9.0
7.0
9.0
8.0
9.0
9.0
9.0
9.0
9.0
9.0
7.0
8.0
9.0
8.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0
10.0
8.0
8.0
8.0
8.0
9.0
8.0
9.0
9.0
8.0
8.0
10.0
9.0
9.0
9.0
10.0
8.0
8.0
8.0
10.0
8.0
8.0
7.0
9.0
16.0
8.0
9.0
9.0
9.0
8.0
9.0
10.0
8.0
9.0
8.0
10.0
10.0
7.0
10.0
7.0
11.0
8.0
8.0
9.0
7.0
9.0
8.0
12.0
14.0
14.0
10.0
9.0
9.0
7.0
10.0
7.0
9.0
8.0
17.0
9.0
9.0
8.0
9.0
10.0
9.0
9.0
9.0
8.0
8.0
7.0
10.0
8.0
8.0
8.0
9.0
8.0
7.0
8.0
9.0
8.0
9.0
8.0
9.0
7.0
8.0
9.0
8.0
7.0
8.0
8.0
10.0
9.0
9.0
9.0
9.0
9.0
8.0
7.0
7.0
9.0
8.0
9.0
8.0
9.0
9.0
8.0
9.0
8.0
8.0
9.0
7.0
8.0
10.0
9.0
8.0
9.

[32mProgress:  97%|███████████████████████████████████████▋ |  ETA: 0:00:01[39m

9.0
9.0
8.0
8.0
8.0
22.0
7.0
9.0
7.0
7.0
7.0
8.0
9.0
10.0
8.0
8.0
9.0
8.0
8.0
9.0
9.0
7.0
9.0
8.0
8.0
8.0
8.0
10.0
8.0
8.0
7.0
9.0
17.0
8.0
9.0
8.0
8.0
10.0
7.0
9.0
7.0
7.0
8.0
9.0
9.0
8.0
7.0
7.0
9.0
9.0
8.0
7.0
9.0
8.0
9.0
9.0
10.0
10.0
8.0
8.0
9.0
9.0
7.0
8.0
8.0
9.0
8.0
7.0
9.0
9.0
9.0
9.0
9.0
9.0
9.0
8.0
10.0
7.0
8.0
8.0
10.0
13.0
19.0
13.0
9.0
9.0
9.0
7.0
9.0
8.0
10.0
9.0
8.0
9.0
8.0
7.0
8.0
8.0
8.0
8.0
7.0
9.0
8.0
8.0
8.0
9.0
8.0
8.0
7.0
9.0
8.0
9.0
7.0
10.0
8.0
8.0
9.0
8.0
9.0
10.0
9.0
7.0
9.0
9.0
7.0
9.0
8.0
9.0
10.0
8.0
9.0
8.0
8.0
11.0
8.0
8.0
7.0
8.0
7.0
9.0
9.0
9.0
7.0
9.0
8.0
8.0
8.0
8.0
9.0
7.0
9.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
8.0
9.0
8.0
8.0
8.0
9.0
8.0
7.0
9.0
8.0
8.0
8.0
8.0
9.0
8.0
8.0
9.0
9.0
8.0
9.0
8.0
8.0
8.0
9.0
10.0
8.0
8.0
10.0
9.0
8.0
8.0
10.0
12.0
9.0
8.0
9.0
9.0
9.0
8.0
8.0
9.0
8.0
9.0
7.0
9.0
8.0
7.0
10.0
21.0
9.0
9.0
8.0
8.0
9.0
8.0
9.0
8.0
9.0
7.0
7.0
8.0
16.0
28.0
13.0
7.0
8.0
9.0
8.0
7.0
9.0
16.0
9.0
7.0
7.0
8.0
8.0
9.0
9.0
8.0
9.0
9.0
8.0
8.0
9.0
8.0
9

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:44[39m


160.0
89.0
73.0
145.0
170.0
138.0
75.0
134.0
101.0
134.0
120.0
102.0
92.0
129.0
200.0
78.0
135.0
125.0
166.0
137.0
95.0
128.0
153.0
98.0
87.0
84.0
118.0
110.0
90.0
114.0
188.0
127.0
151.0
140.0
104.0
119.0
146.0
113.0
125.0
76.0
138.0
162.0
81.0
126.0
109.0
94.0
161.0
87.0
151.0
132.0
127.0
107.0
147.0
141.0
128.0
131.0
136.0
155.0
138.0
133.0
109.0
103.0
137.0
149.0
126.0
76.0
116.0
85.0
119.0
79.0
118.0
133.0
95.0
113.0
101.0
157.0
134.0
70.0
119.0
105.0
100.0
113.0
144.0
97.0
87.0
90.0
83.0
86.0
118.0
87.0
98.0
77.0
140.0
99.0
90.0
135.0
123.0
90.0
85.0
84.0
113.0
96.0
93.0
84.0
139.0
115.0
157.0
128.0
138.0
119.0
131.0
100.0
147.0
88.0
104.0
112.0
82.0
81.0
86.0
146.0
111.0
108.0
76.0
107.0
83.0
159.0
84.0
86.0
89.0
79.0
102.0
140.0
122.0
79.0
113.0
106.0
82.0
147.0
140.0
118.0
143.0
139.0
123.0
110.0
94.0
99.0
185.0
139.0
160.0
80.0
143.0
146.0
100.0
93.0
107.0
141.0
122.0
137.0
126.0
83.0
129.0
130.0
121.0
119.0
139.0
127.0
157.0
124.0
170.0
97.0
146.0
107.0
148.0
91.0
142.0
120.