In [47]:
include("Drone2.jl")
using .Drone2: DroneMDP, DroneAct, DroneState, render, gen, isterminal, reset!, act!
using StaticArrays
using LinearAlgebra
using Statistics
using Random
using Distributions: Normal, logpdf
using Flux
using Flux: params, gradient, update!, gradient, Optimise, Adam, mse, train
using ElectronDisplay



LoadError: LoadError: ArgumentError: Package LinearAlgebra not found in current path, maybe you meant `import/using ..LinearAlgebra`.
- Otherwise, run `import Pkg; Pkg.add("LinearAlgebra")` to install the LinearAlgebra package.
in expression starting at /Users/andres/Documents/UMD/Spring 2024/ENAE 788Z/Project/MDP-Drone/Drone2.jl:1

### PPO

#### Feed Forward NN

In [48]:

struct ActorCritic
    actor::Chain
    critic::Chain
end

function ActorCritic(state_size::Int, action_size::Int)
    actor = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, action_size * 2),
        x -> [x[1],x[2], softplus.(x[3]),softplus.(x[4])]
    )

    critic = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, 1)
    )

    return ActorCritic(actor, critic)
end

ActorCritic

In [49]:
function forward(ac::ActorCritic, states)
    if isa(states, DroneState)
        states = [states]
    end
    results = Vector{Tuple{Vector{Float64}, Vector{Float64}}}(undef, length(states))
    values = zeros(length(states), 1)
    for i in 1:length(states)
        actor_out = ac.actor(states[i])  
        value = ac.critic(states[i])[1]
        action_mean = [actor_out[1]; actor_out[2]]
        action_std = [actor_out[3]; actor_out[4]]
        results[i] = (action_mean, action_std)
        values[i] = value
    end
    return results, values
end 


forward (generic function with 1 method)

In [50]:
function learn(ac::ActorCritic,m::DroneMDP, total_timesteps::Int64,time_steps_per_batch::Int64, max_timesteps_per_episode::Int64, n_updates_per_iteration::Int64,clip=0.2 )
    t_so_far = 0
    while t_so_far < total_timesteps
        batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens = rollout(ac, m, time_steps_per_batch, max_timesteps_per_episode)
        
        V,_ = evaluate(ac, batch_states, batch_actions)

        A_k = advantage(batch_rtgs, V)
        for _ in 1:n_updates_per_iteration
            # Epoch code
            actor_loss_fn(ac,batch_states,batch_actions,batch_log_probs,A_k,clip)


            actor_opt = Adam(lr)
            actor_gs = gradient(() -> actor_loss )
            update!(Adam(0.005), params(ac.actor), actor_gs)

            critic_opt = Adam(lr)
            critic_gs = gradient(() -> critic_loss_fn(ppo_network, batch_obs, batch_rtgo), params(ppo_network.critic.model))
            update!(critic_opt, params(ppo_network.critic.model), critic_gs)


end

Base.Meta.ParseError: ParseError:
# Error @ /Users/andres/Documents/UMD/Spring 2024/ENAE 788Z/Project/MDP-Drone/PPO2.ipynb:23:4

end
#  └ ── Expected `end`

In [51]:
function actor_loss_fn(batch_states,batch_actions,batch_log_probs)
    _,curr_log_probs = evaluate(ac, batch_states, batch_actions)
    ratios = exp.(curr_log_probs-batch_log_probs)
    surr1 = ratios.*A_k
    surr2 = clamp.(surr1,1-clip,1+clip)
    actor_loss = mean(-min(surr1,surr2))
    return actor_loss
end

actor_loss_fn (generic function with 2 methods)

In [52]:
function rollout(ac::ActorCritic, m::DroneMDP, time_steps_per_batch::Int64, max_timesteps_per_episode::Int64)
    batch_states = []
    batch_actions = []
    batch_log_probs =[]
    batch_rews = []
    batch_lens = []

    t = 0
    while t < time_steps_per_batch
        ep_rews =[]
        s = reset!(m)
        done = false
        ep_t_temp = 0
        for ep_t in 1:max_timesteps_per_episode
            t += 1
            push!(batch_states,s)
            action, log_prob = get_action(ac::ActorCritic, m::DroneMDP)
            s, rew, done = act!(m,action)
            render(m)
            push!(ep_rews,rew)
            push!(batch_actions,action)
            push!(batch_log_probs,log_prob)
            ep_t_temp =ep_t
            if done
                break
            end
        end
        push!(batch_lens,ep_t_temp+1)
        push!(batch_rews,ep_rews)
    end

    batch_rtgs = rewards2go(m,batch_rews)
    
    return   batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens
