In [1]:
using Flux
using Zygote
using PyCall
using DataStructures
using StatsBase
using Printf
using LinearAlgebra
using Distributions
using DistributionsAD
using Test
using BSON


In [2]:
layer_size = 32

hparams = Dict([
    ("lr", 1e-3),
    ("env", "CartPoleContinuousBulletEnv-v0"),
    ("policy_layer1", layer_size),
    ("policy_layer2", layer_size),
    ("value_layer1", layer_size),
    ("value_layer2", layer_size),
    ("q_layer1", layer_size),
    ("q_layer2", layer_size),
    ("activation", swish),
    ("target_update", 1e-3),
    ("entropy_incentive", 0.2),
    ("l2_reg", 1e-4),
    ("batch_size", 100),
    ("discount_factor", 0.99),
    ("buffer_size", 1000000),
    ("epochs", 100),
    ("steps_per_epoch", 2000),
    ("train_steps_per_iter", 2)
])



Dict{String,Any} with 18 entries:
  "value_layer2"         => 32
  "lr"                   => 0.001
  "policy_layer2"        => 32
  "discount_factor"      => 0.99
  "train_steps_per_iter" => 2
  "batch_size"           => 100
  "steps_per_epoch"      => 2000
  "env"                  => "CartPoleContinuousBulletEnv-v0"
  "buffer_size"          => 1000000
  "l2_reg"               => 0.0001
  "q_layer1"             => 32
  "target_update"        => 0.001
  "policy_layer1"        => 32
  "value_layer1"         => 32
  "activation"           => swish
  "epochs"               => 100
  "entropy_incentive"    => 0.2
  "q_layer2"             => 32

In [3]:
pyimport("pybullet_envs")
gym = pyimport("gym")
env = gym.make(hparams["env"])


STATE_SPACE = length(env.observation_space.low)
ACTION_SPACE = length(env.action_space.low)
ACTION_HIGH = env.action_space.high
ACTION_LOW = env.action_space.low
TARGET_UPDATE = hparams["target_update"]
ENTROPY_INCENTIVE = hparams["entropy_incentive"]
L2_REG = hparams["l2_reg"]
BATCH_SIZE = hparams["batch_size"]
GAMMA = hparams["discount_factor"]
EPOCHS = hparams["epochs"]
STEPS_PER_EPOCH = hparams["steps_per_epoch"]
TRAIN_STEPS_PER_ITER = hparams["train_steps_per_iter"]

pybullet build time: Jul  8 2020 18:24:12


2

In [4]:
optim = Flux.ADAM(hparams["lr"])
dtype = Float64

Float64

In [5]:
# Our squashed policy has a support of -1,1
# Project actions to/from that support
function action_to_support(action)
    halfspan = (ACTION_HIGH .- ACTION_LOW) ./ 2
    low_end = ACTION_LOW ./ halfspan
    action ./ halfspan .- low_end .- 1
end

function support_to_action(action)
    halfspan = (ACTION_HIGH .- ACTION_LOW) ./ 2
    low_end = ACTION_LOW ./ halfspan
    (action .+ 1 .+ low_end) .* halfspan
end

support_to_action (generic function with 1 method)

In [6]:
begin
    local lol1 = Flux.batch([rand(ACTION_SPACE) for _ in 1:100])
    @test isapprox(lol1, support_to_action(action_to_support(lol1)))
end


