In [1]:
using Flux
using Zygote
using PyCall
using DataStructures
using StatsBase
using Printf
using LinearAlgebra
using Distributions
using DistributionsAD


In [17]:
hparams = Dict([
    ("lr", 1e-4),
    ("env", "Pendulum-v0"),
    ("policy_layer1", 30),
    ("policy_layer2", 30),
    ("value_layer1", 30),
    ("value_layer2", 30),
    ("q_layer1", 30),
    ("q_layer2", 30),
    ("activation", tanh),
    ("target_update", 1e-4),
    ("entropy_incentive", 1e-2),
    ("l2_reg", 1e-4),
    ("batch_size", 64),
    ("discount_factor", 0.99),
    ("buffer_size", 100000),
    ("epochs", 100),
    ("steps_per_epoch", 2000),
    ("train_steps_per_iter", 2)
])



Dict{String,Any} with 18 entries:
  "value_layer2"         => 30
  "lr"                   => 0.0001
  "policy_layer2"        => 30
  "discount_factor"      => 0.99
  "train_steps_per_iter" => 2
  "batch_size"           => 64
  "steps_per_epoch"      => 2000
  "env"                  => "Pendulum-v0"
  "buffer_size"          => 100000
  "l2_reg"               => 0.0001
  "q_layer1"             => 30
  "target_update"        => 0.0001
  "policy_layer1"        => 30
  "value_layer1"         => 30
  "activation"           => tanh
  "epochs"               => 100
  "entropy_incentive"    => 0.01
  "q_layer2"             => 30

In [3]:
gym = pyimport("gym")
env = gym.make(hparams["env"])


STATE_SPACE = length(env.observation_space.low)
ACTION_SPACE = length(env.action_space.low)
TARGET_UPDATE = hparams["target_update"]
ENTROPY_INCENTIVE = hparams["entropy_incentive"]
L2_REG = hparams["l2_reg"]
BATCH_SIZE = hparams["batch_size"]
GAMMA = hparams["discount_factor"]
EPOCHS = hparams["epochs"]
STEPS_PER_EPOCH = hparams["steps_per_epoch"]
TRAIN_STEPS_PER_ITER = hparams["train_steps_per_iter"]

2

In [4]:
optim = Flux.ADAM(hparams["lr"])

ADAM(0.0001, (0.9, 0.999), IdDict{Any,Any}())

In [5]:
memory = CircularBuffer{Tuple{Array{Float32,1}, Array{Float32, 1}, Float32, Array{Float32,1}, Bool}}(hparams["buffer_size"])

0-element CircularBuffer{Tuple{Array{Float32,1},Array{Float32,1},Float32,Array{Float32,1},Bool}}

In [6]:
value = Chain(
    Dense(STATE_SPACE, hparams["value_layer1"], hparams["activation"]),
    LayerNorm(hparams["value_layer1"]),
    Dense(hparams["value_layer1"], hparams["value_layer2"], hparams["activation"]),
    LayerNorm(hparams["value_layer2"]),
    Dense(hparams["value_layer2"], 1),
)

value_target = deepcopy(value)

Chain(Dense(3, 30, tanh), LayerNorm(30), Dense(30, 30, tanh), LayerNorm(30), Dense(30, 1))

In [7]:
critics = map((_) -> Chain(
    Dense(STATE_SPACE+ACTION_SPACE, hparams["q_layer1"], hparams["activation"]),
    LayerNorm(hparams["q_layer1"]),
    Dense(hparams["q_layer1"], hparams["q_layer2"], hparams["activation"]),
    LayerNorm(hparams["q_layer2"]),
    Dense(hparams["q_layer2"], 1),
), 1:2)


# Single value mode
function forward_critic(models, state::Array{T,1}, action::Array{T,1}) where {T}
    pred = map((model) -> model(vcat(state, action)), models)
    minimum(Flux.stack(pred, 1))
