In [1]:
include("Drone2.jl")
using .Drone2: DroneMDP, DroneAct, DroneState, render, gen, isterminal, reset!, act!
using StaticArrays
using LinearAlgebra
using Statistics
using Random
using Distributions: Normal, logpdf
using Flux
using Flux: params, gradient, update!, gradient, Optimise, Adam, mse
using ElectronDisplay

### PPO

#### Feed Forward NN

In [2]:

struct ActorCritic
    actor::Chain
    critic::Chain
end

function ActorCritic(state_size::Int, action_size::Int)
    actor = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, action_size * 2),
        x -> [x[1],x[2], softplus.(x[3]),softplus.(x[4])]
    )

    critic = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, 1)
    )

    return ActorCritic(actor, critic)
end

ActorCritic

In [3]:
function forward(ac::ActorCritic, states)
    if isa(states, DroneState)
        states = [states]
    end
    results = fill([zeros(2), zeros(2)], length(states))
    values = zeros(length(states), 1)
    for i in 1:length(states)
        actor_out = ac.actor(states[i])  
        value = ac.critic(states[i])[1]
        action_mean = [actor_out[1]; actor_out[2]]
        action_std = [actor_out[3]; actor_out[4]]
        results[i] = [action_mean, action_std]
        values[i] = value
    end
    return results, values
end 

forward (generic function with 1 method)

In [4]:
function learn(ac::ActorCritic,m::DroneMDP, total_timesteps::Int64,time_steps_per_batch::Int64, max_timesteps_per_episode::Int64, n_updates_per_iteration::Int64,clip=0.2 )
    t_so_far = 0
    while t_so_far < total_timesteps
        batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens = rollout(ac, m, time_steps_per_batch, max_timesteps_per_episode)
        
        V,_ = evaluate(ac, batch_states, batch_actions)

        A_k = advantage(batch_rtgs, V)
        for _ in 1:n_updates_per_iteration
            # Epoch code
            actor_loss_fn(ac,batch_states,batch_actions,batch_log_probs,A_k,clip)


            actor_opt = Adam(lr)
            actor_gs = gradient(() -> actor_loss )
            update!(Adam(0.005), params(ac.actor), actor_gs)

            critic_opt = Adam(lr)
            critic_gs = gradient(() -> critic_loss_fn(ppo_network, batch_obs, batch_rtgo), params(ppo_network.critic.model))
            update!(critic_opt, params(ppo_network.critic.model), critic_gs)


end

Base.Meta.ParseError: ParseError:
# Error @ /Users/andres/Documents/UMD/Spring 2024/ENAE 788Z/Project/MDP-Drone/PPO2.ipynb:23:4

end
#  └ ── Expected `end`

In [5]:
function actor_loss_fn(batch_states,batch_actions,batch_log_probs)
    _,curr_log_probs = evaluate(ac, batch_states, batch_actions)
    ratios = exp.(curr_log_probs-batch_log_probs)
    surr1 = ratios.*A_k
    surr2 = clamp.(ratios,1-clip,1+clip).*A_k
    actor_loss = mean(-min(surr1,surr2))
    return actor_loss
end

actor_loss_fn (generic function with 1 method)

In [6]:
function rollout(ac::ActorCritic, m::DroneMDP, time_steps_per_batch::Int64, max_timesteps_per_episode::Int64)
    batch_states = []
    batch_actions = []
    batch_log_probs =[]
    batch_rews = []
    batch_lens = []

    t = 0
    while t < time_steps_per_batch
        ep_rews =[]
        s = reset!(m)
        done = false
        ep_t_temp = 0
        for ep_t in 1:max_timesteps_per_episode
            t += 1
            push!(batch_states,s)
            action, log_prob = get_action(ac::ActorCritic, m::DroneMDP)
            s, rew, done = act!(m,action)
            render(m)
            push!(ep_rews,rew)
            push!(batch_actions,action)
            push!(batch_log_probs,log_prob)
            ep_t_temp =ep_t
            if done
                break
            end
        end
        push!(batch_lens,ep_t_temp+1)
        push!(batch_rews,ep_rews)
    end

    batch_rtgs = rewards2go(m,batch_rews)
    
    return   batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens
end;

In [7]:
function get_action(ac::ActorCritic, m::DroneMDP)
    result = forward(ac, m.drone)
    action_mean = result[1][1][1] 
    action_std = result[1][1][2]
    params = rand.(Normal.(action_mean, action_std))
    v = clamp(params[1], -m.v_max, m.v_max)
    omega = clamp(params[2], -m.om_max, m.om_max)
    action = DroneAct(v,omega)
    log_prob = sum(logpdf.(Normal.(action_mean, action_std), action))
    return action, log_prob
