Skip to content

Commit

Permalink
update to the latest PyTorch version
Browse files Browse the repository at this point in the history
  • Loading branch information
bamos committed Jul 17, 2022
1 parent 1b73ca9 commit bf19cd4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
17 changes: 10 additions & 7 deletions mpc/lqr_step.py
Expand Up @@ -157,7 +157,7 @@ def lqr_backward(ctx, C, c, F, f):
Kt_T.bmm(qt_u.unsqueeze(2)).squeeze(2) + \
Kt_T.bmm(Qt_uu).bmm(kt.unsqueeze(2)).squeeze(2)

return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter)
return Ks, ks, n_total_qp_iter


# @profile
Expand Down Expand Up @@ -210,7 +210,7 @@ def lqr_forward(ctx, x_init, C, c, F, f, Ks, ks):
I = ub > ub_limit
ub[I] = ub_limit if isinstance(lb_limit, float) else ub_limit[I]
# TODO(eugenevinitsky) why do we need to do this here?
new_ut = util.eclamp(new_ut, lb.double(), ub.double())
new_ut = util.eclamp(new_ut, lb, ub)
new_u.append(new_ut)

new_xut = torch.cat((new_xt, new_ut), dim=1)
Expand Down Expand Up @@ -273,12 +273,13 @@ def get_bound(side, t):

class LQRStepFn(Function):
# @profile
@staticmethod
def forward(ctx, x_init, C, c, F, f=None):
if no_op_forward:
ctx.save_for_backward(
x_init, C, c, F, f, current_x, current_u)
ctx.current_x, ctx.current_u = current_x, current_u
return current_x, current_u, None, None
return current_x, current_u

if delta_space:
# Taylor-expand the objective to do the backward pass in
Expand All @@ -295,17 +296,19 @@ def forward(ctx, x_init, C, c, F, f=None):
f_back = None
else:
assert False

ctx.current_x = current_x
ctx.current_u = current_u

Ks, ks, back_out = lqr_backward(ctx, C, c_back, F, f_back)
Ks, ks, n_total_qp_iter = lqr_backward(ctx, C, c_back, F, f_back)
new_x, new_u, for_out = lqr_forward(ctx,
x_init, C, c, F, f, Ks, ks)
ctx.save_for_backward(x_init, C, c, F, f, new_x, new_u)

return new_x, new_u, back_out, for_out
return new_x, new_u, torch.Tensor([n_total_qp_iter]), \
for_out.costs, for_out.full_du_norm, for_out.mean_alphas

@staticmethod
def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
start = time.time()
x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors
Expand Down Expand Up @@ -403,4 +406,4 @@ def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
backward_time = time.time()-start
return dx_init, dC, dc, dF, df

return LQRStepFn.apply
return LQRStepFn.apply
38 changes: 17 additions & 21 deletions mpc/mpc.py
Expand Up @@ -261,9 +261,8 @@ def forward(self, x_init, cost, dx):
C, c, _ = self.approximate_cost(
x, util.detach_maybe(u), cost, diff=False)

x, u, _lqr, back_out, for_out = self.solve_lqr_subproblem(
x_init, C, c, F, f, cost, dx, x, u)
# back_out, for_out = _lqr.back_out, _lqr.for_out
x, u, n_total_qp_iter, costs, full_du_norm, mean_alphas = \
self.solve_lqr_subproblem(x_init, C, c, F, f, cost, dx, x, u)
n_not_improved += 1

assert x.ndimension() == 3
Expand All @@ -273,31 +272,31 @@ def forward(self, x_init, cost, dx):
best = {
'x': list(torch.split(x, split_size_or_sections=1, dim=1)),
'u': list(torch.split(u, split_size_or_sections=1, dim=1)),
'costs': for_out.costs,
'full_du_norm': for_out.full_du_norm,
'costs': costs,
'full_du_norm': full_du_norm,
}
else:
for j in range(n_batch):
if for_out.costs[j] <= best['costs'][j] + self.best_cost_eps:
if costs[j] <= best['costs'][j] + self.best_cost_eps:
n_not_improved = 0
best['x'][j] = x[:,j].unsqueeze(1)
best['u'][j] = u[:,j].unsqueeze(1)
best['costs'][j] = for_out.costs[j]
best['full_du_norm'][j] = for_out.full_du_norm[j]
best['costs'][j] = costs[j]
best['full_du_norm'][j] = full_du_norm[j]

