## Target ref = (1,0,0)

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from controllers import DHDeePC as DDeePC
from tqdm import tqdm

import os
import io
import base64
import tempfile
from IPython.display import HTML

from mpc import mpc
from mpc.mpc import GradMethods, QuadCost
import mpc.util as eutil
from mpc.env_dx import pendulum, cartpole

%matplotlib inline
%reload_ext autoreload

# Data collection

In [2]:
params = torch.tensor((10., 1., 1.))
dx = pendulum.PendulumDx(params, simple=True).to(device='mps')
th = torch.tensor([np.pi])
thdot = torch.tensor([0])
xinit = torch.stack((torch.cos(th), torch.sin(th), thdot), dim=1)
x = xinit
u_init = None

ud = []
yd = []
T = 100

t_dir = tempfile.mkdtemp()
print('Tmp dir: {}'.format(t_dir))

for i in range(T):
    u = torch.randn((1,))*0.1
    x = dx(x,u.unsqueeze(0))
    ud = np.append(ud, u)
    yd = np.append(yd, x)

    fig, axs = plt.subplots(1, 1, figsize=(5,5))
    # axs = axs.reshape(-1)
    dx.get_frame(x, ax=axs)
    axs.get_xaxis().set_visible(False)
    axs.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(os.path.join(t_dir, '{:03d}.png'.format(i)))
    plt.close(fig)

Tmp dir: /var/folders/bv/3kttr09s6dsg653szk2tbhlh0000gn/T/tmp3ixihd9i


In [3]:
vid_fname = 'pendulum-collection.mp4'

if os.path.exists(vid_fname):
    os.remove(vid_fname)
    
cmd = 'ffmpeg -r 16 -f image2 -i {}/%03d.png -vcodec libx264 -crf 25  -pix_fmt yuv420p {}'.format(
    t_dir, vid_fname
)

os.system(cmd)
print('Saving video to: {}'.format(vid_fname))