end;

In [53]:
function get_action(ac::ActorCritic, m::DroneMDP)
    result = forward(ac, m.drone)
    action_mean = result[1][1][1] 
    action_std = result[1][1][2]
    params = rand.(Normal.(action_mean, action_std))
    v = clamp(params[1], -m.v_max, m.v_max)
    omega = clamp(params[2], -m.om_max, m.om_max)
    action = DroneAct(v,omega)
    log_prob = sum(logpdf.(Normal.(action_mean, action_std), action))
    return action, log_prob
end

get_action (generic function with 1 method)

In [54]:
function rewards2go(m::DroneMDP, batch_rews)
    batch_rtg = []
    for ep_rews in reverse(batch_rews)
        discounted_reward = 0
        for rew in  reverse(ep_rews)
            discounted_reward = rew + m.discount*discounted_reward
            pushfirst!(batch_rtg,discounted_reward)
        end
    end
    return  batch_rtg
end


rewards2go (generic function with 1 method)

In [55]:
function log_prob(ac::ActorCritic,batch_states,batch_actions)
    log_probs = zeros(length(batch_states))
    for i in 1:length(batch_states)
        result = forward(ac, batch_states[i])
        action_mean = result[1][1][1] 
        action_std = result[1][1][2]
        log_prob = sum(logpdf.(Normal.(action_mean, action_std), batch_actions[i]))
        log_probs[i] =  log_prob
    end
    return log_probs
end

log_prob (generic function with 1 method)

In [56]:
function evaluate(ac::ActorCritic, batch_states, batch_actions)
   V = zeros(length(batch_states))
   for i in 1:length(batch_states)
      V[i] = ac.critic(batch_states[i])[1]
   end
   log_probs = log_prob(ac,batch_states,batch_actions)
   return V, log_probs
end

evaluate (generic function with 1 method)

In [57]:
function advantage(batch_rtgs, V)
    A = batch_rtgs-V
    return normalize(A)
end

advantage (generic function with 1 method)

In [58]:
ac = ActorCritic(3,2)
m = DroneMDP()
clip = 0.02

0.02

In [59]:
batch_states, batch_actions, batch_log_probs, batch_rtgs, batch_lens = rollout(π, m, 100, 1000)

(Any[Float32[25.0, 25.0, 0.0], Float32[24.408594, 25.211927, 5.9390965], Float32[23.861311, 25.408241, 5.938776], Float32[23.37253, 25.958403, 5.438776], Float32[23.256216, 26.463379, 4.938776], Float32[23.15758, 27.111712, 4.86337], Float32[22.786339, 27.599062, 5.36337], Float32[22.281893, 27.824223, 5.86337], Float32[21.637169, 28.178728, 5.7804527], Float32[21.066261, 28.60882, 5.637546]  …  Float32[9.938222, 10.040588, 0.9311141], Float32[9.414979, 9.799912, 0.43111408], Float32[8.970353, 9.202337, 0.9311141], Float32[8.337469, 8.91123, 0.43111408], Float32[7.7758946, 8.949976, 6.214299], Float32[7.219337, 9.305838, 5.714299], Float32[6.76973, 9.815548, 5.435217], Float32[6.5367723, 10.662055, 4.9809403], Float32[6.2947435, 11.316176, 5.066774], Float32[5.6538653, 11.874214, 5.566774]], Any[Float32[-1.2564622, -0.6881776], Float32[-1.1628556, -0.00064091414], Float32[-1.471847, -1.0], Float32[-1.0363995, -1.0], Float32[-1.3115875, -0.15081252], Float32[-1.225285, 1.0], Float32[-1.

In [60]:
V,_ = evaluate(π, batch_states, batch_actions)
A_k = advantage(batch_rtgs, V)
actor_loss_fn(batch_states,batch_actions,batch_log_probs)

0.010881271502685403

In [61]:
data =  (batch_states,batch_actions,batch_log_probs)
actor_loss_fn(batch_states,batch_actions,batch_log_probs)

0.010881271502685403