if self.verbose > 0:
util.table_log('lqr', (
('iter', i),
('mean(cost)', torch.mean(best['costs']).item(), '{:.4e}'),
('||full_du||_max', max(for_out.full_du_norm).item(), '{:.2e}'),
# ('||alpha_du||_max', max(for_out.alpha_du_norm), '{:.2e}'),
('||full_du||_max', max(full_du_norm).item(), '{:.2e}'),
# ('||alpha_du||_max', max(alpha_du_norm), '{:.2e}'),
# TODO: alphas, total_qp_iters here is for the current
# iterate, not the best
('mean(alphas)', for_out.mean_alphas.item(), '{:.2e}'),
('total_qp_iters', back_out.n_total_qp_iter),
('mean(alphas)', mean_alphas.item(), '{:.2e}'),
('total_qp_iters', n_total_qp_iter),
))

if max(for_out.full_du_norm) < self.eps or \
if max(full_du_norm) < self.eps or \
n_not_improved > self.not_improved_lim:
break

Expand All @@ -316,7 +315,7 @@ def forward(self, x_init, cost, dx):
else:
C, c, _ = self.approximate_cost(x, u, cost, diff=True)

x, u, _, _, _ = self.solve_lqr_subproblem(
x, u = self.solve_lqr_subproblem(
x_init, C, c, F, f, cost, dx, x, u, no_op_forward=True)

if self.detach_unconverged:
Expand All @@ -328,7 +327,7 @@ def forward(self, x_init, cost, dx):
print("LQR Warning: All examples did not converge to a fixed point.")
print("Detaching and *not* backpropping through the bad examples.")

I = for_out.full_du_norm < self.eps
I = full_du_norm < self.eps
Ix = Variable(I.unsqueeze(0).unsqueeze(2).expand_as(x)).type_as(x.data)
Iu = Variable(I.unsqueeze(0).unsqueeze(2).expand_as(u)).type_as(u.data)
x = x*Ix + x.clone().detach()*(1.-Ix)
Expand Down Expand Up @@ -359,9 +358,7 @@ def solve_lqr_subproblem(self, x_init, C, c, F, f, cost, dynamics, x, u,
no_op_forward=no_op_forward,
)
e = Variable(torch.Tensor())
x, u, back_out, for_out = _lqr(x_init, C, c, F, f if f is not None else e)

return x, u, _lqr, back_out, for_out
return _lqr(x_init, C, c, F, f if f is not None else e)
else:
nsc = self.n_state + self.n_ctrl
_n_state = nsc
Expand Down Expand Up @@ -443,10 +440,9 @@ def solve_lqr_subproblem(self, x_init, C, c, F, f, cost, dynamics, x, u,
back_eps=self.back_eps,
no_op_forward=no_op_forward,
)
x, u, back_out, for_out = _lqr(_x_init, _C, _c, _F, _f)
x, *rest = _lqr(_x_init, _C, _c, _F, _f)
x = x[:,:,self.n_ctrl:]

return x, u, _lqr, back_out, for_out
return [x] + rest

def approximate_cost(self, x, u, Cf, diff=True):
with torch.enable_grad():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_mpc.py
Expand Up @@ -946,12 +946,13 @@ def test_memory():
test_lqr_linear_unbounded()
test_lqr_linear_bounded()
test_lqr_linear_bounded_delta()
# test_lqr_cuda_singleton()
test_lqr_backward_cost_linear_dynamics_unconstrained()
test_lqr_backward_cost_linear_dynamics_constrained()
test_lqr_backward_cost_affine_dynamics_module_constrained()
test_lqr_backward_cost_nn_dynamics_module_constrained()
test_lqr_backward_cost_nn_dynamics_module_constrained_slew()
test_lqr_linearization()
test_lqr_slew_rate()

# test_lqr_cuda_singleton()
# test_memory()

0 comments on commit bf19cd4

Please sign in to comment.