# Reinforcement Learning (DQN) on Raw Pixel Data

**Author:** Kyle Daruwalla

This is an attempt to recreate the PyTorch DQN tutorial in Julia using Flux. We will be working the the `CartPole-v0` environment where the learning agent tries to balance a pole on a cart by pushing the cart left or right.

In [None]:
# resolve imports
using Plots, Images, ProgressMeter, IJulia
using Flux, CuArrays, Zygote
using Gym
using DataStructures: CircularBuffer
using Random: shuffle
using Distributions: sample

## Model Creation

First, we define our DQN model using the standard layers found [here](https://fluxml.ai/Flux.jl/stable/models/layers/). The model takes the difference between current and previous screen patches, then outputs the Q-values for each action: $Q(s, \text{left})$ and $Q(s, \text{right})$.

In [None]:
Flux.outdims(::BatchNorm, isize) = isize

function createdqn(height, width)
    ksize = (3, 3)
    ssize = 2
    lsizes = [1, 16, 32, 32] # channel sizes passing through conv layers
    
    # conv layers
    convs = [
        # 3x3 kernel applied to 3 channel input w/ 16 channel output
        Conv(ksize, lsizes[1] => lsizes[2], stride = ssize),
        BatchNorm(lsizes[2], relu),
        Conv(ksize, lsizes[2] => lsizes[3], stride = ssize),
        BatchNorm(lsizes[3], relu),
        Conv(ksize, lsizes[3] => lsizes[4], stride = ssize),
        BatchNorm(lsizes[4], relu)
    ]
    
    ffwidth, ffheight = foldl((i, m) -> Flux.outdims(m, i), convs; init = (width, height))
    ffsize = Int(ffwidth * ffheight * lsizes[4])
    Chain(
        convs...,
        x -> reshape(x, :, size(x, 4)),
        Dense(ffsize, 2)
    ) |> gpu
end

## Input Processing

Gym.jl returns an $c \times h \times w$ array. We need to transform this into a $w \times h \times c$ array which is how Flux interprets RGB arrays. We'll also crop the image, since most of the pixels are just white background. Lastly, we resize the image, because the original size of $400 \times 600$ would consume too much memory for training.

In [None]:
include("showenv.jl")
reqwidth = 40
        
# Instantiate the env
env = make("CartPole-v0", :rgb, true)
reset!(env)
screen, _ = showenv(env, reqwidth)
width, height, nchannels, _ = size(screen)
println("width = $(width), height = $(height), channels = $(nchannels)")
colorview(Gray, permutedims(cpu(screen[:, :, 1, 1]), [2, 1]))

## Memory and Replay

Typically, as an episode progresses, all the state transitions are stored in a trace. The NN implementing the policy is trained in batches over this trace. Here, we define a convenient way of storing each transition which is defined by
1. **state**: the current state
2. **action**: the action taken
3. **reward**: the reward received
4. **nextstate**: the state transitioned to

In [None]:
struct Transition
    state :: CuArray{Float32}
    action :: Int64
    reward :: Float64
    nextstate :: CuArray{Float32}
    done :: Bool
end

We also defined the concept of a trace as a "memory". This is simply a circular buffer for storing transitions. You can `push!` onto a memory as well as randomly `sample` from it.

In [None]:
tracelength = 10000
memory = CircularBuffer{Transition}(tracelength)

## Training

To train the model, we define two DQNs. First, we have a policy net. This network is responsible for calculating $Q(s, a)$. We make decisions based off the policy net. For stability, we also have a target net. The target net is used to compute $V(s_{t + 1}) = \max_a Q(s_{t + 1}, a)$. We only train the policy net every transition. The target net simply copies the weights off the policy net over at the end of each episode.

In [None]:
function copyweights!(source, sink)
    Flux.loadparams!(sink, Flux.params(source))
end

policynet = createdqn(height, width)
targetnet = createdqn(height, width)
copyweights!(policynet, targetnet)

The loss function is defined as follows
$$L(B) = \frac{1}{|B|} \sum_{\delta \in B} \ell(\delta)$$

$$\ell(\delta) = \begin{cases}
    \frac{1}{2} \delta^2 & |\delta| \leq 1 \\
    |\delta| - \frac{1}{2} & \text{otherwise}
\end{cases}$$

$$\delta = Q(s, a) - (r(s) + \gamma V(s_{t + 1}))$$

In [None]:
const γ = 0.999

# helper functions to get the Q and V values for transitions
function Q(transition::Transition; policy)
    q = policy(transition.state)[transition.action]
    return q
end
function V(transition::Transition; target)
    r = transition.reward
    v = transition.done ? 0f0 : maximum(target(transition.nextstate))
    r + γ * v
end

huber(δ) = sum(map(x -> abs(x) <= 1 ? 0.5 * x^2 : abs(x) - 0.5, δ)) / length(δ)
l(q, v) = huber(q .- v)

We will optimize according to ADAM with default rates.

In [None]:
optim = RMSProp()

## RL Policy

We use a standard Q-learning policy. With probability $\epsilon$, we perform a random action, and with probability $1 - \epsilon$, we perform the action dictated by the policy net. $\epsilon$ starts at 0.9 and decays exponentially to 0.05.

In [None]:
const ϵ_start = 0.9
const ϵ_end = 0.05
const ϵ_decay = 200
function selectaction(state, iter; policy) :: Int64
    e = rand()
    ϵ = ϵ_end + (ϵ_start - ϵ_end) * exp(-1. * iter / ϵ_decay)
    if e > ϵ
        Flux.onecold(policy(state), [1, 2])[1]
    else
        rand([1, 2])
    end
end

## Model Execution

We now train our model by running through multiple episodes.

In [None]:
batchsize = 128
nepisodes = 500
iter = 0
progbar = Progress(nepisodes)
epruntime = []
loss = []
currloss = 0
target_update_rate = 20
target_update_init = target_update_rate
target_update_decay = 5
update_idx = 1 + target_update_rate
plot_update_rate = 100
for i in 1:nepisodes
    reset!(env)

    last_screen, _ = showenv(env, reqwidth)
    curr_screen, _ = showenv(env, reqwidth)
    state = curr_screen .- last_screen
    
    T = 0
    done = false
    while !done
        # take an action based on the policy net
        action = selectaction(state, iter; policy = policynet)
        _, reward, done, _ = step!(env, action)
        
        # push most recent transition onto memory trace
        last_screen = curr_screen
        curr_screen, _ = showenv(env, reqwidth)
        nextstate = curr_screen - last_screen
        transition = Transition(state, action, reward, nextstate, done)
        push!(memory, transition)
        state = nextstate
        
        # only train if the trace is at least batch size
        if length(memory) >= batchsize
            # sample a batch worth of transitions
            batch = sample(memory, batchsize, replace = false)
            qs = cat(Q.(batch; policy = policynet)..., dims=2)
            vs = cat(V.(batch; target = targetnet)..., dims=2)

            # update the model based on batch
            weights = Flux.params(policynet)
            Zygote.hook(x -> min(1, max(-1, x)), weights)
            Flux.train!(l, weights, [(qs, vs)], optim)
            currloss = cpu(l(qs, vs))
            push!(loss, currloss)
        end
        
        T += 1
        iter += 1
    end
    
    push!(epruntime, T)
    
    # update target net based on policy net
    if target_update_rate < 0
        copyweights!(policynet, targetnet)
    elseif i == update_idx
        copyweights!(policynet, targetnet)
        target_update_rate -= target_update_decay
        update_idx = i + target_update_rate
    end
    
    if i < 100
        display(plot(plot(loss, title="Loss", xlabel="Iteration", ylabel="Loss", label=""),
                     plot(epruntime, title="Total Reward", xlabel="Episode #", ylabel="Reward", label=""),
                layout=2, size=(1200, 400)))
        next!(progbar)
        IJulia.clear_output(true)
        IJulia.flush_all()
    elseif i % plot_update_rate == 0
        IJulia.clear_output(true)
        display(plot(plot(loss, title="Loss", xlabel="Iteration", ylabel="Loss", label=""),
                     plot(epruntime, title="Total Reward", xlabel="Episode #", ylabel="Reward", label=""),
                layout=2, size=(1200, 400)))
        next!(progbar)
        IJulia.flush_all()
    else
        next!(progbar)
        IJulia.flush_all()
    end
end

## Results

Here, we plot the runtime of each episode. Theoretically, the runtime should increase as the agent learns the environment.

In [None]:
avgruntime = [(1 / min(n, 100)) * sum(epruntime[max(1, n - 100):n]) for n in eachindex(epruntime)]
plot(epruntime, title="Episode Runtime", xlabel="Episode #", ylabel="Runtime (iterations)", label="Runtime")
plot!(avgruntime, label="Avg. Runtime")

In [None]:
filename_base = "dqn-raw-tu$(target_update_init)-$(target_update_decay)-b$(batchsize)-t$(nepisodes)"
savefig(filename_base * "-runtime.png")
savefig(plot(loss, title="Loss", xlabel="Iteration", ylabel="Loss", label=""), filename_base * "-loss.png")

In [None]:
using DelimitedFiles

open(filename_base * "-runtime.csv", "w") do io
    writedlm(io, ep_runtime, ',')
end
open(filename_base * "-loss.csv", "w") do io
    writedlm(io, loss, ',')
end