Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Brandon Amos
committed
Oct 2, 2018
0 parents
commit 769419f
Showing
15 changed files
with
2,667 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dist | ||
*-info | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# qpth • [![PyPi][pypi-image]][pypi] [![License][license-image]][license] | ||
|
||
[pypi-image]: https://img.shields.io/pypi/v/mpc.svg | ||
[pypi]: https://pypi.python.org/pypi/mpc | ||
|
||
[license-image]: http://img.shields.io/badge/license-Apache--2-blue.svg?style=flat | ||
[license]: LICENSE | ||
|
||
*A fast and differentiable model predictive control solver for PyTorch. | ||
Crafted by <a href="https://bamos.github.io">Brandon Amos</a>, | ||
Ivan Jimenez, | ||
Jacob Sacks, | ||
<a href='https://www.cc.gatech.edu/~bboots3/'>Byron Boots</a>, | ||
and | ||
<a href="https://zicokolter.com">J. Zico Kolter</a>.* | ||
|
||
--- | ||
|
||
+ [More details are available on our project website here](http://locuslab.github.io/mpc.pytorch) | ||
+ This code curently only support PyTorch 0.3 and an update | ||
to a newer version of PyTorch is coming soon. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import torch | ||
from torch.autograd import Function, Variable | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.nn.parameter import Parameter | ||
|
||
from . import util | ||
|
||
ACTS = { | ||
'sigmoid': F.sigmoid, | ||
'relu': F.relu, | ||
'elu': F.elu, | ||
} | ||
|
||
class NNDynamics(nn.Module): | ||
def __init__(self, n_state, n_ctrl, hidden_sizes=[100], | ||
activation='sigmoid', passthrough=True): | ||
super().__init__() | ||
|
||
self.passthrough = passthrough | ||
|
||
self.fcs = [] | ||
in_sz = n_state+n_ctrl | ||
for out_sz in hidden_sizes + [n_state]: | ||
fc = nn.Linear(in_sz, out_sz) | ||
self.fcs.append(fc) | ||
in_sz = out_sz | ||
self.fcs = nn.ModuleList(self.fcs) | ||
|
||
assert activation in ACTS.keys() | ||
act_f = ACTS[activation] | ||
self.activation = activation | ||
self.acts = [act_f]*(len(self.fcs)-1)+[lambda x:x] # Activation functions. | ||
|
||
self.Ws = [y.weight for y in self.fcs] | ||
self.zs = [] # Activations. | ||
|
||
|
||
def __getstate__(self): | ||
return (self.fcs, self.activation, self.passthrough) | ||
|
||
|
||
def __setstate__(self, state): | ||
super().__init__() | ||
if len(state) == 2: | ||
# TODO: Remove this soon, keeping for some old models. | ||
self.fcs, self.activation = state | ||
self.passthrough = True | ||
else: | ||
self.fcs, self.activation, self.passthrough = state | ||
|
||
act_f = ACTS[self.activation] | ||
self.acts = [act_f]*(len(self.fcs)-1)+[lambda x:x] # Activation functions. | ||
self.Ws = [y.weight for y in self.fcs] | ||
|
||
|
||
def forward(self, x, u): | ||
x_dim, u_dim = x.ndimension(), u.ndimension() | ||
if x_dim == 1: | ||
x = x.unsqueeze(0) | ||
if u_dim == 1: | ||
u = u.unsqueeze(0) | ||
|
||
self.zs = [] | ||
z = torch.cat((x, u), 1) | ||
for act, fc in zip(self.acts, self.fcs): | ||
z = act(fc(z)) | ||
self.zs.append(z) | ||
|
||
# Hack: Don't include the output. | ||
self.zs = self.zs[:-1] | ||
|
||
if self.passthrough: | ||
z += x | ||
|
||
if x_dim == 1: | ||
z = z.squeeze(0) | ||
|
||
return z | ||
|
||
def grad_input(self, x, u): | ||
assert isinstance(x, Variable) == isinstance(u, Variable) | ||
diff = isinstance(x, Variable) | ||
|
||
x_dim, u_dim = x.ndimension(), u.ndimension() | ||
n_batch, n_state = x.size() | ||
_, n_ctrl = u.size() | ||
|
||
if not diff: | ||
Ws = [W.data for W in self.Ws] | ||
zs = [z.data for z in self.zs] | ||
else: | ||
Ws = self.Ws | ||
zs = self.zs | ||
|
||
assert len(zs) == len(Ws)-1 | ||
grad = Ws[-1].repeat(n_batch,1,1) | ||
for i in range(len(zs)-1, 0-1, -1): | ||
n_out, n_in = Ws[i].size() | ||
|
||
if self.activation == 'relu': | ||
I = util.get_data_maybe(zs[i] <= 0.).unsqueeze(2).repeat(1,1,n_in) | ||
Wi_grad = Ws[i].repeat(n_batch,1,1) | ||
Wi_grad[I] = 0. | ||
elif self.activation == 'sigmoid': | ||
d = zs[i]*(1.-zs[i]) | ||
d = d.unsqueeze(2).expand(n_batch, n_out, n_in) | ||
Wi_grad = Ws[i].repeat(n_batch,1,1)*d | ||
else: | ||
assert False | ||
|
||
grad = grad.bmm(Wi_grad) | ||
|
||
R = grad[:,:,:n_state] | ||
S = grad[:,:,n_state:] | ||
|
||
if self.passthrough: | ||
I = torch.eye(n_state).type_as(util.get_data_maybe(R)) \ | ||
.unsqueeze(0).repeat(n_batch, 1, 1) | ||
|
||
if diff: | ||
I = Variable(I) | ||
|
||
R = R + I | ||
|
||
if x_dim == 1: | ||
R = R.squeeze(0) | ||
S = S.squeeze(0) | ||
|
||
return R, S | ||
|
||
|
||
class CtrlPassthroughDynamics(nn.Module): | ||
def __init__(self, dynamics): | ||
super().__init__() | ||
self.dynamics = dynamics | ||
|
||
def forward(self, tilde_x, u): | ||
tilde_x_dim, u_dim = tilde_x.ndimension(), u.ndimension() | ||
if tilde_x_dim == 1: | ||
tilde_x = tilde_x.unsqueeze(0) | ||
if u_dim == 1: | ||
u = u.unsqueeze(0) | ||
|
||
n_ctrl = u.size(1) | ||
x = tilde_x[:,n_ctrl:] | ||
xtp1 = self.dynamics(x, u) | ||
tilde_xtp1 = torch.cat((u, xtp1), dim=1) | ||
|
||
if tilde_x_dim == 1: | ||
tilde_xtp1 = tilde_xtp1.squeeze() | ||
|
||
return tilde_xtp1 | ||
|
||
def grad_input(self, x, u): | ||
assert False, "Unimplemented" | ||
|
||
|
||
class AffineDynamics(nn.Module): | ||
def __init__(self, A, B, c=None): | ||
super(AffineDynamics, self).__init__() | ||
|
||
assert A.ndimension() == 2 | ||
assert B.ndimension() == 2 | ||
if c is not None: | ||
assert c.ndimension() == 1 | ||
|
||
self.A = A | ||
self.B = B | ||
self.c = c | ||
|
||
def forward(self, x, u): | ||
if not isinstance(x, Variable) and isinstance(self.A, Variable): | ||
A = self.A.data | ||
B = self.B.data | ||
c = self.c.data if self.c is not None else 0. | ||
else: | ||
A = self.A | ||
B = self.B | ||
c = self.c if self.c is not None else 0. | ||
|
||
x_dim, u_dim = x.ndimension(), u.ndimension() | ||
if x_dim == 1: | ||
x = x.unsqueeze(0) | ||
if u_dim == 1: | ||
u = u.unsqueeze(0) | ||
|
||
z = x.mm(A.t()) + u.mm(B.t()) + c | ||
|
||
if x_dim == 1: | ||
z = z.squeeze(0) | ||
|
||
return z | ||
|
||
def grad_input(self, x, u): | ||
n_batch = x.size(0) | ||
A, B = self.A, self.B | ||
A = A.unsqueeze(0).repeat(n_batch, 1, 1) | ||
B = B.unsqueeze(0).repeat(n_batch, 1, 1) | ||
if not isinstance(x, Variable) and isinstance(A, Variable): | ||
A, B = A.data, B.data | ||
return A, B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import torch | ||
from torch.autograd import Function, Variable | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.nn.parameter import Parameter | ||
|
||
import numpy as np | ||
|
||
from empc import util | ||
|
||
import os | ||
|
||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
plt.style.use('bmh') | ||
|
||
class CartpoleDx(nn.Module): | ||
def __init__(self, params=None): | ||
super().__init__() | ||
|
||
self.n_state = 5 | ||
self.n_ctrl = 1 | ||
|
||
# model parameters | ||
if params is None: | ||
# gravity, masscart, masspole, length | ||
self.params = Variable(torch.Tensor((9.8, 1.0, 0.1, 0.5))) | ||
else: | ||
self.params = params | ||
assert len(self.params) == 4 | ||
self.force_mag = 100. | ||
|
||
self.theta_threshold_radians = np.pi#12 * 2 * np.pi / 360 | ||
self.x_threshold = 2.4 | ||
self.max_velocity = 10 | ||
|
||
self.dt = 0.05 | ||
|
||
self.lower = -self.force_mag | ||
self.upper = self.force_mag | ||
|
||
# 0 1 2 3 4 | ||
# x dx cos(th) sin(th) dth | ||
self.goal_state = torch.Tensor( [ 0., 0., 1., 0., 0.]) | ||
self.goal_weights = torch.Tensor([0.1, 0.1, 1., 1., 0.1]) | ||
self.ctrl_penalty = 0.001 | ||
|
||
self.mpc_eps = 1e-4 | ||
self.linesearch_decay = 0.5 | ||
self.max_linesearch_iter = 2 | ||
|
||
def forward(self, state, u): | ||
squeeze = state.ndimension() == 1 | ||
|
||
if squeeze: | ||
state = state.unsqueeze(0) | ||
u = u.unsqueeze(0) | ||
|
||
if state.is_cuda and not self.params.is_cuda: | ||
self.params = self.params.cuda() | ||
gravity, masscart, masspole, length = torch.unbind(self.params) | ||
total_mass = masspole + masscart | ||
polemass_length = masspole * length | ||
|
||
u = torch.clamp(u[:,0], -self.force_mag, self.force_mag) | ||
|
||
x, dx, cos_th, sin_th, dth = torch.unbind(state, dim=1) | ||
th = torch.atan2(sin_th, cos_th) | ||
|
||
cart_in = (u + polemass_length * dth**2 * sin_th) / total_mass | ||
th_acc = (gravity * sin_th - cos_th * cart_in) / \ | ||
(length * (4./3. - masspole * cos_th**2 / | ||
total_mass)) | ||
xacc = cart_in - polemass_length * th_acc * cos_th / total_mass | ||
|
||
x = x + self.dt * dx | ||
dx = dx + self.dt * xacc | ||
th = th + self.dt * dth | ||
dth = dth + self.dt * th_acc | ||
|
||
state = torch.stack(( | ||
x, dx, torch.cos(th), torch.sin(th), dth | ||
), 1) | ||
|
||
return state | ||
|
||
def get_frame(self, state): | ||
state = util.get_data_maybe(state.view(-1)) | ||
assert len(state) == 5 | ||
x, dx, cos_th, sin_th, dth = torch.unbind(state) | ||
th = np.arctan2(sin_th, cos_th) | ||
th_x = sin_th*self.length*2 | ||
th_y = cos_th*self.length*2 | ||
fig, ax = plt.subplots(figsize=(6,6)) | ||
ax.plot((x,x+th_x), (0, th_y), color='k') | ||
ax.set_xlim((-5., 5.)) | ||
ax.set_ylim((-2., 2.)) | ||
return fig, ax | ||
|
||
def get_true_obj(self): | ||
q = torch.cat(( | ||
self.goal_weights, | ||
self.ctrl_penalty*torch.ones(self.n_ctrl) | ||
)) | ||
assert not hasattr(self, 'mpc_lin') | ||
px = -torch.sqrt(self.goal_weights)*self.goal_state #+ self.mpc_lin | ||
p = torch.cat((px, torch.zeros(self.n_ctrl))) | ||
return Variable(q), Variable(p) | ||
|
||
if __name__ == '__main__': | ||
dx = CartpoleDx() | ||
n_batch, T = 8, 50 | ||
u = torch.zeros(T, n_batch, dx.n_ctrl) | ||
xinit = torch.zeros(n_batch, dx.n_state) | ||
th = 1. | ||
xinit[:,2] = np.cos(th) | ||
xinit[:,3] = np.sin(th) | ||
x = xinit | ||
for t in range(T): | ||
x = dx(x, u[t]) | ||
fig, ax = dx.get_frame(x[0]) | ||
fig.savefig('{:03d}.png'.format(t)) | ||
plt.close(fig) | ||
|
||
vid_file = 'cartpole_vid.mp4' | ||
if os.path.exists(vid_file): | ||
os.remove(vid_file) | ||
cmd = ('/usr/bin/ffmpeg -loglevel quiet ' | ||
'-r 32 -f image2 -i %03d.png -vcodec ' | ||
'libx264 -crf 25 -pix_fmt yuv420p {}').format( | ||
vid_file | ||
) | ||
os.system(cmd) | ||
for t in range(T): | ||
os.remove('{:03d}.png'.format(t)) |
Oops, something went wrong.