[32m[1mTest Passed[22m[39m

In [7]:
function initialize()
    
    typeswitch(T, x) = x
    typeswitch(T, x::Number) = T(x)
    typeswitch(T, x::AbstractArray) = T.(x)
    
    value = Chain(
        Dense(STATE_SPACE, hparams["value_layer1"], hparams["activation"]),
        #LayerNorm(hparams["value_layer1"]),
        Dense(hparams["value_layer1"], hparams["value_layer2"], hparams["activation"]),
        #LayerNorm(hparams["value_layer2"]),
        Dense(hparams["value_layer2"], 1),
    )
    value = Flux.fmap(x -> typeswitch(dtype, x), value)

    value_target = deepcopy(value)

    critics = map((_) -> Flux.fmap(x -> typeswitch(dtype, x), Chain(
        Dense(STATE_SPACE+ACTION_SPACE, hparams["q_layer1"], hparams["activation"]),
        #LayerNorm(hparams["q_layer1"]),
        Dense(hparams["q_layer1"], hparams["q_layer2"], hparams["activation"]),
        #LayerNorm(hparams["q_layer2"]),
        Dense(hparams["q_layer2"], 1),
    )), 1:2)

    policy = Chain(
        Dense(STATE_SPACE, hparams["policy_layer1"], hparams["activation"]),
        #LayerNorm(hparams["policy_layer1"]),
        Dense(hparams["policy_layer1"], hparams["policy_layer2"], hparams["activation"]),
        #LayerNorm(hparams["policy_layer2"]),
        Dense(hparams["policy_layer2"], ACTION_SPACE*2),
    )
    policy = Flux.fmap(x -> typeswitch(dtype, x), policy)

    
    memory = memory = CircularBuffer{Tuple{Array{dtype,1}, Array{dtype, 1}, dtype, Array{dtype,1}, Bool}}(hparams["buffer_size"])
    
    value, value_target, critics, policy, memory
end

initialize (generic function with 1 method)

In [8]:



# Single value mode
function forward_critic(models, state::Array{T,1}, action::Array{T,1}) where {T}
    pred = map((model) -> model(vcat(state, action)), models)
    minimum(Flux.stack(pred, 1))
end
# Batch mode
function forward_critic(models, state::Array{T,2}, action::Array{T,2}) where {T}
    pred = map((model) -> dropdims(model(vcat(state, action)), dims=1), models)
    minimum(Flux.stack(pred, 1), dims=1)
end

forward_critic (generic function with 2 methods)

In [112]:


function forward_policy(model, state::Array{T, 2}) where {T}
    x = policy(state)
    mu = x[1:ACTION_SPACE,:]
    sigma = exp.(x[ACTION_SPACE+1:end,:])
    return mu, sigma
end
function forward_policy(model, state::Array{T, 1}) where {T}
    x = policy(state)
    mu = x[1:ACTION_SPACE]
    sigma = exp.(x[ACTION_SPACE+1:end])
    return mu, sigma
end

# The SAC authors squash their action space
# See Appendix C in their paper
function squash(actions)
    tanh.(actions)
end
function policy_sample(mu::Array{T}, sigma::Array{T}) where {T}
    mu .+ sigma .* clamp.(T.(rand(Normal(), size(sigma))), -10, 10)
end
function policy_sample(model, state::Array)
    policy_sample(forward_policy(model, state)...)
end
function act(model, state)
    support_to_action(squash(policy_sample(model, state)))
end
function logprob(mu, sigma, action_unsquashed::Array{T}) where {T}
    # Add a small epsilon so we don't explode the loss
    logpdf(TuringDiagMvNormal(mu, sigma), action_unsquashed) - sum(log.(1 .- tanh.(action_unsquashed) .^ 2 .+ eps(T) ))
end
# function logprobs(mu, sigma, actions_unsquashed)
#     mapslices(
#         (x) -> logprob(x[1:ACTION_SPACE], x[ACTION_SPACE+1:ACTION_SPACE*2], x[ACTION_SPACE*2+1:end]), 
#         [mu;sigma;actions_unsquashed], dims=2
#     )
# end

# Copied from DistributionsAD
# function logprob_cuda(mu, sigma, x)
#     -(length(x) * log(2*pi) + 2 * sum(log.(sigma)) + sum(((x .- mu) ./ sigma) .^ 2)) / 2
# end

function logprobs_cuda(mu::AbstractArray{T, 2}, sigma::AbstractArray{T, 2}, x::AbstractArray{T, 2}) where {T}
    # A pretty intimidating function
    # The first half is the normal logpdf part, copied from DistributionsAD for a diagonal multivariate gaussian where the sums are respecting batching
    # The second part is needed because the function is squashed with tanh
    -(size(mu, 1) * log(2*pi) .+ 2 * sum(log.(sigma), dims=1) .+ sum(((x .- mu) ./ sigma) .^ 2, dims=1)) / 2 .- sum(log.(1 .- tanh.(x) .^ 2 .+ eps(T)), dims=1)
end

logprobs_cuda (generic function with 2 methods)

In [111]:
mu = rand(10)
sigma = rand(10) .^2
action = mu .+ sigma .* rand(Normal(), 10)

mus = Flux.batch([mu, mu, mu])
sigmas = Flux.batch([sigma, sigma, sigma])
actions = Flux.batch([action, action, action])

@test isapprox(logprob(mu, sigma, action), logprobs_cuda(mus, sigmas, actions)[3])

[32m[1mTest Passed[22m[39m

In [103]:
logprob(mu, sigma, action)

24.433451663275275

In [102]:
logprobs_cuda(mus, sigmas, actions)

1×3 Array{Float64,2}:
 24.4335  24.4335  24.4335

In [50]:
states = repeat(rand(STATE_SPACE), 1,100)
mu, sigma = forward_policy(policy, states)
a = policy_sample(mu, sigma)
y = logprobs(mu, sigma, a)


1×100 Array{Float64,2}:
 -0.557655  -0.735788  -0.741822  …  -0.530643  -0.734148  -0.570652

In [11]:
function update_value_target!(value_target, value)
    for (p, p_target) in zip(Flux.params(value), Flux.params(value_target))
        p_target .= (1-TARGET_UPDATE) .* p_target .+ TARGET_UPDATE .* p
        @assert !any(isnan.(p_target))
    end
end

function update_value!(value, critics, policy, optim, states::Array{T}) where {T}
    parameters = Flux.params(value)
    outer_loss = 0
    
    mu, sigma = forward_policy(policy, states)
    actions = policy_sample(mu, sigma)
    p = logprobs(mu, sigma, actions)
    q = forward_critic(critics, states, squash(actions))
    target = q .- ENTROPY_INCENTIVE .* p
    
    
    grads = gradient(parameters) do
        loss = Flux.mse(value(states), target)# + L2_REG * sum(norm, parameters)
        outer_loss += loss
        return loss
    end
    
    Flux.update!(optim, parameters, grads)
    @assert !any([any(isnan.(layer.W)) for layer in value])
    outer_loss
end

update_value! (generic function with 1 method)

In [23]:
function update_critic!(critic, value_target, optim, states, actions, rewards, next_states, deaths)
    parameters = Flux.params(critic)
    outer_loss = 0
    target = rewards .+ GAMMA .* (.!deaths) .* dropdims(value_target(next_states), dims=1)
    @assert !any(isnan.(target))
        
    grads = gradient(parameters) do
        loss = Flux.mse(critic(vcat(states, actions)), target)# + L2_REG * sum(norm, parameters)
        outer_loss += loss
        return loss
    end
    
    Flux.update!(optim, parameters, grads)
    
    @assert !any([any(isnan.(layer.W)) for layer in critic])
    outer_loss
end

update_critic! (generic function with 1 method)

In [13]:
function update_policy!(policy, critics, optim, states::Array{T}) where {T}
    parameters = Flux.params(policy)
    outer_loss = 0
    batch_size = size(states)[2]
    unitnoise = T.(rand(Normal(), (ACTION_SPACE, batch_size)))
    unitnoise = clamp.(unitnoise, -10, 10) # Don't allow crazy outliers
    
    grads = gradient(parameters) do
        mu, sigma = forward_policy(policy, states)
        actions = mu .+ sigma .* unitnoise
        q = forward_critic(critics, states, squash(actions))
        is = ENTROPY_INCENTIVE * logprobs(mu, sigma, actions)
        loss = sum(is .- q) / batch_size #+ L2_REG * sum(norm, parameters)
        outer_loss += loss
        return loss
    end
        
    Flux.update!(optim, parameters, grads)
    @assert !any([any(isnan.(layer.W)) for layer in policy])
    outer_loss
end

update_policy! (generic function with 1 method)

huhu


5.5002e-5

In [14]:
function batch(memory)
    batch = vcat([sample(memory) for _ in 1:BATCH_SIZE])
    
    # Destructure the batch
    state, action, reward, next_state, death = [getindex.(batch, i) for i in 1:5]
    state = Flux.batch(state)
    next_state = Flux.batch(next_state)
    action = Flux.batch(action)
    
    @assert !any(isnan.(state))
    @assert !any(isnan.(action))
    
    return state, action, reward, next_state, death
end

batch (generic function with 1 method)

In [15]:
function collect_experience!(env, memory, steps)
  state = dtype.(env.reset())
  total_reward = 0
  total_deaths = 0
  for _ in 1:steps
    action = dtype.(clamp(rand(MvNormal(ACTION_LOW, ACTION_HIGH)), ACTION_LOW, ACTION_HIGH))
    next_state, reward, death, _ = env.step(action) # Advance the env

    # Convert to dtype
    next_state = dtype.(next_state)
    reward = dtype(reward)
    total_reward += reward

    push!(memory, (state, action, reward, next_state, death))
    
    if death
        state = env.reset()
        total_deaths += 1
    end
  end
  total_reward / (total_deaths + 1)
end


collect_experience! (generic function with 1 method)

In [16]:

function train!(env, memory, optim, policy, critics, value, value_target)
  for epoch in 1:EPOCHS
    total_reward = 0
    total_deaths = 0
    total_v_loss = 0
    total_q1_loss = 0
    total_q2_loss = 0
    total_policy_loss = 0
    total_entropy = 0
    total_iterations = 0
    state = dtype.(env.reset())
    for i in 1:STEPS_PER_EPOCH
        action = act(policy, state) # Act
        next_state, reward, death, _ = env.step(action) # Advance the env

        # Convert to Float32
        next_state = dtype.(next_state)
        reward = dtype(reward)

        push!(memory, (state, action, reward, next_state, death))
        total_reward += reward
        if death
            state = env.reset()
            total_deaths += 1
        else
            state = next_state
        end

        if length(memory) > BATCH_SIZE
            for i in 1:1
                states, actions, rewards, next_states, deaths = batch(memory)
                actions = action_to_support(actions)
                total_v_loss += update_value!(value, critics, policy, optim, states)
                total_q1_loss += update_critic!(critics[1], value_target, optim, states, actions, rewards, next_states, deaths)
                total_q2_loss += update_critic!(critics[2], value_target, optim, states, actions, rewards, next_states, deaths)
                total_policy_loss += update_policy!(policy, critics, optim, states)
                update_value_target!(value_target, value)
                @assert(!isnan(total_v_loss))
                @assert(!isnan(total_q1_loss))
                @assert(!isnan(total_q2_loss))
                @assert(!isnan(total_policy_loss))
                total_iterations += 1
            end
        end
    end

    s,a,r,sn,d = batch(memory)
    v = mean(value(s))
    q1 = mean(critics[1](vcat(s, a)))
    q2 = mean(critics[2](vcat(s, a)))
    mu, sigma = forward_policy(policy, s)
    println(@sprintf("v: %f, q1: %f, q2: %f, e: %f", 
            v, 
            q1,
            q2,
            mean(sigma)))

    println(@sprintf("I: %d, r: %f, v: %f, q1: %f, q2: %f, p: %f",
            epoch,
            total_reward/(total_deaths+1),
            total_v_loss/total_iterations,
            total_q1_loss/total_iterations,
            total_q2_loss/total_iterations,
            total_policy_loss/total_iterations))
    flush(stdout)
  end
end


train! (generic function with 1 method)

In [17]:
value, value_target, critics, policy, memory = initialize()
collect_experience!(env, memory, 200)

10.526315789473685

In [25]:
train!(env, memory, optim, policy, critics, value, value_target)


v: 23.654447, q1: 23.426141, q2: 23.430552, e: 0.861843
I: 1, r: 24.096386, v: 0.021249, q1: 2000.219431, q2: 2000.684103, p: -23.577358
v: 23.295288, q1: 23.304172, q2: 23.346272, e: 0.869510
I: 2, r: 26.666667, v: 0.020617, q1: 2016.945288, q2: 2017.399796, p: -23.555325
v: 23.768756, q1: 23.635682, q2: 23.659207, e: 0.878101
I: 3, r: 27.027027, v: 0.027285, q1: 1983.293353, q2: 1983.859376, p: -23.569312
v: 23.518104, q1: 23.616610, q2: 23.668448, e: 0.874863
I: 4, r: 28.571429, v: 0.014384, q1: 2000.935017, q2: 2001.343588, p: -23.547004
v: 23.674016, q1: 23.589777, q2: 23.634162, e: 0.877567
I: 5, r: 26.666667, v: 0.018785, q1: 1949.967180, q2: 1950.388580, p: -23.610086
v: 23.548820, q1: 23.441905, q2: 23.467004, e: 0.876182
I: 6, r: 25.974026, v: 0.015728, q1: 1997.715843, q2: 1998.098964, p: -23.597711
v: 23.476372, q1: 23.493596, q2: 23.517159, e: 0.871869
I: 7, r: 27.397260, v: 0.024801, q1: 1988.017182, q2: 1988.453296, p: -23.612672
v: 23.872068, q1: 23.654609, q2: 23.66576

v: 23.527652, q1: 23.253805, q2: 23.216301, e: 0.880239
I: 61, r: 25.000000, v: 0.009378, q1: 1986.794766, q2: 1987.014500, p: -23.711093
v: 23.429945, q1: 23.453890, q2: 23.465711, e: 0.885897
I: 62, r: 25.316456, v: 0.019108, q1: 2037.714394, q2: 2037.934027, p: -23.695766
v: 23.902212, q1: 23.682968, q2: 23.667632, e: 0.876424
I: 63, r: 26.315789, v: 0.010739, q1: 2003.524029, q2: 2003.757434, p: -23.712395
v: 23.509436, q1: 23.643754, q2: 23.662922, e: 0.875773
I: 64, r: 23.809524, v: 0.008513, q1: 2020.509797, q2: 2020.720828, p: -23.709834
v: 23.587453, q1: 23.455529, q2: 23.443720, e: 0.872804
I: 65, r: 27.027027, v: 0.011002, q1: 2007.602858, q2: 2007.837710, p: -23.735484
v: 23.845513, q1: 23.780273, q2: 23.796526, e: 0.878015
I: 66, r: 22.471910, v: 0.011240, q1: 2021.515386, q2: 2021.730404, p: -23.730663
v: 23.659181, q1: 23.725662, q2: 23.743894, e: 0.867560
I: 67, r: 26.666667, v: 0.009698, q1: 2015.401577, q2: 2015.604108, p: -23.769540
v: 23.659844, q1: 23.634852, q2: 2

In [40]:
function test(env, policy)
    state = env.reset()
    total_reward = 0
    for _ in 1:5000
        mu, sigma = forward_policy(policy, state)
        state, reward, death, _ = env.step(mu)
        total_reward += reward
        env.render()
        if death
            break
        end
    end
    total_reward
end

test (generic function with 1 method)

In [41]:
test(env, policy)

7.0

In [None]:
@BSON.save "models.bson" Dict(
    :value => value,
    :value_target => value_target,
    :critics => critics,
    :policy => policy,
)

In [67]:
BATCH_SIZE = 10000
s,a,r,sa,d = batch(memory)

([-0.0593887521755381 -0.09837611675025143 … 0.011800205368475647 -0.19544651568152593; -0.7859106243879304 -1.4163999185471952 … 0.02471146324088716 -0.8055429254736141; -0.018139650155244755 0.1221037720307057 … -0.03026086376533901 0.029718731219978547; 0.4925133549958805 1.0070555566413861 … 0.012829738314209124 0.20123916341663578], [8.416118698940394 5.8421170146112615 … 4.956979129541407 8.013060997234167], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-0.08039334550728447 -0.13071587094306414 … 0.00946937420570149 -0.21736059562231513; -1.0502296665873185 -1.6169877096406347 … -0.1165415581387079 -1.0957039970394593; -0.004989776396819065 0.1445471776105656 … -0.028073328105376015 0.036913824237685744; 0.6574936879212845 1.122170278992995 … 0.10937678299814976 0.35975465088535985], Bool[0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 1])

In [75]:
mean(r)

1.0