end
# Batch mode
function forward_critic(models, state::Array{T,2}, action::Array{T,2}) where {T}
    pred = map((model) -> dropdims(model(vcat(state, action)), dims=1), models)
    minimum(Flux.stack(pred, 1), dims=1)
end

forward_critic (generic function with 2 methods)

In [8]:
policy = Chain(
    Dense(STATE_SPACE, hparams["policy_layer1"], hparams["activation"]),
    LayerNorm(hparams["policy_layer1"]),
    Dense(hparams["policy_layer1"], hparams["policy_layer2"], hparams["activation"]),
    LayerNorm(hparams["policy_layer2"]),
    Dense(hparams["policy_layer2"], ACTION_SPACE*2),
)

function forward_policy(model, state::Array{T, 2}) where {T}
    x = policy(state)
    mean = x[1:ACTION_SPACE,:]
    std = exp.(x[ACTION_SPACE+1:end,:])
    return mean, std
end
function forward_policy(model, state::Array{T, 1}) where {T}
    x = policy(state)
    mean = x[1:ACTION_SPACE]
    std = exp.(x[ACTION_SPACE+1:end])
    return mean, std
end

function forward_policy_dist(model, state::Array{T, 1}) where {T}
    MvNormal(forward_policy(model, state)...)
end
function forward_policy_dist(model, state::Array{T, 2}) where {T}
    mean, std = forward_policy(model, state)
    map((x) -> MvNormal(x[1], x[2]), zip(eachcol(mean), eachcol(std)))
end
function forward_policy_logprob(model, state, action)
    dists = forward_policy_dist(policy, state)
    logprob = Flux.batch(map((x) -> logpdf(x[1], x[2]), zip(dists, eachcol(action))))
end
function logprob(mean, std, action)
    logpdf(TuringDiagMvNormal(mean, std), action)
end
function logprobs(mean, std, actions)
    Flux.unsqueeze(map((i) -> logprob(mean[:,i], std[:,i], actions[:,i]), 1:size(actions)[2]), 1)
end

logprobs (generic function with 1 method)

In [9]:
function update_value_target!(value_target, value)
    for (p, p_target) in zip(Flux.params(value), Flux.params(value_target))
       p_target .= (1-TARGET_UPDATE) .* p_target .+ TARGET_UPDATE .* p
    end
end

function update_value!(value, critics, policy, optim, states::Array{T}) where {T}
    parameters = Flux.params(value)
    outer_loss = 0
    
    dist = forward_policy_dist(policy, states)
    actions = T.(Flux.batch([rand(d) for d in dist]))
    p = Flux.batch([logpdf(d, a) for (d,a) in zip(dist, eachcol(actions))])
    q = forward_critic(critics, states, actions)
    target = q .- ENTROPY_INCENTIVE * p
    
    #println(mean(target), mean(q), mean(p))
    
    grads = gradient(parameters) do
        loss = Flux.huber_loss(value(states), target) + L2_REG * sum(norm, parameters)
        outer_loss += loss
        return loss
    end

    Flux.update!(optim, parameters, grads)
    outer_loss
end

update_value! (generic function with 1 method)

In [10]:
function update_critic!(critic, value_target, optim, states, actions, rewards, next_states, deaths)
    target = rewards .+ GAMMA .* (.!deaths) .* value_target(next_states)
    parameters = Flux.params(critic)
    outer_loss = 0
    
    grads = gradient(parameters) do
        loss = Flux.huber_loss(critic(vcat(states, actions)), target)
        outer_loss += loss
        return loss
    end
    
    Flux.update!(optim, parameters, grads)
    outer_loss
end

update_critic! (generic function with 1 method)

