Skip to content

Commit

Permalink
save some allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Nov 28, 2020
1 parent efc2045 commit e038b5a
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/backward_pass.jl
Expand Up @@ -2,8 +2,8 @@ choleskyvectens(a,b) = permutedims(sum(a.*b,1), [3 2 1])

macro setupQTIC()
quote
m = size(u,1)
n,_,N = size(fx)
m = size(u,1)
n,_,N = size(fx)

@assert size(cx) == (n, N) "size(cx) should be (n, N)"
@assert size(cu) == (m, N) "size(cu) should be (m, N)"
Expand All @@ -18,15 +18,15 @@ macro setupQTIC()
Quui = Array{T}(undef,m,m,N)
dV = [0., 0.]

Vx[:,N] = cx[:,N]
@views Vx[:,N] .= cx[:,N]
Vxx[:,:,N] = cxx
Quu[:,:,N] = cuu
diverge = 0
end |> esc
end

macro end_backward_pass()
quote
@views quote
QuF = Qu
if isempty(lims) || lims[1,1] > lims[1,2]
# debug("# no control limits: Cholesky decomposition, check for non-PD")
Expand Down Expand Up @@ -61,14 +61,19 @@ macro end_backward_pass()
end
end
# debug("# update cost-to-go approximation")
dV = dV + [k_i'Qu; .5*k_i'Quu[:,:,i]*k_i]
Vx[:,i] = Qx + K_i'Quu[:,:,i]*k_i + K_i'Qu + Qux'k_i
Vxx[:,:,i] = Qxx + K_i'Quu[:,:,i]*K_i + K_i'Qux + Qux'K_i
Vxx[:,:,i] = .5*(Vxx[:,:,i] + Vxx[:,:,i]')
Quuki = Quu[:,:,i]*k_i
kiQuuki = dot(k_i, Quuki)
KiQuuki = (K_i'Quuki)
KiQuuKi = K_i'Quu[:,:,i]*K_i
dV .= dV .+ [dot(k_i,Qu); .5*kiQuuki]
Vx[:,i:i] .= Qx .+ KiQuuki .+ K_i'Qu .+ Qux'k_i
Vxx[:,:,i] .= Qxx .+ KiQuuKi .+ K_i'Qux .+ Qux'K_i
Vxx[:,:,i] .+= Vxx[:,:,i]'
Vxx[:,:,i] ./= 2

# debug("# save controls/gains")
k[:,i] = k_i
K[:,:,i] = K_i
k[:,i] .= vec(k_i)
K[:,:,i] .= K_i

end |> esc
end
Expand Down Expand Up @@ -209,7 +214,7 @@ function back_pass(cx,cu,cxx::AbstractArray{T,3},cxu,cuu,fx::AbstractArray{T,3},
return diverge, GaussianPolicy(N,n,m,K,k,Quui,Quu), Vx, Vxx,dV
end

function back_pass(cx,cu,cxx::AbstractArray{T,2},cxu,cuu,fx::AbstractMatrix{T},fu,λ,regType,lims,x,u) where T # cost quadratic and cost and LTI dynamics
@views function back_pass(cx,cu,cxx::AbstractArray{T,2},cxu,cuu,fx::AbstractMatrix{T},fu,λ,regType,lims,x,u) where T # cost quadratic and cost and LTI dynamics

m,N = size(u)
n = size(fx,1)
Expand All @@ -226,7 +231,7 @@ function back_pass(cx,cu,cxx::AbstractArray{T,2},cxu,cuu,fx::AbstractMatrix{T},f
Quui = Array{T}(undef,m,m,N)
dV = [0., 0.]

Vx[:,N] = cx[:,N]
Vx[:,N] = cx[:,end]
Vxx[:,:,N] = cxx
Quu[:,:,N] = cuu

Expand Down

0 comments on commit e038b5a

Please sign in to comment.