In [1]:
import torch

from mpc import mpc
from mpc.mpc import QuadCost, LinDx, GradMethods
from mpc.env_dx import cartpole
import torch
from torch.autograd import Function, Variable
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
import numpy as np
import numpy.random as npr
from mpc import util
import matplotlib.pyplot as plt

import os
import io
import base64
import tempfile
from IPython.display import HTML

from tqdm import tqdm

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
arr = torch.randn(50,4)
print(arr[:,1].max())
print(arr[:,1].min())
if any(torch.abs(arr[:,1]) >= 2) : print('violated')

tensor(2.3808)
tensor(-2.5610)
violated


In [4]:
class CartpoleDx(nn.Module):
    def __init__(self, params=None):
        super().__init__()

        self.n_state = 4
        self.n_ctrl = 1

        # model parameters
        if params is None:
            # gravity, masscart, masspole, length
            self.params = Variable(torch.Tensor((9.8, 1.0, 0.1, 0.5)))
        else:
            self.params = params
        assert len(self.params) == 4
        self.force_mag = 100.

        self.theta_threshold_radians = np.pi#12 * 2 * np.pi / 360
        self.x_threshold = 2.4
        self.max_velocity = 10

        self.dt = 0.05

        self.lower = -self.force_mag
        self.upper = self.force_mag

        # 0  1      2        3   4
        # x dx cos(th) sin(th) dth
        self.goal_state = torch.Tensor(  [ 0.,  0., 0.,   0.])
        self.goal_weights = torch.Tensor([100, 100,  100, 100])
        self.ctrl_penalty = 0.01

        self.mpc_eps = 1e-4
        self.linesearch_decay = 0.5
        self.max_linesearch_iter = 2

    def forward(self, state, u):
        squeeze = state.ndimension() == 1

        if squeeze:
            state = state.unsqueeze(0)
            u = u.unsqueeze(0)

        if state.is_cuda and not self.params.is_cuda:
            self.params = self.params.cuda()
        gravity, masscart, masspole, length = torch.unbind(self.params)
        total_mass = masspole + masscart
        polemass_length = masspole * length

        u = torch.clamp(u[:,0], -self.force_mag, self.force_mag)

        x, dx, th, dth = torch.unbind(state, dim=1)
        # th = torch.atan2(sin_th, cos_th)
        cos_th, sin_th = torch.cos(th), torch.sin(th)

        cart_in = (u + polemass_length * dth**2 * sin_th) / total_mass
        th_acc = (gravity * sin_th - cos_th * cart_in) / \
                 (length * (4./3. - masspole * cos_th**2 /
                                     total_mass))
        xacc = cart_in - polemass_length * th_acc * cos_th / total_mass

        x = x + self.dt * dx
        dx = dx + self.dt * xacc
        th = th + self.dt * dth
        dth = dth + self.dt * th_acc

        state = torch.stack((
            x, dx, th, dth
        ), 1)

        return state

    def get_frame(self, state, ax=None):
        state = util.get_data_maybe(state.view(-1))
        assert len(state) == 4
        x, dx, th, dth = torch.unbind(state)
        cos_th, sin_th = torch.cos(th), torch.sin(th)
        gravity, masscart, masspole, length = torch.unbind(self.params)
        th = np.arctan2(sin_th, cos_th)
        th_x = sin_th*length
        th_y = cos_th*length

        if ax is None:
            fig, ax = plt.subplots(figsize=(6,6))
        else:
            fig = ax.get_figure()
        ax.plot((x,x+th_x), (0, th_y), color='k')
        ax.set_xlim((-length*2, length*2))
        ax.set_ylim((-length*2, length*2))
        return fig, ax

    def get_true_obj(self):
        q = torch.cat((
            self.goal_weights,
            self.ctrl_penalty*torch.ones(self.n_ctrl)
        ))
        assert not hasattr(self, 'mpc_lin')
        px = -torch.sqrt(self.goal_weights)*self.goal_state #+ self.mpc_lin
        p = torch.cat((px, torch.zeros(self.n_ctrl)))
        return Variable(q), Variable(p)

In [5]:
dx = CartpoleDx().to(device='mps')

n_batch, T, mpc_T = 1, 70, 20

def uniform(shape, low, high):
    r = high-low
    return torch.rand(shape)*r+low

# torch.manual_seed(0)
th = uniform(n_batch, -0.01, 0.01)
thdot = torch.Tensor([0])
x = torch.Tensor([0])
xdot = torch.Tensor([0])
xinit = torch.stack((x, xdot, th, thdot), dim=1)
x = xinit
u_init = None

q, p = dx.get_true_obj()
Q = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(
    mpc_T, n_batch, 1, 1
)
p = p.unsqueeze(0).repeat(mpc_T, n_batch, 1)

t_dir = tempfile.mkdtemp()
print('Tmp dir: {}'.format(t_dir))

controller = mpc.MPC(
        dx.n_state, dx.n_ctrl, mpc_T,
        u_init=u_init,
        u_lower=dx.lower, u_upper=dx.upper,
        lqr_iter=50,
        verbose=0,
        exit_unconverged=False,
        detach_unconverged=False,
        backprop=False,
        linesearch_decay=dx.linesearch_decay,
        max_linesearch_iter=dx.max_linesearch_iter,
        grad_method=GradMethods.AUTO_DIFF,
        eps=1e-2,
    ).to(device='mps')

print(xinit.shape)
# F, f = controller.linearize_dynamics(torch.zeros((mpc_T,n_batch,4)), util.detach_maybe(torch.zeros((mpc_T,n_batch,1))), dx, diff=True)
action_history = []
state_history= []
for i in tqdm(range(T)):
    # x += torch.randn(x.shape)*0.001
    nominal_states, nominal_actions, nominal_objs = controller(x, QuadCost(Q, p), dx)
    
    next_action = nominal_actions[0] + torch.randn(1,)*0.001
    action_history.append(next_action)
    u_init = torch.cat((nominal_actions[1:], torch.zeros(1, n_batch, dx.n_ctrl)), dim=0)
    u_init[-2] = u_init[-3]
    x = dx(x, next_action) 
    state_history.append(x)
    
    n_row, n_col = 1, 1
    fig, axs = plt.subplots(n_row, n_col, figsize=(3*n_col,3*n_row))

    for r in range(n_batch):
        dx.get_frame(x[r], ax=axs)
        axs.get_xaxis().set_visible(False)
        axs.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(os.path.join(t_dir, '{:03d}.png'.format(i)))
    plt.close(fig)
    
action_history = torch.stack(action_history).detach()[:,:,0]
state_history = torch.stack(state_history).detach()[:,:,:]

Tmp dir: /tmp/tmprk5zey5f
torch.Size([1, 4])


100%|██████████| 70/70 [00:04<00:00, 14.56it/s]


In [6]:
state_history = state_history.squeeze(1)
state_history = state_history.detach().cpu().numpy()
action_history = action_history.squeeze(1)
action_history = action_history.detach().cpu().numpy()
action_history.shape

(70,)

In [7]:
from numpy import savetxt
savetxt('../cartpole_yd.csv', state_history, delimiter=',')
savetxt('../cartpole_ud.csv', action_history, delimiter=',')