video = io.open(vid_fname, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''<video alt="test" controls>
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii')))

ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers
  built with clang version 14.0.6
  configuration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1674566267822/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1674566267822/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable

Saving video to: pendulum-collection.mp4


frame=  100 fps=0.0 q=-1.0 Lsize=       8kB time=00:00:06.06 bitrate=  11.4kbits/s speed=28.4x    
video:6kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 30.942654%
[libx264 @ 0x127e36600] frame I:1     Avg QP:11.00  size:   252
[libx264 @ 0x127e36600] frame P:32    Avg QP:22.36  size:    93
[libx264 @ 0x127e36600] frame B:67    Avg QP:26.46  size:    40
[libx264 @ 0x127e36600] consecutive B-frames:  4.0% 16.0% 12.0% 68.0%
[libx264 @ 0x127e36600] mb I  I16..4: 95.6%  3.4%  1.0%
[libx264 @ 0x127e36600] mb P  I16..4:  0.5%  0.1%  0.2%  P16..4:  0.4%  0.2%  0.1%  0.0%  0.0%    skip:98.6%
[libx264 @ 0x127e36600] mb B  I16..4:  0.1%  0.0%  0.0%  B16..8:  0.7%  0.1%  0.0%  direct: 0.0%  skip:99.0%  L0:49.7% L1:47.1% BI: 3.2%
[libx264 @ 0x127e36600] 8x8 transform intra:6.0% inter:57.8%
[libx264 @ 0x127e36600] coded y,uvDC,uvAC intra: 10.5% 0.0% 0.0% inter: 0.1% 0.0% 0.0%
[libx264 @ 0x127e36600] i16 v,h,dc,p: 93%  7%  0%  0%
[libx264 @ 0x127e36600] i8 v,h,dc,ddl

In [4]:
ref = torch.Tensor(
    [[-4.3711e-08,  1.0000e+00,  1.0000e+00],
        [-1.0232e-01,  9.9475e-01,  2.0500e+00],
        [-2.5447e-01,  9.6708e-01,  3.0961e+00],
        [-4.4697e-01,  8.9455e-01,  4.1214e+00],
        [-6.5787e-01,  7.5313e-01,  5.0923e+00],
        [-8.4651e-01,  5.3237e-01,  5.8282e+00],
        [-9.6845e-01,  2.4921e-01,  6.1907e+00],
        [-9.9787e-01, -6.5240e-02,  6.3430e+00],
        [-9.2823e-01, -3.7200e-01,  6.3176e+00],
        [-7.7268e-01, -6.3479e-01,  6.1314e+00],
        [-5.5847e-01, -8.2953e-01,  5.8105e+00],
        [-3.1748e-01, -9.4826e-01,  5.3893e+00],
        [-7.7469e-02, -9.9699e-01,  4.9106e+00],
        [ 1.4233e-01, -9.8982e-01,  4.4071e+00],
        [ 3.3179e-01, -9.4335e-01,  3.9078e+00],
        [ 4.8803e-01, -8.7283e-01,  3.4325e+00],
        [ 6.1270e-01, -7.9032e-01,  2.9928e+00],
        [ 7.0983e-01, -7.0437e-01,  2.5959e+00],
        [ 7.8417e-01, -6.2054e-01,  2.2421e+00],
        [ 8.4033e-01, -5.4207e-01,  1.9306e+00],
        [ 8.8234e-01, -4.7061e-01,  1.6585e+00],
        [ 9.1356e-01, -4.0672e-01,  1.4224e+00],
        [ 9.3662e-01, -3.5034e-01,  1.2185e+00],
        [ 9.4961e-01, -3.1345e-01,  7.8225e-01],
        [ 9.6203e-01, -2.7295e-01,  8.4716e-01],
        [ 9.7201e-01, -2.3492e-01,  7.8642e-01],
        [ 9.7951e-01, -2.0138e-01,  6.8744e-01],
        [ 9.8504e-01, -1.7235e-01,  5.9112e-01],
        [ 9.8908e-01, -1.4738e-01,  5.0590e-01],
        [ 9.9203e-01, -1.2597e-01,  4.3232e-01],
        [ 9.9419e-01, -1.0763e-01,  3.6925e-01],
        [ 9.9576e-01, -9.1943e-02,  3.1533e-01],
        [ 9.9691e-01, -7.8529e-02,  2.6926e-01],
        [ 9.9775e-01, -6.7064e-02,  2.2991e-01],
        [ 9.9836e-01, -5.7268e-02,  1.9630e-01],
        [ 9.9880e-01, -4.8900e-02,  1.6760e-01],
        [ 9.9913e-01, -4.1754e-02,  1.4308e-01],
        [ 9.9936e-01, -3.5651e-02,  1.2215e-01],
        [ 9.9954e-01, -3.0439e-02,  1.0429e-01],
        [ 9.9966e-01, -2.5989e-02,  8.9035e-02],
        [ 9.9975e-01, -2.2189e-02,  7.6027e-02],
        [ 9.9982e-01, -1.8944e-02,  6.4911e-02],
        [ 9.9987e-01, -1.6174e-02,  5.5418e-02],
        [ 9.9990e-01, -1.3809e-02,  4.7302e-02],
        [ 9.9993e-01, -1.1789e-02,  4.0392e-02],
        [ 9.9995e-01, -1.0065e-02,  3.4479e-02],
        [ 9.9996e-01, -8.5938e-03,  2.9435e-02],
        [ 9.9997e-01, -7.3370e-03,  2.5137e-02],
        [ 9.9998e-01, -6.2639e-03,  2.1462e-02],
        [ 9.9999e-01, -5.3478e-03,  1.8323e-02],
        [ 9.9999e-01, -4.5656e-03,  1.5643e-02],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.],
        [1.,0.,0.]]
)

In [5]:
Tini = 1
m = 1
p = 3
Tf = 30

u_constraints = np.ones(Tf)*dx.upper
y_constraints = np.kron(np.ones(Tf), np.array([1,1,20]))
r = torch.kron(torch.ones(Tf), torch.Tensor([0.001]))
q = torch.kron(torch.ones(Tf), torch.Tensor([1,1,0.1]))

controller = DDeePC(
    ud=ud, yd=yd, u_constraints=u_constraints, y_constraints=y_constraints, n_batch=1,
    T=T, N=Tf, m=m, p=p, Tini=Tini, linear=False, stochastic=False, r=r, q=q, lam_g1=torch.Tensor([100]), lam_g2=torch.Tensor([100])
)



