Skip to content

Commit

Permalink
fix kl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jan 30, 2019
1 parent 997ec22 commit 7a7f192
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 20 deletions.
38 changes: 24 additions & 14 deletions src/demo_linear.jl
Expand Up @@ -80,16 +80,16 @@ function demo_linear_kl(;kwargs...)
lims = [] #ones(m,1)*[-1 1]*.6

T = 1000 # horizon
x = ones(n) # initial state
x0 = ones(n) # initial state
u = .1*randn(m,T) # initial controls

# optimization problem
N = T+1
fx = A
fu = B
cxx = Q
cxu = zeros(size(B))
cuu = R
fx = repeat(A,1,1,T)
fu = repeat(B,1,1,T)
cxx = repeat(Q,1,1,T)
cxu = repeat(zeros(size(B)),1,1,T)
cuu = repeat(R,1,1,T)
function lin_dyn_df(x,u,Q,R)
u[isnan.(u)] .= 0
cx = Q*x
Expand All @@ -102,25 +102,35 @@ function demo_linear_kl(;kwargs...)
xnew = A*x + B*u
return xnew
end
lin_dyn_fT(x,Q) = 0.5*sum(x.*(Q*x))
f(x,u,i) = lin_dyn_f(x,u,A,B,Q,R)
costfun(x,u) = 0.5*(sum(x.*(Q*x),1) + sum(u.*(R*u),1))[:]
df(x,u) = lin_dyn_df(x,u,Q,R)
dyn = (x,u,i) -> lin_dyn_f(x,u,A,B,Q,R)
costf = (x,u) -> 0.5*(sum(x.*(Q*x),dims=1) + sum(u.*(R*u),dims=1))[:]
diffdyn = (x,u) -> lin_dyn_df(x,u,Q,R)

function rollout(u)
x = zeros(n,T)
x[:,1] = x0
for t = 1:T-1
x[:,t+1] = dyn(x[:,t],u[:,t],t)
end
x
end
x = rollout(u)
model = LinearTimeVaryingModelsBase.SimpleLTVModel(repeat(A,1,1,N),repeat(B,1,1,N),false)
# plotFn(x) = plot(squeeze(x,2)')
traj = GaussianPolicy(Float64,T,n,m)
# run the optimization
local Vx, Vxx, cost, otrace, totalcost
outercosts = zeros(5)
@time for iter = 1:5
cost0 = 0.5*sum(x.*(Q*x)) + 0.5*sum(u.*(R*u))
x, u, traj, Vx, Vxx, cost, otrace = iLQGkl(f,costfun,df, x, u, traj, model; cost=cost0, lims=lims,kwargs...);
x, u, traj, Vx, Vxx, cost, otrace = iLQGkl(dyn,costf,diffdyn, x, traj, model; cost=cost0, lims=lims,kwargs...);
totalcost = get(otrace, :cost)[2]
outercosts[iter] = sum(totalcost)
println("Outer loop: Cost = ", sum(cost))
end



plotstuff_linear(x,u,totalcost,outercosts)
totalcost = get(otrace, :cost)[2]
plotstuff_linear(x,u,[cost],min.(totalcost,400))
# plotstuff_linear(x,u,totalcost,outercosts)
x, u, traj, Vx, Vxx, cost, otrace
end
2 changes: 1 addition & 1 deletion src/iLQGkl.jl
Expand Up @@ -24,7 +24,7 @@ To solve the maximum entropy problem, use controller `controller(xi,i) = u[:,i]
"""
function iLQGkl(dynamics,costfun,derivs, x0, traj_prev, model;
constrain_per_step = false,
kl_step = 0,
kl_step = 1,
lims = [],
tol_fun = 1e-7,
tol_grad = 1e-4,
Expand Down
8 changes: 3 additions & 5 deletions src/klutils.jl
Expand Up @@ -101,15 +101,14 @@ end



entropy(traj::GaussianPolicy) = mean(logdet(traj.Σ[:,:,t])/2 for t = 1:traj.T) + traj.m*log(2π*e)/2
entropy(traj::GaussianPolicy) = mean(logdet(traj.Σ[:,:,t])/2 for t = 1:traj.T) + traj.m*log(2π)/2

"""
new_η, satisfied, divergence = calc_η(xnew,xold,sigmanew,η, traj_new, traj_prev, kl_step)
This Function caluculates the step size
"""
function calc_η(xnew,xold,sigmanew,ηbracket, traj_new, traj_prev, kl_step::Number)
kl_step > 0 || (return (1., true,0))

kl_step > 0 || (return (ηbracket, true,0))
divergence = kl_div_wiki(xnew,xold,sigmanew, traj_new, traj_prev) |> mean
constraint_violation = divergence - kl_step
# Convergence check - constraint satisfaction.
Expand All @@ -131,8 +130,7 @@ function calc_η(xnew,xold,sigmanew,ηbracket, traj_new, traj_prev, kl_step::Num
end

function calc_η(xnew,xold,sigmanew,ηbracket, traj_new, traj_prev, kl_step::AbstractVector)
any(kl_step .> 0) || (return (1., true,0))

any(kl_step .> 0) || (return (ηbracket, true,0))
divergence = kl_div_wiki(xnew,xold,sigmanew, traj_new, traj_prev)
if !isa(kl_step,AbstractVector)
divergence = mean(divergence)
Expand Down

0 comments on commit 7a7f192

Please sign in to comment.