In [1]:
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import os
import io
import base64
import tempfile
from IPython.display import HTML
from numpy import loadtxt

from controllers import DDeePC
from controller_utils import CartpoleDx, sample_initial_signal
from tqdm import tqdm

from mpc import mpc
from mpc.mpc import GradMethods, QuadCost
import mpc.util as eutil


%matplotlib inline


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# torch.set_num_threads(8)

## Data Collection


In [3]:
Tini = 4
m = 1
p = 4
Tf = 20
T = (m+1)*(Tini + Tf + p) + 14
n_batch = 16
# torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'mps'
print(device)
ud = loadtxt('../badcartpole_ud.csv', delimiter=',')
yd = loadtxt('../badcartpole_yd.csv', delimiter=',')
yd = yd.reshape(T*p,)
noise = np.diag(np.kron(np.ones(T), np.array([0.01, 0.01, 0.005, 0.01]))) @ np.random.randn(*yd.shape)
yd = yd + noise
dx = CartpoleDx().to(device)
def uniform(shape, low, high):
    r = high-low
    return torch.rand(shape)*r+low

cuda


In [4]:
u_constraints = np.ones(Tf)*dx.upper
y_constraints = np.kron(np.ones(Tf), np.array([0.25, 1, 0.1, 1]))
r = torch.ones(m)*0.01
q = torch.ones(p)*100
lam_g1 = torch.Tensor([20])
lam_g2 = torch.Tensor([20])

controller = DDeePC(
    ud=ud, yd=yd, u_constraints=u_constraints, y_constraints=y_constraints,
    Tini=Tini, T=T, N=Tf, m=m, p=p, n_batch=n_batch, device=device,
    linear=False, stochastic=True, q=q, r=r
).to(device)

for param in controller.parameters():
    print(param)

Parameter containing:
tensor([51.9858], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([0.0009], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([0.4251], device='cuda:0', requires_grad=True)


In [5]:
episodes = 50
ref = torch.zeros(size=(n_batch,p))
perfect = torch.kron(torch.ones(episodes+Tini), ref).to(device)
ref = torch.kron(torch.ones(Tf), ref).to(device)
n_row = np.sqrt(n_batch)
n_col = n_row
opt = torch.optim.Rprop(controller.parameters(), lr=0.01, step_sizes=(1e-3,1e2))
criterion = torch.nn.HuberLoss()

epochs = 100
pbar = tqdm(range(epochs))
cum_loss = []
done = False

for j in pbar:

    # uini, yini = sample_initial_signal(Tini=Tini, m=m, p=p, batch=n_batch, ud=ud, yd=yd)
    uini = torch.zeros(size=(n_batch, Tini*m)).to(device)
    th = uniform((n_batch), -0.01, 0.01)
    # x = uniform((n_batch), -0.001, 0.001)
    # thdot = uniform((n_batch), -0.001, 0.001)
    # xdot = uniform((n_batch), -0.001, 0.001)
    yini = torch.stack((torch.Tensor([0]), torch.Tensor([0]), th, torch.Tensor([0])), dim=1).repeat(1,Tini).to(device)
    traj = yini
    count = 0

    t_dir = tempfile.mkdtemp()

    for i in range(episodes):
        count = i
        yini = traj[:,-p*Tini:]
        
        noise_weight = torch.diag(torch.kron(torch.ones(Tini), torch.Tensor([0.01, 0.01, 0.005, 0.01]))).repeat(n_batch, 1, 1)
        noise = torch.randn(n_batch,p*Tini,1)
        noise = torch.bmm(noise_weight, noise).squeeze(-1).to(device)

        u_pred, _ = controller(ref=ref, y_ini=yini+noise, u_ini=uini)
        input = u_pred[:,:m]
        y = dx(yini[:,-p:], input)

        traj = torch.cat((traj, y), axis=1)
        
        if torch.any(torch.abs(yini[:,-2]) >= 0.2) or torch.any(torch.abs(yini[:,1]) >= 0.25):
            break

        if i >= episodes -1 : 
            done = True

    # if count < episodes-1:
    #     remainder = torch.ones(n_batch, (episodes-count-1)*p).to(device)
    #     traj = torch.cat((traj, remainder), axis=1)

    # loss = criterion(target=perfect, input=traj)
    cum_loss.append(loss.item())
    opt.zero_grad()
    loss.backward()
    opt.step()

    pbar.set_description(f'l={loss.item():.3f}, ly={controller.lam_y.data.item():.3f},\
 lg1={controller.lam_g1.data.item():.3f}, lg2={controller.lam_g2.data.item():.3f}, ep={count}\
 q={controller.q.data}, r={controller.r.data.item():.3f}')

l=1.760, ly=4097.963, lg1=-32.707, lg2=0.021, ep=39 q=tensor([10., 10., 10., 10.], device='cuda:0'), r=0.100:  86%|████████▌ | 86/100 [25:08<02:51, 12.22s/it]

In [1]:
plt.plot(range(len(cum_loss)), cum_loss)
plt.ylabel("Loss")
plt.xlabel("epoch")
plt.grid('on')
plt.show()

NameError: name 'plt' is not defined

In [None]:
for param in controller.parameters():
    print(param)

Parameter containing:
tensor([12.0913], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([-99.9990], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([0.5743], device='cuda:0', requires_grad=True)


In [None]:
vid_fname = 'cartpole_initial.mp4'
if os.path.exists(vid_fname):
    os.remove(vid_fname)
cmd = 'ffmpeg -r 16 -f image2 -i {}/%03d.png -vcodec libx264 -crf 25 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -pix_fmt yuv420p {}'.format(
    t_dir, vid_fname
)
os.system(cmd)
# print('Saving video to: {}'.format(vid_fname))

video = io.open(vid_fname, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''<video alt="test" controls>
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii')))

ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libsw

FileNotFoundError: [Errno 2] No such file or directory: 'cartpole_initial.mp4'