In [6]:
params = [param for param in controller.parameters()]
params

[Parameter containing:
 tensor([[ 0.0759,  0.0435,  0.0093,  ..., -0.0143, -0.1849,  0.0445],
         [ 0.0435,  0.0093,  0.1614,  ..., -0.1849,  0.0445,  0.0021],
         [ 0.0093,  0.1614,  0.0460,  ...,  0.0445,  0.0021, -0.1937],
         ...,
         [-0.0247, -0.0116,  0.0126,  ...,  0.0366, -0.0149,  0.0078],
         [-0.0116,  0.0126,  0.0679,  ..., -0.0149,  0.0078, -0.0932],
         [ 0.0126,  0.0679, -0.1162,  ...,  0.0078, -0.0932,  0.0909]],
        requires_grad=True),
 Parameter containing:
 tensor([[-1.0000e+00, -1.0000e+00, -1.0000e+00,  ..., -9.9988e-01,
          -9.9987e-01, -9.9985e-01],
         [-5.6949e-04, -1.4440e-03, -2.3345e-03,  ..., -1.5323e-02,
          -1.6338e-02, -1.7074e-02],
         [ 1.1388e-02,  1.7490e-02,  1.7810e-02,  ...,  5.9526e-02,
           2.0301e-02,  1.4719e-02],
         ...,
         [-9.9993e-01, -9.9994e-01, -9.9995e-01,  ..., -9.9999e-01,
          -9.9997e-01, -9.9995e-01],
         [ 1.1749e-02,  1.0814e-02,  1.0344e-02,  

In [7]:
def save_frame(t_dir, x, i):
    fig, axs = plt.subplots(1, 1, figsize=(5,5))
    # axs = axs.reshape(-1)
    dx.get_frame(x, ax=axs)
    axs.get_xaxis().set_visible(False)
    axs.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(os.path.join(t_dir, '{:03d}.png'.format(i)))
    plt.close(fig)

In [8]:
import time
epochs = 100
traj_len = 100
th = torch.tensor([np.pi/2])
thdot = torch.tensor([1])
ref = torch.reshape(ref, (-1,))

opt = optim.Adam(controller.parameters(), lr=0.1)
criterion = torch.nn.MSELoss()

t_dir = tempfile.mkdtemp()
print('Tmp dir: {}'.format(t_dir))

for j in range(epochs):

    yini = torch.stack((torch.cos(th), torch.sin(th), thdot), dim=1)
    uini = torch.Tensor([1.])
    traj = yini
    pbar = tqdm(range(traj_len))
    
    for i in pbar:
        input, output , t = controller(ref=ref[p*i:p*(Tf+i)], y_ini=yini, u_ini=uini)
        uini = input[0,:m]
        yini = dx(yini, uini.unsqueeze(0))
        traj = torch.cat((traj, yini), axis=1)
        
        if j > 0 : pbar.set_description(f'loss = {loss.item():.5f}, episode = {j}, time to backprop = {tl:.2f}')
        if j == epochs-1:
            save_frame(t_dir, yini, i)
    
    tl = time.time()
    loss = criterion(target=ref[:traj_len*p], input=traj[0][:traj_len*p])
    opt.zero_grad()
    loss.backward()
    opt.step()
    tl = time.time() - tl

Tmp dir: /var/folders/bv/3kttr09s6dsg653szk2tbhlh0000gn/T/tmphc0ni5td


100%|██████████| 100/100 [20:25<00:00, 12.26s/it]
loss = 1.74716, episode = 1, time to backprop = 7.65: 100%|██████████| 100/100 [00:46<00:00,  2.16it/s]
loss = 3.56840, episode = 2, time to backprop = 18.00: 100%|██████████| 100/100 [00:29<00:00,  3.39it/s]
loss = 1.78794, episode = 3, time to backprop = 17.86: 100%|██████████| 100/100 [00:16<00:00,  5.94it/s]
loss = 1.69619, episode = 4, time to backprop = 20.10: 100%|██████████| 100/100 [00:13<00:00,  7.16it/s]
loss = 1.66724, episode = 5, time to backprop = 21.26: 100%|██████████| 100/100 [00:14<00:00,  6.73it/s]
loss = 1.63396, episode = 6, time to backprop = 23.39: 100%|██████████| 100/100 [00:14<00:00,  6.84it/s]
loss = 1.61152, episode = 7, time to backprop = 23.57: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]
loss = 1.59208, episode = 8, time to backprop = 25.60: 100%|██████████| 100/100 [00:18<00:00,  5.27it/s]


KeyboardInterrupt: 

In [None]:
params = [param for param in controller.parameters()]
params

[Parameter containing:
 tensor([[-0.4128, -0.0045, -0.3667,  ...,  0.2694,  0.5155,  0.4483],
         [-0.0601,  0.3409,  0.6240,  ..., -0.3258, -0.5235, -0.5920],
         [ 0.6157,  0.5936,  0.4500,  ..., -0.4773, -0.5751, -0.6719],
         ...,
         [ 0.3386,  0.6559,  0.4238,  ..., -0.6103, -0.5911, -0.3220],
         [ 0.2538,  0.3372,  0.5954,  ..., -0.7194, -0.2946, -0.3509],
         [ 0.7486,  0.4213,  0.2401,  ..., -0.2905, -0.5166, -0.7060]],
        requires_grad=True),
 Parameter containing:
 tensor([[-1.6033, -1.6030, -1.5999,  ..., -0.3980, -0.3991, -0.3994],
         [ 0.5862,  0.5321,  0.6140,  ..., -0.5871, -0.5615, -0.6064],
         [ 0.9484,  0.7414,  0.5909,  ..., -0.8734, -0.6218, -0.5908],
         ...,
         [-0.4006, -0.4008, -0.3994,  ..., -1.5994, -1.6003, -1.6006],
         [-0.3770, -0.4464, -0.6447,  ...,  0.4444,  0.5665,  0.5609],
         [-0.3528, -0.5836, -0.6529,  ...,  0.4131,  0.5431,  0.5214]],
        requires_grad=True)]

In [None]:
vid_fname = 'pendulum-notrain.mp4'

if os.path.exists(vid_fname):
    os.remove(vid_fname)
    
cmd = 'ffmpeg -r 16 -f image2 -i {}/%03d.png -vcodec libx264 -crf 25  -pix_fmt yuv420p {}'.format(
    t_dir, vid_fname
)

os.system(cmd)
print('Saving video to: {}'.format(vid_fname))

video = io.open(vid_fname, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''<video alt="test" controls>
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii')))

ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers
  built with clang version 14.0.6
  configuration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1674566267822/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1674566267822/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable

Saving video to: pendulum-notrain.mp4


frame=  100 fps=0.0 q=-1.0 Lsize=      19kB time=00:00:06.06 bitrate=  25.5kbits/s speed=31.9x    
video:17kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 11.194715%
[libx264 @ 0x153632d60] frame I:1     Avg QP:16.27  size:   538
[libx264 @ 0x153632d60] frame P:46    Avg QP:22.12  size:   246
[libx264 @ 0x153632d60] frame B:53    Avg QP:26.19  size:    92
[libx264 @ 0x153632d60] consecutive B-frames: 23.0% 16.0%  9.0% 52.0%
[libx264 @ 0x153632d60] mb I  I16..4: 53.7% 45.1%  1.2%
[libx264 @ 0x153632d60] mb P  I16..4:  0.1%  0.5%  0.3%  P16..4:  0.9%  0.7%  0.3%  0.0%  0.0%    skip:97.2%
[libx264 @ 0x153632d60] mb B  I16..4:  0.0%  0.1%  0.0%  B16..8:  1.8%  0.5%  0.0%  direct: 0.0%  skip:97.6%  L0:41.7% L1:43.2% BI:15.1%
[libx264 @ 0x153632d60] 8x8 transform intra:49.0% inter:17.1%
[libx264 @ 0x153632d60] coded y,uvDC,uvAC intra: 6.1% 0.0% 0.0% inter: 0.2% 0.0% 0.0%
[libx264 @ 0x153632d60] i16 v,h,dc,p: 79% 19%  2%  0%
[libx264 @ 0x153632d60] i8 v,h,dc,dd

In [None]:
A = np.array([[1,2],[3,4],[5,6]])
A_ = np.linalg.pinv(A)
A @ A_ @ A

array([[1., 2.],
       [3., 4.],
       [5., 6.]])

: 