In [60]:
import torch
from torch.nn.parameter import Parameter
import torch.optim as optim

import numpy as np

from mpc import mpc
from mpc.mpc import QuadCost

from IPython.core import ultratb

from mpc.dynamics import AffineDynamics
from tqdm import tqdm

## Parameter Initialisation

In [61]:
n_batch, n_state, n_ctrl, T = 24, 3, 3, 10
n_sc = n_state + n_ctrl
device = 'cpu'
u_lower = torch.tensor([-0.5,-0.5,-0.5], dtype=torch.float32)
u_lower = u_lower.repeat(T, n_batch, 1)
u_upper = torch.tensor([0.5,0.5,0.5], dtype=torch.float32)
u_upper = u_upper.repeat(T, n_batch, 1)

In [62]:
goal_state = torch.Tensor([2,1,-1])
goal_weights = torch.ones(n_state)*10
px = -(goal_weights)*goal_state
p = torch.cat((px, torch.zeros(n_ctrl)))
p = p.unsqueeze(0).repeat(T, n_batch, 1)

ctr_penalty = 0.1
q = torch.cat([goal_weights, torch.ones(n_ctrl)*ctr_penalty]).to(device)
print(q)
Q = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(
        T, n_batch, 1, 1
).to(device)
A = torch.tensor([[1.01, 0.01, 0],[0.01, 1.01, 0.01],[0, 0.01, 1.01]]).to(device)
B = torch.eye(3).to(device)

# Initialise Parameters
weight_est, ctrl_est = Parameter(torch.randn(size=(3,))*0.1+1), Parameter(torch.randn(size=(3,))*0.1+1)

tensor([10.0000, 10.0000, 10.0000,  0.1000,  0.1000,  0.1000])


In [63]:
print(weight_est.sum().item())
print(weight_est)

2.8732266426086426
Parameter containing:
tensor([1.0178, 0.8498, 1.0057], requires_grad=True)


In [64]:
print(q)
print(torch.cat([weight_est, ctrl_est]))

tensor([10.0000, 10.0000, 10.0000,  0.1000,  0.1000,  0.1000])
tensor([1.0178, 0.8498, 1.0057, 0.9706, 0.9906, 0.9783],
       grad_fn=<CatBackward0>)


## Loss Function Definition

In [94]:
def get_loss(x_init : torch.Tensor, q_est : torch.Tensor, r_est : torch.Tensor) -> torch.Tensor:

        # Expert 
        x_true, u_true, objs_true = mpc.MPC(
            n_state, n_ctrl, T,
            u_lower=u_lower, u_upper=u_upper, 
            lqr_iter=100,
            verbose=-1,
            exit_unconverged=False,
            detach_unconverged=False,
            backprop=False,
            n_batch=n_batch,
        )(x_init, QuadCost(Q, p), AffineDynamics(A=A, B=B))

        # Learner

        # Construct cost matrices from ctrl and state penalty
        # Weights and penalties are identical for each state so 
        # We only need to optimize over two scalar variables "weight_est" and "ctrl_est"

        q = torch.cat([q_est, r_est])
        Q_est = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(
                T, n_batch, 1, 1
        ).to(device)
        px = -(q_est)*goal_state
        p_est = torch.cat((px, torch.zeros(n_ctrl)))
        p_est = p_est.unsqueeze(0).repeat(T, n_batch, 1)    

        # Roll out MPC with estimated cost function
        x_pred, u_pred, objs_pred = mpc.MPC(
            n_state, n_ctrl, T,
            u_lower=u_lower, u_upper=u_upper, 
            lqr_iter=100,
            verbose=-1,
            backprop=False,
            exit_unconverged=False,
            detach_unconverged=False,
            n_batch=n_batch,
        )(x_init, QuadCost(Q_est, p_est), AffineDynamics(A=A, B=B))

        # Get MSE of trajectory
        criterion = torch.nn.MSELoss()
        traj_loss = criterion(input=u_pred, target=u_true)
        # traj_loss = torch.mean((u_true - u_pred)**2) #+ torch.mean((x_true - x_pred)**2)
        return traj_loss

## Training

In [95]:
opt = optim.RMSprop((weight_est, ctrl_est), lr=1e-2)
pbar = tqdm(range(50), ncols=120)

for i in pbar:
    x_init = torch.randn(n_batch,n_state)

    loss = get_loss(x_init, weight_est, ctrl_est)
    opt.zero_grad()
    loss.backward()
    opt.step()

    # Used to checj the difference in ratio of ctrl cost and state cost
    model_loss = np.abs(100 - weight_est.sum().item() / ctrl_est.sum().item())

    pbar.set_description(f'Loss = {loss.item():.10f}, Model Loss = {model_loss:.2f}')


  0%|                                                                                            | 0/50 [00:00<?, ?it/s]


TypeError: LQRStep.<locals>.LQRStepFn.backward() takes from 3 to 5 positional arguments but 7 were given

In [83]:
loss

tensor(2609.1292, grad_fn=<MeanBackward0>)

In [67]:
ctrl_est

Parameter containing:
tensor([0.0149, 0.0153, 0.0151], requires_grad=True)

In [68]:
weight_est

Parameter containing:
tensor([1.5018, 1.4442, 1.5167], requires_grad=True)