end

get_action (generic function with 1 method)

In [8]:
function rewards2go(m::DroneMDP, batch_rews)
    batch_rtg = []
    for ep_rews in reverse(batch_rews)
        discounted_reward = 0
        for rew in  reverse(ep_rews)
            discounted_reward = rew + m.discount*discounted_reward
            pushfirst!(batch_rtg,discounted_reward)
        end
    end
    return  batch_rtg
end


rewards2go (generic function with 1 method)

In [9]:
# function evaluate(ac::ActorCritic, batch_states, batch_actions)
#    # Compute V
#    V = [ac.critic(state)[1] for state in batch_states]

#    # Compute log_probs
#    log_probs = [0.0 for _ in 1:length(batch_states)]
#    for i in 1:length(batch_states)
#        result = forward(ac, batch_states[i])
#        action_mean = result[1][1][1] 
#        action_std = result[1][1][2]
#        log_probs[i] = sum(logpdf.(Normal.(action_mean, action_std), batch_actions[i]))
#    end

#    return V, log_probs
# end

function evaluate(ac::ActorCritic, batch_states, batch_actions)
    # Compute V
    V = [ac.critic(state)[1] for state in batch_states]
 
    # Compute log_probs
    log_probs = Float64[]
    for i in 1:length(batch_states)
        result = forward(ac, batch_states[i])
        action_mean = result[1][1][1] 
        action_std = result[1][1][2]
        v = Normal(action_mean[1], action_std[1])
        om = Normal(action_mean[2], action_std[2])
        push!(log_probs, logpdf(v, batch_actions[i][1]) + logpdf(om, batch_actions[i][2]))
    end
 
    return V, log_probs
end

evaluate (generic function with 1 method)

In [10]:
function advantage(batch_rtgs, V)
    A = batch_rtgs-V
    return normalize(A)
end

advantage (generic function with 1 method)

In [11]:
ac = ActorCritic(3,2)
m = DroneMDP()
clip = 0.02

0.02

In [12]:
batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens = rollout(ac, m, 100, 1000)

objc[4685]: Class WebSwapCGLLayer is implemented in both /System/Library/Frameworks/WebKit.framework/Versions/A/Frameworks/WebCore.framework/Versions/A/Frameworks/libANGLE-shared.dylib (0x25e232270) and /Users/andres/.julia/artifacts/12f3018147190ddc494f686e5fbefe8d84f16efb/Julia.app/Contents/Frameworks/Electron Framework.framework/Versions/A/Libraries/libGLESv2.dylib (0x117071348). One of the two will be used. Which one is undefined.


(Any[Float32[25.0, 25.0, 0.0], Float32[24.969805, 24.983505, 0.5], Float32[24.828938, 24.764118, 1.0], Float32[24.826172, 24.725117, 1.5], Float32[24.862434, 24.645884, 2.0], Float32[24.87662, 24.635288, 2.5], Float32[24.871964, 24.635952, 3.0], Float32[24.979961, 24.676407, 3.5], Float32[25.079311, 24.791437, 4.0], Float32[25.146011, 25.10075, 4.5]  …  Float32[23.635872, 28.321894, 1.1702712], Float32[23.660135, 28.078781, 1.6702712], Float32[23.81212, 27.856373, 2.1702712], Float32[24.065924, 27.727028, 2.6702712], Float32[24.204824, 27.731012, 3.1702712], Float32[24.341633, 27.756327, 3.3245606], Float32[24.47638, 27.865953, 3.8245606], Float32[24.62398, 27.958414, 3.7012196], Float32[24.6836, 28.064707, 4.2012196], Float32[24.686981, 28.367535, 4.7012196]], Any[Float32[-0.068813615, 1.0], Float32[-0.521437, 1.0], Float32[-0.07819696, 1.0], Float32[-0.17427537, 1.0], Float32[-0.035409898, 1.0], Float32[0.009405222, 1.0], Float32[-0.23065285, 1.0], Float32[-0.30398887, 1.0], Float32[

In [13]:
V,_ = evaluate(ac, batch_states, batch_actions)
A_k = advantage(batch_rtgs, V)
actor_loss_fn(batch_states,batch_actions,batch_log_probs)

0.05862566503990246

In [14]:
function log_prob(mean, std, state)
    dist = Normal(mean, std)
    return logpdf(dist, state)
end

log_prob (generic function with 1 method)

In [15]:
actor_gs = gradient(() -> actor_loss_fn(batch_states,batch_actions,batch_log_probs), params(ac.actor))
update!(actor_opt, params(ac.actor), actor_gs)


ErrorException: Mutating arrays is not supported -- called push!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations
