In [9]:
import torch
from torch.autograd import Function, Variable
from torch.nn.parameter import Parameter
import torch.optim as optim

import numpy as np
import numpy.random as npr

from mpc import mpc
from mpc.mpc import GradMethods, QuadCost, LinDx

import sys
from IPython.core import ultratb

import time
import os
import shutil
import pickle as pkl
import collections
from mpc.dynamics import AffineDynamics
import argparse
import setproctitle
from tqdm import tqdm

In [24]:
n_batch, n_state, n_ctrl, T = 12, 3, 3, 10
n_sc = n_state + n_ctrl
device = 'cpu'
u_lower = torch.tensor([-0.5,-0.5,-0.5], dtype=torch.float32)
u_lower = u_lower.repeat(T, n_batch, 1)
u_upper = torch.tensor([0.5,0.5,0.5], dtype=torch.float32)
u_upper = u_upper.repeat(T, n_batch, 1)

In [29]:
goal_state = torch.Tensor([2,1,-1])
goal_weights = torch.ones(n_state)*10
px = -(goal_weights)*goal_state
p = torch.cat((px, torch.zeros(n_ctrl)))
p = p.unsqueeze(0).repeat(T, n_batch, 1)

ctr_penalty = 0.1
q = torch.cat([goal_weights, torch.ones(n_ctrl)*ctr_penalty]).to(device)
print(q)
Q = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(
        T, n_batch, 1, 1
).to(device)
A = torch.tensor([[1.01, 0.01, 0],[0.01, 1.01, 0.01],[0, 0.01, 1.01]]).to(device)
B = torch.eye(3).to(device)
weight_est, ctrl_est = Parameter(torch.randn(size=(1,))+1), Parameter(torch.randn(size=(1,))+1)

tensor([10.0000, 10.0000, 10.0000,  0.1000,  0.1000,  0.1000])


In [30]:
print(q)
print(torch.cat([torch.ones(n_state)*weight_est, torch.ones(n_ctrl)*ctrl_est]))

tensor([10.0000, 10.0000, 10.0000,  0.1000,  0.1000,  0.1000])
tensor([2.8567, 2.8567, 2.8567, 0.3694, 0.3694, 0.3694],
       grad_fn=<CatBackward0>)


In [31]:
def get_loss(x_init : torch.Tensor, q_est : torch.Tensor, r_est : torch.Tensor) -> torch.Tensor:

        # Expert 
        x_true, u_true, objs_true = mpc.MPC(
            n_state, n_ctrl, T,
            u_lower=u_lower, u_upper=u_upper, 
            lqr_iter=100,
            verbose=-1,
            exit_unconverged=False,
            detach_unconverged=False,
            backprop=False,
            n_batch=n_batch,
        )(x_init, QuadCost(Q, p), AffineDynamics(A=A, B=B))

        # Learner

        # Construct cost matrices from ctrl and state penalty
        # Weights and penalties are identical for each state so 
        # We only need to optimize over two scalar variables "weight_est" and "ctrl_est"

        q = torch.cat([torch.ones(n_state)*q_est, torch.ones(n_ctrl)*r_est])
        Q_est = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(
                T, n_batch, 1, 1
        ).to(device)
        px = -(torch.ones(n_state)*q_est)*goal_state
        p_est = torch.cat((px, torch.zeros(n_ctrl)))
        p_est = p_est.unsqueeze(0).repeat(T, n_batch, 1)    

        # Roll out MPC with estimated cost function
        x_pred, u_pred, objs_pred = mpc.MPC(
            n_state, n_ctrl, T,
            u_lower=u_lower, u_upper=u_upper, 
            lqr_iter=100,
            verbose=-1,
            backprop=False,
            exit_unconverged=False,
            detach_unconverged=False,
            n_batch=n_batch,
        )(x_init, QuadCost(Q_est, p_est), AffineDynamics(A=A, B=B))

        # Get MSE of trajectory
        traj_loss = torch.mean((u_true - u_pred)**2) + torch.mean((x_true - x_pred)**2)

        return traj_loss

In [32]:
opt = optim.RMSprop((weight_est, ctrl_est), lr=1e-2)
model_loss = []
traj_loss = []
for i in tqdm(range(100)):
    x_init = torch.randn(n_batch,n_state)

    loss = get_loss(x_init, weight_est, ctrl_est)
    traj_loss.append(loss)
    opt.zero_grad()
    loss.backward()
    opt.step()

    model_loss.append((10 - weight_est)**2 + (1 - ctrl_est)**2)


100%|██████████| 100/100 [00:21<00:00,  4.57it/s]


In [33]:
traj_loss

[tensor(0.0002, grad_fn=<AddBackward0>),
 tensor(0.0001, grad_fn=<AddBackward0>),
 tensor(6.9959e-05, grad_fn=<AddBackward0>),
 tensor(3.8982e-05, grad_fn=<AddBackward0>),
 tensor(2.3731e-05, grad_fn=<AddBackward0>),
 tensor(1.7700e-05, grad_fn=<AddBackward0>),
 tensor(9.4328e-06, grad_fn=<AddBackward0>),
 tensor(4.9492e-06, grad_fn=<AddBackward0>),
 tensor(2.4493e-06, grad_fn=<AddBackward0>),
 tensor(1.5519e-06, grad_fn=<AddBackward0>),
 tensor(9.9384e-07, grad_fn=<AddBackward0>),
 tensor(5.1008e-07, grad_fn=<AddBackward0>),
 tensor(3.9073e-07, grad_fn=<AddBackward0>),
 tensor(2.4036e-07, grad_fn=<AddBackward0>),
 tensor(1.2575e-07, grad_fn=<AddBackward0>),
 tensor(4.8172e-08, grad_fn=<AddBackward0>),
 tensor(2.3396e-08, grad_fn=<AddBackward0>),
 tensor(1.2757e-08, grad_fn=<AddBackward0>),
 tensor(7.4299e-09, grad_fn=<AddBackward0>),
 tensor(3.1111e-09, grad_fn=<AddBackward0>),
 tensor(7.3259e-10, grad_fn=<AddBackward0>),
 tensor(7.9785e-10, grad_fn=<AddBackward0>),
 tensor(5.2330e-10

In [34]:
ctrl_est

Parameter containing:
tensor([0.0308], requires_grad=True)

In [35]:
weight_est

Parameter containing:
tensor([3.0772], requires_grad=True)