In [1]:
include("Drone2.jl")
using .Drone2: DroneMDP, DroneAct, DroneState, render, gen, isterminal
using StaticArrays
using LinearAlgebra
using Random
using Distributions: Normal, logpdf
using Flux
using Flux: params, gradient, update!

### PPO

In [4]:

struct ActorCritic
    actor::Chain
    critic::Chain
end

function ActorCritic(state_size::Int, action_size::Int)
    actor = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, action_size * 2),
        x -> [x[1],x[2], softplus.(x[3]),softplus.(x[4])]
    )

    critic = Chain(
        Dense(state_size, 64, sigmoid),
        Dense(64, 1)
    )

    return ActorCritic(actor, critic)
end

ActorCritic

In [5]:
function forward(ac::ActorCritic, state::DroneState)
    actor_out = ac.actor(state)
    action_mean = [actor_out[1], actor_out[2]]
    action_std = [actor_out[3], actor_out[4]]  
    value = ac.critic(state)[1]
    return [action_mean, action_std], value  # Concatenate means and std deviations into a single array
end

forward (generic function with 1 method)

In [6]:
function log_prob(ac::ActorCritic, states::Vector{DroneState}, actions::Vector{DroneAct})
    log_probs = zeros(length(states))

    for i in 1:length(states)
        (action_mean, action_std), _ = forward(ac, states[i])
        log_probs[i] = sum(logpdf.(Normal.(action_mean, action_std), actions[i]))
    end

    return log_probs
end


log_prob (generic function with 1 method)

In [7]:
function simulation(m::DroneMDP, s0::DroneState, max_steps=Inf)
    t = 0
    s = s0
    states::Vector{DroneState} = []
    actions::Vector{DroneAct} = []
    rewards = []
    values = []
    while !isterminal(m, s) && t < max_steps
        (action_mean, action_std), value = forward(ac, s0)
        push!(values,value)
        a = DroneAct(rand.(Normal.(action_mean, action_std)))  # Call the policy_function
        push!(actions,a)
        s, r = gen(m,s,a)
        push!(states,s)
        push!(rewards,r)
        
        t += 1
    end
    
    return states,actions,rewards,values  
end;

In [8]:

function compute_advantage(r::Vector, values::Vector, discount::Float32, lambda::Float32)
    T = length(r)
    δ = r .+ discount * [values[2:end]; 0] .- values
    advantages = similar(δ)
    running_advantage = 0
    for t in T:-1:1
        running_advantage = δ[t] + discount * lambda * running_advantage
        advantages[t] = running_advantage
    end
    return advantages
end

compute_advantage (generic function with 1 method)

In [9]:
function ppo_loss(ac::ActorCritic, states::Vector{DroneState}, actions::Vector{DroneAct}, rewards, values, discount, lambda, eps_clip)

    # Compute advantages
    advantages = compute_advantage(rewards, values, discount, lambda)

    # Compute old and new log probabilities
    old_log_probs = log_prob(ac, states, actions)
    new_log_probs = log_prob(ac, states, actions)

    # Compute ratio and clipped ratio
    ratio = exp.(new_log_probs - old_log_probs)
    clipped_ratio = clamp.(ratio, 1 - eps_clip, 1 + eps_clip)

    # Compute surrogate loss
    surrogate_loss = min.(advantages .* ratio, advantages .* clipped_ratio)

    # Compute value loss
    value_loss = sum((rewards .+ discount * [values[2:end]; 0] .- values).^2)

    # Compute total loss
    loss = -mean(surrogate_loss) + 0.5 * value_loss

    return loss
end

ppo_loss (generic function with 1 method)

In [10]:
function ppo_update!(ac::ActorCritic, optimizer, states, actions, rewards, values, discount, lambda, eps_clip)
    gradient(params(ac)) do
        loss = ppo_loss(ac, states, actions, rewards, values, discount, lambda, eps_clip)
        return loss
    end
    Flux.Optimise.update!(optimizer, params(ac))
end


ppo_update! (generic function with 1 method)

In [11]:
m = DroneMDP()
ac = ActorCritic(4,2)
s0 = DroneState(5,5,0.0,false)
states1, actions1, rewards1, values1  =  simulation(m,s0,1000)
lambda = Float32(0.99)
eps_clip = Float32(0.2);