In [11]:
function update_policy!(policy, critics, optim, states::Array{T}) where {T}
    parameters = Flux.params(policy)
    outer_loss = 0
    batch_size = size(states)[2]
    unitnoise = T.(rand(Normal(), (ACTION_SPACE, batch_size)))
    
    grads = gradient(parameters) do
        mean, std = forward_policy(policy, states)
        actions = unitnoise .* std .+ mean
        q = forward_critic(critics, states, actions)
        is = ENTROPY_INCENTIVE * logprobs(mean, std, actions)
        loss = sum(is .- q) / batch_size
        outer_loss += loss
        return loss
    end
    
    Flux.update!(optim, parameters, grads)
    outer_loss
end

update_policy! (generic function with 1 method)

In [12]:
function batch(memory)
    batch = vcat([sample(memory) for _ in 1:BATCH_SIZE])
    # Destructure the batch
    state, action, reward, next_state, death = [getindex.(batch, i) for i in 1:5]
    state = Flux.batch(state)
    next_state = Flux.batch(next_state)
    action = Flux.batch(action)
    return state, action, reward, next_state, death
end

batch (generic function with 1 method)

In [15]:
for epoch in 1:EPOCHS
    total_reward = 0
    total_deaths = 0
    total_v_loss = 0
    total_q1_loss = 0
    total_q2_loss = 0
    total_policy_loss = 0
    total_iterations = 0
    state = Float32.(env.reset())
    for i in 1:STEPS_PER_EPOCH
        action = rand(forward_policy_dist(policy, state)) # Act
        action = clamp(action, env.action_space.low, env.action_space.high)
        next_state, reward, death, _ = env.step(action) # Advance the env

        # Convert to Float32
        next_state = Float32.(next_state)
        reward = Float32(reward)

        push!(memory, (state, action, reward, next_state, death))
        total_reward += reward
        if death
            state = env.reset()
            total_deaths += 1
        else
            state = next_state
        end
        
        if length(memory) > BATCH_SIZE
            for i in 1:TRAIN_STEPS_PER_ITER
                states, actions, rewards, next_states, deaths = batch(memory)
                total_v_loss += update_value!(value, critics, policy, optim, states)
                total_q1_loss += update_critic!(critics[1], value_target, optim, states, actions, rewards, next_states, deaths)
                total_q2_loss += update_critic!(critics[2], value_target, optim, states, actions, rewards, next_states, deaths)
                total_policy_loss += update_policy!(policy, critics, optim, states)
                update_value_target!(value_target, value)
                total_iterations += 1
            end
        end
    end
    
    println(@sprintf("I: %d, r: %f, v: %f, q1: %f, q2: %f, p: %f", 
            epoch,
            total_reward/(total_deaths+1),
            total_v_loss/total_iterations, 
            total_q1_loss/total_iterations, 
            total_q2_loss/total_iterations, 
            total_policy_loss/total_iterations))
    flush(stdout)
end

I: 1, r: -601.944580, v: 0.108978, q1: 2.885548, q2: 2.887355, p: 4.719818
I: 2, r: -901.188904, v: 0.432230, q1: 2.852739, q2: 2.855324, p: 3.707046
I: 3, r: -684.379639, v: 0.409662, q1: 2.711770, q2: 2.716033, p: 3.869687
I: 4, r: -505.849304, v: 0.465882, q1: 2.774301, q2: 2.777215, p: 3.967479
I: 5, r: -833.089539, v: 0.457431, q1: 2.757236, q2: 2.759940, p: 4.043257
I: 6, r: -761.569580, v: 0.450687, q1: 2.659956, q2: 2.662855, p: 4.220744
I: 7, r: -645.099609, v: 0.448799, q1: 2.594251, q2: 2.595930, p: 4.325452
I: 8, r: -571.018677, v: 0.455499, q1: 2.671499, q2: 2.672621, p: 4.318037
I: 9, r: -755.985779, v: 0.413188, q1: 2.635601, q2: 2.636891, p: 4.340734
I: 10, r: -582.337952, v: 0.340242, q1: 2.580169, q2: 2.580998, p: 4.404067
