<a href="https://colab.research.google.com/github/eisbetterthanpi/JEPA/blob/main/JEPA_mpc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## mpc

### locuslab_mpc

In [None]:
# https://github.com/locuslab/mpc.pytorch


In [None]:
# @title util
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/util.py
import torch
from torch.autograd import Function, Variable
from torch.nn import Module
from torch.nn.parameter import Parameter
import operator

def jacobian(f, x, eps):
    if x.ndimension() == 2:
        assert x.size(0) == 1
        x = x.squeeze()
    e = Variable(torch.eye(len(x)).type_as(get_data_maybe(x)))
    J = []
    for i in range(len(x)):
        J.append((f(x + eps*e[i]) - f(x - eps*e[i]))/(2.*eps))
    J = torch.stack(J).transpose(0,1)
    return J

def expandParam(X, n_batch, nDim):
    if X.ndimension() in (0, nDim):
        return X, False
    elif X.ndimension() == nDim - 1:
        return X.unsqueeze(0).expand(*([n_batch] + list(X.size()))), True
    else:
        raise RuntimeError("Unexpected number of dimensions.")

def bdiag(d):
    assert d.ndimension() == 2
    nBatch, sz = d.size()
    dtype = d.type() if not isinstance(d, Variable) else d.data.type()
    D = torch.zeros(nBatch, sz, sz).type(dtype)
    I = torch.eye(sz).repeat(nBatch, 1, 1).type(dtype).byte()
    D[I] = d.view(-1)
    return D

def bger(x, y):
    return x.unsqueeze(2).bmm(y.unsqueeze(1))

def bmv(X, y):
    return X.bmm(y.unsqueeze(2)).squeeze(2)

def bquad(x, Q):
    return x.unsqueeze(1).bmm(Q).bmm(x.unsqueeze(2)).squeeze(1).squeeze(1)

def bdot(x, y):
    return torch.bmm(x.unsqueeze(1), y.unsqueeze(2)).squeeze(1).squeeze(1)

def eclamp(x, lower, upper):
    # In-place!!
    if type(lower) == type(x):
        assert x.size() == lower.size()
    if type(upper) == type(x):
        assert x.size() == upper.size()
    I = x < lower
    x[I] = lower[I] if not isinstance(lower, float) else lower
    I = x > upper
    x[I] = upper[I] if not isinstance(upper, float) else upper
    return x

def get_data_maybe(x):
    return x if not isinstance(x, Variable) else x.data

_seen_tables = []
def table_log(tag, d):
    # TODO: There's probably a better way to handle formatting here, or a better way altogether to replace this quick hack.
    global _seen_tables
    def print_row(r):
        print('| ' + ' | '.join(r) + ' |')
    if tag not in _seen_tables:
        print_row(map(operator.itemgetter(0), d))
        _seen_tables.append(tag)
    s = []
    for di in d:
        assert len(di) in [2,3]
        if len(di) == 3:
            e, fmt = di[1:]
            s.append(fmt.format(e))
        else:
            e = di[1]
            s.append(str(e))
    print_row(s)

def get_traj(T, u, x_init, dynamics):
    if isinstance(dynamics, LinDx):
        F = get_data_maybe(dynamics.F)
        f = get_data_maybe(dynamics.f)
        if f is not None:
            assert f.shape == F.shape[:3]
    x = [get_data_maybe(x_init)]
    for t in range(T):
        xt = x[t]
        ut = get_data_maybe(u[t])
        if t < T-1:
            # new_x = f(Variable(xt), Variable(ut)).data
            if isinstance(dynamics, LinDx):
                xut = torch.cat((xt, ut), 1)
                new_x = bmv(F[t], xut)
                if f is not None:
                    new_x += f[t]
            else:
                new_x = dynamics(Variable(xt), Variable(ut)).data
            x.append(new_x)
    x = torch.stack(x, dim=0)
    return x

def get_cost(T, u, cost, dynamics=None, x_init=None, x=None):
    assert x_init is not None or x is not None
    if isinstance(cost, QuadCost):
        C = get_data_maybe(cost.C)
        c = get_data_maybe(cost.c)
    if x is None:
        x = get_traj(T, u, x_init, dynamics)
    objs = []
    for t in range(T):
        xt = x[t]
        ut = u[t]
        xut = torch.cat((xt, ut), 1)
        if isinstance(cost, QuadCost):
            obj = 0.5*bquad(xut, C[t]) + bdot(xut, c[t])
        else:
            obj = cost(xut)
        objs.append(obj)
    objs = torch.stack(objs, dim=0)
    total_obj = torch.sum(objs, dim=0)
    return total_obj

def detach_maybe(x):
    if x is None:
        return None
    return x if not x.requires_grad else x.detach()

def data_maybe(x):
    if x is None:
        return None
    return x.data



In [None]:
# @title pnqp
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/pnqp.py
import torch

# \@profile
def pnqp(H, q, lower, upper, x_init=None, n_iter=20):
    GAMMA = 0.1
    n_batch, n, _ = H.size()
    pnqp_I = 1e-11*torch.eye(n).type_as(H).expand_as(H)
    def obj(x):
        return 0.5*bquad(x, H) + bdot(q, x)
    if x_init is None:
        if n == 1:
            x_init = -(1./H.squeeze(2))*q
        else:
            H_lu = H.lu()
            x_init = -q.unsqueeze(2).lu_solve(*H_lu).squeeze(2) # Clamped in the x assignment.
    else:
        x_init = x_init.clone() # Don't over-write the original x_init.
    x = eclamp(x_init, lower, upper)
    # Active examples in the batch.
    J = torch.ones(n_batch).type_as(x).byte()
    for i in range(n_iter):
        g = bmv(H, x) + q
        # TODO: Could clean up the types here.
        Ic = (((x == lower) & (g > 0)) | ((x == upper) & (g < 0))).float()
        If = 1-Ic
        if If.is_cuda:
            Hff_I = bger(If.float(), If.float()).type_as(If)
            not_Hff_I = 1-Hff_I
            Hfc_I = bger(If.float(), Ic.float()).type_as(If)
        else:
            Hff_I = bger(If, If)
            not_Hff_I = 1-Hff_I
            Hfc_I = bger(If, Ic)
        g_ = g.clone()
        g_[Ic.bool()] = 0.
        H_ = H.clone()
        H_[not_Hff_I.bool()] = 0.0
        H_ += pnqp_I
        if n == 1:
            dx = -(1./H_.squeeze(2))*g_
        else:
            H_lu_ = H_.lu()
            dx = -g_.unsqueeze(2).lu_solve(*H_lu_).squeeze(2)
        J = torch.norm(dx, 2, 1) >= 1e-4
        m = J.sum().item() # Number of active examples in the batch.
        if m == 0:
            return x, H_ if n == 1 else H_lu_, If, i
        alpha = torch.ones(n_batch).type_as(x)
        decay = 0.1
        max_armijo = GAMMA
        count = 0
        while max_armijo <= GAMMA and count < 10:
            # Crude way of making sure too much time isn't being spent doing the line search.
            # assert count < 10
            maybe_x = eclamp(x+torch.diag(alpha).mm(dx), lower, upper)
            armijos = (GAMMA+1e-6)*torch.ones(n_batch).type_as(x)
            armijos[J] = (obj(x)-obj(maybe_x))[J]/bdot(g, x-maybe_x)[J]
            I = armijos <= GAMMA
            alpha[I] *= decay
            max_armijo = torch.max(armijos)
            count += 1
        x = maybe_x
    # TODO: Maybe change this to a warning.
    print("[WARNING] pnqp warning: Did not converge")
    return x, H_ if n == 1 else H_lu_, If, i



In [None]:
# @title dynamics
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/dynamics.py
import torch
from torch.autograd import Function, Variable
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter

ACTS = {
    'sigmoid': torch.sigmoid,
    'relu': F.relu,
    'elu': F.elu,
}

class NNDynamics(nn.Module):
    def __init__(self, n_state, n_ctrl, hidden_sizes=[100], activation='sigmoid', passthrough=True):
        super().__init__()
        self.passthrough = passthrough
        self.fcs = []
        in_sz = n_state+n_ctrl
        for out_sz in hidden_sizes + [n_state]:
            fc = nn.Linear(in_sz, out_sz)
            self.fcs.append(fc)
            in_sz = out_sz
        self.fcs = nn.ModuleList(self.fcs)
        assert activation in ACTS.keys()
        act_f = ACTS[activation]
        self.activation = activation
        self.acts = [act_f]*(len(self.fcs)-1)+[lambda x:x] # Activation functions.
        self.Ws = [y.weight for y in self.fcs]
        self.zs = [] # Activations.

    def __getstate__(self):
        return (self.fcs, self.activation, self.passthrough)

    def __setstate__(self, state):
        super().__init__()
        if len(state) == 2:
            # TODO: Remove this soon, keeping for some old models.
            self.fcs, self.activation = state
            self.passthrough = True
        else:
            self.fcs, self.activation, self.passthrough = state
        act_f = ACTS[self.activation]
        self.acts = [act_f]*(len(self.fcs)-1)+[lambda x:x] # Activation functions.
        self.Ws = [y.weight for y in self.fcs]

    def forward(self, x, u):
        x_dim, u_dim = x.ndimension(), u.ndimension()
        if x_dim == 1:
            x = x.unsqueeze(0)
        if u_dim == 1:
            u = u.unsqueeze(0)
        self.zs = []
        z = torch.cat((x, u), 1)
        for act, fc in zip(self.acts, self.fcs):
            z = act(fc(z))
            self.zs.append(z)
        # Hack: Don't include the output.
        self.zs = self.zs[:-1]
        if self.passthrough:
            z += x
        if x_dim == 1:
            z = z.squeeze(0)
        return z

    def grad_input(self, x, u):
        assert isinstance(x, Variable) == isinstance(u, Variable)
        diff = isinstance(x, Variable)
        x_dim, u_dim = x.ndimension(), u.ndimension()
        n_batch, n_state = x.size()
        _, n_ctrl = u.size()
        if not diff:
            Ws = [W.data for W in self.Ws]
            zs = [z.data for z in self.zs]
        else:
            Ws = self.Ws
            zs = self.zs
        assert len(zs) == len(Ws)-1
        grad = Ws[-1].repeat(n_batch,1,1)
        for i in range(len(zs)-1, 0-1, -1):
            n_out, n_in = Ws[i].size()
            if self.activation == 'relu':
                I = get_data_maybe(zs[i] <= 0.).unsqueeze(2).repeat(1,1,n_in)
                Wi_grad = Ws[i].repeat(n_batch,1,1)
                Wi_grad[I] = 0.
            elif self.activation == 'sigmoid':
                d = zs[i]*(1.-zs[i])
                d = d.unsqueeze(2).expand(n_batch, n_out, n_in)
                Wi_grad = Ws[i].repeat(n_batch,1,1)*d
            else:
                assert False
            grad = grad.bmm(Wi_grad)
        R = grad[:,:,:n_state]
        S = grad[:,:,n_state:]
        if self.passthrough:
            I = torch.eye(n_state).type_as(get_data_maybe(R)).unsqueeze(0).repeat(n_batch, 1, 1)
            if diff:
                I = Variable(I)
            R = R + I
        if x_dim == 1:
            R = R.squeeze(0)
            S = S.squeeze(0)
        return R, S

class CtrlPassthroughDynamics(nn.Module):
    def __init__(self, dynamics):
        super().__init__()
        self.dynamics = dynamics

    def forward(self, tilde_x, u):
        tilde_x_dim, u_dim = tilde_x.ndimension(), u.ndimension()
        if tilde_x_dim == 1:
            tilde_x = tilde_x.unsqueeze(0)
        if u_dim == 1:
            u = u.unsqueeze(0)
        n_ctrl = u.size(1)
        x = tilde_x[:,n_ctrl:]
        xtp1 = self.dynamics(x, u)
        tilde_xtp1 = torch.cat((u, xtp1), dim=1)
        if tilde_x_dim == 1:
            tilde_xtp1 = tilde_xtp1.squeeze()
        return tilde_xtp1

    def grad_input(self, x, u):
        assert False, "Unimplemented"

class AffineDynamics(nn.Module):
    def __init__(self, A, B, c=None):
        super(AffineDynamics, self).__init__()
        assert A.ndimension() == 2
        assert B.ndimension() == 2
        if c is not None:
            assert c.ndimension() == 1
        self.A = A
        self.B = B
        self.c = c

    def forward(self, x, u):
        if not isinstance(x, Variable) and isinstance(self.A, Variable):
            A = self.A.data
            B = self.B.data
            c = self.c.data if self.c is not None else 0.
        else:
            A = self.A
            B = self.B
            c = self.c if self.c is not None else 0.
        x_dim, u_dim = x.ndimension(), u.ndimension()
        if x_dim == 1:
            x = x.unsqueeze(0)
        if u_dim == 1:
            u = u.unsqueeze(0)
        z = x.mm(A.t()) + u.mm(B.t()) + c
        if x_dim == 1:
            z = z.squeeze(0)
        return z

    def grad_input(self, x, u):
        n_batch = x.size(0)
        A, B = self.A, self.B
        A = A.unsqueeze(0).repeat(n_batch, 1, 1)
        B = B.unsqueeze(0).repeat(n_batch, 1, 1)
        if not isinstance(x, Variable) and isinstance(A, Variable):
            A, B = A.data, B.data
        return A, B



In [None]:
# @title lqr_step
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/lqr_step.py
# time-varying linear control (LQR) problem
import torch
from torch.autograd import Function, Variable
from torch.nn import Module
from torch.nn.parameter import Parameter
import numpy as np
import numpy.random as npr
from collections import namedtuple
import time

LqrBackOut = namedtuple('lqrBackOut', 'n_total_qp_iter')
LqrForOut = namedtuple('lqrForOut', 'objs full_du_norm alpha_du_norm mean_alphas costs')

def LQRStep(n_state, n_ctrl, T,
            u_lower=None, u_upper=None,
            u_zero_I=None,
            delta_u=None,
            linesearch_decay=0.2,
            max_linesearch_iter=10,
            true_cost=None,
            true_dynamics=None,
            delta_space=True,
            current_x=None, current_u=None,
            verbose=0,
            back_eps=1e-3,
            no_op_forward=False):
    """A single step of the box-constrained iLQR solver.
        Required Args:
            n_state, n_ctrl, T
            x_init: The initial state [n_batch, n_state]
        Optional Args:
            u_lower, u_upper: The lower- and upper-bounds on the controls.
                These can either be floats or shaped as [T, n_batch, n_ctrl]
                TODO: Better support automatic expansion of these.
            TODO"""
    # \@profile
    def lqr_backward(ctx, C, c, F, f):
        n_batch = C.size(1)
        u = ctx.current_u
        Ks = []
        ks = []
        prev_kt = None
        n_total_qp_iter = 0
        Vtp1 = vtp1 = None
        for t in range(T-1, -1, -1):
            if t == T-1:
                Qt = C[t]
                qt = c[t]
            else:
                Ft = F[t]
                Ft_T = Ft.transpose(1,2)
                Qt = C[t] + Ft_T.bmm(Vtp1).bmm(Ft)
                if f is None or f.nelement() == 0:
                    qt = c[t] + Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2)
                else:
                    ft = f[t]
                    qt = c[t] + Ft_T.bmm(Vtp1).bmm(ft.unsqueeze(2)).squeeze(2) + \
                        Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2)
            Qt_xx = Qt[:, :n_state, :n_state]
            Qt_xu = Qt[:, :n_state, n_state:]
            Qt_ux = Qt[:, n_state:, :n_state]
            Qt_uu = Qt[:, n_state:, n_state:]
            qt_x = qt[:, :n_state]
            qt_u = qt[:, n_state:]
            if u_lower is None:
                if n_ctrl == 1 and u_zero_I is None:
                    Kt = -(1./Qt_uu)*Qt_ux
                    kt = -(1./Qt_uu.squeeze(2))*qt_u
                else:
                    if u_zero_I is None:
                        Qt_uu_inv = [torch.pinverse(Qt_uu[i]) for i in range(Qt_uu.shape[0])]
                        Qt_uu_inv = torch.stack(Qt_uu_inv)
                        Kt = -Qt_uu_inv.bmm(Qt_ux)
                        kt = bmv(-Qt_uu_inv, qt_u)
                        # Qt_uu_LU = Qt_uu.lu()
                        # Kt = -Qt_ux.lu_solve(*Qt_uu_LU)
                        # kt = -qt_u.lu_solve(*Qt_uu_LU)
                    else:
                        # Solve with zero constraints on the active controls.
                        I = u_zero_I[t].float()
                        notI = 1-I
                        qt_u_ = qt_u.clone()
                        qt_u_[I.bool()] = 0
                        Qt_uu_ = Qt_uu.clone()
                        if I.is_cuda:
                            notI_ = notI.float()
                            Qt_uu_I = (1-bger(notI_, notI_)).type_as(I)
                        else:
                            Qt_uu_I = 1-bger(notI, notI)
                        Qt_uu_[Qt_uu_I.bool()] = 0.
                        Qt_uu_[bdiag(I).bool()] += 1e-8
                        Qt_ux_ = Qt_ux.clone()
                        Qt_ux_[I.unsqueeze(2).repeat(1,1,Qt_ux.size(2)).bool()] = 0.
                        if n_ctrl == 1:
                            Kt = -(1./Qt_uu_)*Qt_ux_
                            kt = -(1./Qt_uu.squeeze(2))*qt_u_
                        else:
                            Qt_uu_LU_ = Qt_uu_.lu()
                            Kt = -Qt_ux_.lu_solve(*Qt_uu_LU_)
                            kt = -qt_u_.unsqueeze(2).lu_solve(*Qt_uu_LU_).squeeze(2)
            else:
                assert delta_space
                lb = get_bound('lower', t) - u[t]
                ub = get_bound('upper', t) - u[t]
                if delta_u is not None:
                    lb[lb < -delta_u] = -delta_u
                    ub[ub > delta_u] = delta_u
                kt, Qt_uu_free_LU, If, n_qp_iter = pnqp(
                    Qt_uu, qt_u, lb, ub,
                    x_init=prev_kt, n_iter=20)
                if verbose > 1:
                    print('  + n_qp_iter: ', n_qp_iter+1)
                n_total_qp_iter += 1+n_qp_iter
                prev_kt = kt
                Qt_ux_ = Qt_ux.clone()
                Qt_ux_[(1-If).unsqueeze(2).repeat(1,1,Qt_ux.size(2)).bool()] = 0
                if n_ctrl == 1:
                    # Bad naming, Qt_uu_free_LU isn't the LU in this case.
                    Kt = -((1./Qt_uu_free_LU)*Qt_ux_)
                else:
                    Kt = -Qt_ux_.lu_solve(*Qt_uu_free_LU)
            Kt_T = Kt.transpose(1,2)
            Ks.append(Kt)
            ks.append(kt)
            Vtp1 = Qt_xx + Qt_xu.bmm(Kt) + Kt_T.bmm(Qt_ux) + Kt_T.bmm(Qt_uu).bmm(Kt)
            vtp1 = qt_x + Qt_xu.bmm(kt.unsqueeze(2)).squeeze(2) + \
                Kt_T.bmm(qt_u.unsqueeze(2)).squeeze(2) + \
                Kt_T.bmm(Qt_uu).bmm(kt.unsqueeze(2)).squeeze(2)
        return Ks, ks, n_total_qp_iter


    # \@profile
    def lqr_forward(ctx, x_init, C, c, F, f, Ks, ks):
        x = ctx.current_x
        u = ctx.current_u
        n_batch = C.size(1)
        old_cost = get_cost(T, u, true_cost, true_dynamics, x=x)
        current_cost = None
        alphas = torch.ones(n_batch).type_as(C)
        full_du_norm = None
        i = 0
        while (current_cost is None or \
            (old_cost is not None and \
                torch.any((current_cost > old_cost)).cpu().item() == 1)) and \
            i < max_linesearch_iter:
            new_u = []
            new_x = [x_init]
            dx = [torch.zeros_like(x_init)]
            objs = []
            for t in range(T):
                t_rev = T-1-t
                Kt = Ks[t_rev]
                kt = ks[t_rev]
                new_xt = new_x[t]
                xt = x[t]
                ut = u[t]
                dxt = dx[t]
                new_ut = bmv(Kt, dxt) + ut + torch.diag(alphas).mm(kt)
                # Currently unimplemented:
                assert not ((delta_u is not None) and (u_lower is None))
                if u_zero_I is not None:
                    new_ut[u_zero_I[t]] = 0.
                if u_lower is not None:
                    lb = get_bound('lower', t)
                    ub = get_bound('upper', t)
                    if delta_u is not None:
                        lb_limit, ub_limit = lb, ub
                        lb = u[t] - delta_u
                        ub = u[t] + delta_u
                        I = lb < lb_limit
                        lb[I] = lb_limit if isinstance(lb_limit, float) else lb_limit[I]
                        I = ub > ub_limit
                        ub[I] = ub_limit if isinstance(lb_limit, float) else ub_limit[I]
                    # TODO(eugenevinitsky) why do we need to do this here?
                    new_ut = eclamp(new_ut, lb, ub)
                new_u.append(new_ut)
                new_xut = torch.cat((new_xt, new_ut), dim=1)
                if t < T-1:
                    if isinstance(true_dynamics, LinDx):
                        F, f = true_dynamics.F, true_dynamics.f
                        new_xtp1 = bmv(F[t], new_xut)
                        if f is not None and f.nelement() > 0:
                            new_xtp1 += f[t]
                    else:
                        new_xtp1 = true_dynamics(Variable(new_xt), Variable(new_ut)).data
                    new_x.append(new_xtp1)
                    dx.append(new_xtp1 - x[t+1])
                if isinstance(true_cost, QuadCost):
                    C, c = true_cost.C, true_cost.c
                    obj = 0.5*bquad(new_xut, C[t]) + bdot(new_xut, c[t])
                else:
                    obj = true_cost(new_xut)
                objs.append(obj)
            objs = torch.stack(objs)
            current_cost = torch.sum(objs, dim=0)
            new_u = torch.stack(new_u)
            new_x = torch.stack(new_x)
            if full_du_norm is None:
                full_du_norm = (u-new_u).transpose(1,2).contiguous().view(n_batch, -1).norm(2, 1)
            alphas[current_cost > old_cost] *= linesearch_decay
            i += 1
        # If the iteration limit is hit, some alphas are one step too small.
        alphas[current_cost > old_cost] /= linesearch_decay
        alpha_du_norm = (u-new_u).transpose(1,2).contiguous().view(
            n_batch, -1).norm(2, 1)
        return new_x, new_u, LqrForOut(
            objs, full_du_norm,
            alpha_du_norm,
            torch.mean(alphas),
            current_cost
        )

    def get_bound(side, t):
        if side == 'lower':
            v = u_lower
        if side == 'upper':
            v = u_upper
        if isinstance(v, float):
            return v
        else:
            return v[t]

    class LQRStepFn(Function):
        # \@profile
        @staticmethod
        def forward(ctx, x_init, C, c, F, f=None):
            if no_op_forward:
                ctx.save_for_backward(x_init, C, c, F, f, current_x, current_u)
                ctx.current_x, ctx.current_u = current_x, current_u
                return current_x, current_u
            if delta_space:
                # Taylor-expand the objective to do the backward pass in the delta space.
                assert current_x is not None
                assert current_u is not None
                c_back = []
                for t in range(T):
                    xt = current_x[t]
                    ut = current_u[t]
                    xut = torch.cat((xt, ut), 1)
                    c_back.append(bmv(C[t], xut) + c[t])
                c_back = torch.stack(c_back)
                f_back = None
            else:
                assert False
            ctx.current_x = current_x
            ctx.current_u = current_u
            Ks, ks, n_total_qp_iter = lqr_backward(ctx, C, c_back, F, f_back)
            new_x, new_u, for_out = lqr_forward(ctx,
                x_init, C, c, F, f, Ks, ks)
            ctx.save_for_backward(x_init, C, c, F, f, new_x, new_u)
            return new_x, new_u, torch.Tensor([n_total_qp_iter]), \
              for_out.costs, for_out.full_du_norm, for_out.mean_alphas

        @staticmethod
        def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
            start = time.time()
            x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors
            r = []
            for t in range(T):
                rt = torch.cat((dl_dx[t], dl_du[t]), 1)
                r.append(rt)
            r = torch.stack(r)
            if u_lower is None:
                I = None
            else:
                I = (torch.abs(new_u - u_lower) <= 1e-8) | \
                    (torch.abs(new_u - u_upper) <= 1e-8)
            dx_init = Variable(torch.zeros_like(x_init))
            _mpc = MPC(n_state, n_ctrl, T,
                u_zero_I=I, u_init=None,
                lqr_iter=1,
                verbose=-1,
                n_batch=C.size(1),
                delta_u=None,
                exit_unconverged=False, # It's really bad if this doesn't converge.
                eps=back_eps,
            )
            dx, du, _ = _mpc(dx_init, QuadCost(C, -r), LinDx(F, None))
            dx, du = dx.data, du.data
            dxu = torch.cat((dx, du), 2)
            xu = torch.cat((new_x, new_u), 2)
            dC = torch.zeros_like(C)
            for t in range(T):
                xut = torch.cat((new_x[t], new_u[t]), 1)
                dxut = dxu[t]
                dCt = -0.5*(bger(dxut, xut) + bger(xut, dxut))
                dC[t] = dCt
            dc = -dxu
            lams = []
            prev_lam = None
            for t in range(T-1, -1, -1):
                Ct_xx = C[t,:,:n_state,:n_state]
                Ct_xu = C[t,:,:n_state,n_state:]
                ct_x = c[t,:,:n_state]
                xt = new_x[t]
                ut = new_u[t]
                lamt = bmv(Ct_xx, xt) + bmv(Ct_xu, ut) + ct_x
                if prev_lam is not None:
                    Fxt = F[t,:,:,:n_state].transpose(1, 2)
                    lamt += bmv(Fxt, prev_lam)
                lams.append(lamt)
                prev_lam = lamt
            lams = list(reversed(lams))
            dlams = []
            prev_dlam = None
            for t in range(T-1, -1, -1):
                dCt_xx = C[t,:,:n_state,:n_state]
                dCt_xu = C[t,:,:n_state,n_state:]
                drt_x = -r[t,:,:n_state]
                dxt = dx[t]
                dut = du[t]
                dlamt = bmv(dCt_xx, dxt) + bmv(dCt_xu, dut) + drt_x
                if prev_dlam is not None:
                    Fxt = F[t,:,:,:n_state].transpose(1, 2)
                    dlamt += bmv(Fxt, prev_dlam)
                dlams.append(dlamt)
                prev_dlam = dlamt
            dlams = torch.stack(list(reversed(dlams)))
            dF = torch.zeros_like(F)
            for t in range(T-1):
                xut = xu[t]
                lamt = lams[t+1]
                dxut = dxu[t]
                dlamt = dlams[t+1]
                dF[t] = -(bger(dlamt, xut) + bger(lamt, dxut))
            if f.nelement() > 0:
                _dlams = dlams[1:]
                assert _dlams.shape == f.shape
                df = -_dlams
            else:
                df = torch.Tensor()
            dx_init = -dlams[0]
            backward_time = time.time()-start
            return dx_init, dC, dc, dF, df
    return LQRStepFn.apply



In [None]:
# @title mpc
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/mpc.py
import torch
from torch.autograd import Function, Variable
from torch.nn import Module
from torch.nn.parameter import Parameter
import numpy as np
import numpy.random as npr
from collections import namedtuple
from enum import Enum
import sys

QuadCost = namedtuple('QuadCost', 'C c')
LinDx = namedtuple('LinDx', 'F f')

# https://stackoverflow.com/questions/11351032
QuadCost.__new__.__defaults__ = (None,) * len(QuadCost._fields)
LinDx.__new__.__defaults__ = (None,) * len(LinDx._fields)

class GradMethods(Enum):
    AUTO_DIFF = 1
    FINITE_DIFF = 2
    ANALYTIC = 3
    ANALYTIC_CHECK = 4

class SlewRateCost(Module):
    """Hacky way of adding the slew rate penalty to costs."""
    # TODO: It would be cleaner to update this to just use the slew rate penalty instead of # slew_C
    def __init__(self, cost, slew_C, n_state, n_ctrl):
        super().__init__()
        self.cost = cost
        self.slew_C = slew_C
        self.n_state = n_state
        self.n_ctrl = n_ctrl

    def forward(self, tau):
        true_tau = tau[:, self.n_ctrl:]
        true_cost = self.cost(true_tau)
        # The slew constraints are time-invariant.
        slew_cost = 0.5 * bquad(tau, self.slew_C[0])
        return true_cost + slew_cost

    def grad_input(self, x, u):
        raise NotImplementedError("Implement grad_input")


class MPC(Module):
    """A differentiable box-constrained iLQR solver.
    This provides a differentiable solver for the following box-constrained
    control problem with a quadratic cost (defined by C and c) and
    non-linear dynamics (defined by f):
        min_{tau={x,u}} sum_t 0.5 tau_t^T C_t tau_t + c_t^T tau_t
                        s.t. x_{t+1} = f(x_t, u_t)
                            x_0 = x_init
                            u_lower <= u <= u_upper
    This implements the Control-Limited Differential Dynamic Programming
    paper with a first-order approximation to the non-linear dynamics:
    https://homes.cs.washington.edu/~todorov/papers/TassaICRA14.pdf
    Some of the notation here is from Sergey Levine's notes:
    http://rll.berkeley.edu/deeprlcourse/f17docs/lecture_8_model_based_planning.pdf
    Required Args:
        n_state, n_ctrl, T
    Optional Args:
        u_lower, u_upper: The lower- and upper-bounds on the controls.
            These can either be floats or shaped as [T, n_batch, n_ctrl]
        u_init: The initial control sequence, useful for warm-starting:
            [T, n_batch, n_ctrl]
        lqr_iter: The number of LQR iterations to perform.
        grad_method: The method to compute the Jacobian of the dynamics.
            GradMethods.ANALYTIC: Use a manually-defined Jacobian.
                + Fast and accurate, use this if possible
            GradMethods.AUTO_DIFF: Use PyTorch's autograd.
                + Slow
            GradMethods.FINITE_DIFF: Use naive finite differences
                + Inaccurate
        delta_u (float): The amount each component of the controls
            is allowed to change in each LQR iteration.
        verbose (int):
            -1: No output or warnings
             0: Warnings
            1+: Detailed iteration info
        eps: Termination threshold, on the norm of the full control
             step (without line search)
        back_eps: `eps` value to use in the backwards pass.
        n_batch: May be necessary for now if it can't be inferred.
                 TODO: Infer, potentially remove this.
        linesearch_decay (float): Multiplicative decay factor for the
            line search.
        max_linesearch_iter (int): Can be used to disable the line search
            if 1 is used for some problems the line search can
            be harmful.
        exit_unconverged: Assert False if a fixed point is not reached.
        detach_unconverged: Detach examples from the graph that do
            not hit a fixed point so they are not differentiated through.
        backprop: Allow the solver to be differentiated through.
        slew_rate_penalty (float): Penalty term applied to
            ||u_t - u_{t+1}||_2^2 in the objective.
        prev_ctrl: The previous nominal control sequence to initialize
            the solver with.
        not_improved_lim: The number of iterations to allow that don't
            improve the objective before returning early.
        best_cost_eps: Absolute threshold for the best cost
            to be updated."""
    def __init__(self, n_state, n_ctrl, T,
            u_lower=None, u_upper=None,
            u_zero_I=None, u_init=None,
            lqr_iter=10,
            grad_method=GradMethods.ANALYTIC,
            delta_u=None,
            verbose=0,
            eps=1e-7,
            back_eps=1e-7,
            n_batch=None,
            linesearch_decay=0.2,
            max_linesearch_iter=10,
            exit_unconverged=True,
            detach_unconverged=True,
            backprop=True,
            slew_rate_penalty=None,
            prev_ctrl=None,
            not_improved_lim=5,
            best_cost_eps=1e-4
    ):
        super().__init__()
        assert (u_lower is None) == (u_upper is None)
        assert max_linesearch_iter > 0
        self.n_state = n_state
        self.n_ctrl = n_ctrl
        self.T = T
        self.u_lower = u_lower
        self.u_upper = u_upper
        if not isinstance(u_lower, float):
            self.u_lower = detach_maybe(self.u_lower)
        if not isinstance(u_upper, float):
            self.u_upper = detach_maybe(self.u_upper)
        self.u_zero_I = detach_maybe(u_zero_I)
        self.u_init = detach_maybe(u_init)
        self.lqr_iter = lqr_iter
        self.grad_method = grad_method
        self.delta_u = delta_u
        self.verbose = verbose
        self.eps = eps
        self.back_eps = back_eps
        self.n_batch = n_batch
        self.linesearch_decay = linesearch_decay
        self.max_linesearch_iter = max_linesearch_iter
        self.exit_unconverged = exit_unconverged
        self.detach_unconverged = detach_unconverged
        self.backprop = backprop
        self.not_improved_lim = not_improved_lim
        self.best_cost_eps = best_cost_eps
        self.slew_rate_penalty = slew_rate_penalty
        self.prev_ctrl = prev_ctrl

    # \@profile
    def forward(self, x_init, cost, dx):
        # QuadCost.C: [T, n_batch, n_tau, n_tau]
        # QuadCost.c: [T, n_batch, n_tau]
        assert isinstance(cost, QuadCost) or isinstance(cost, Module) or isinstance(cost, Function)
        assert isinstance(dx, LinDx) or isinstance(dx, Module) or isinstance(dx, Function)
        # TODO: Clean up inferences, expansions, and assumptions made here.
        if self.n_batch is not None:
            n_batch = self.n_batch
        elif isinstance(cost, QuadCost) and cost.C.ndimension() == 4:
            n_batch = cost.C.size(1)
        else:
            print('MPC Error: Could not infer batch size, pass in as n_batch')
            sys.exit(-1)
        # if c.ndimension() == 2:
        #     c = c.unsqueeze(1).expand(self.T, n_batch, -1)
        if isinstance(cost, QuadCost):
            C, c = cost
            if C.ndimension() == 2:
                # Add the time and batch dimensions.
                C = C.unsqueeze(0).unsqueeze(0).expand(
                    self.T, n_batch, self.n_state+self.n_ctrl, -1)
            elif C.ndimension() == 3:
                # Add the batch dimension.
                C = C.unsqueeze(1).expand(
                    self.T, n_batch, self.n_state+self.n_ctrl, -1)
            if c.ndimension() == 1:
                # Add the time and batch dimensions.
                c = c.unsqueeze(0).unsqueeze(0).expand(self.T, n_batch, -1)
            elif c.ndimension() == 2:
                # Add the batch dimension.
                c = c.unsqueeze(1).expand(self.T, n_batch, -1)
            if C.ndimension() != 4 or c.ndimension() != 3:
                print('MPC Error: Unexpected QuadCost shape.')
                sys.exit(-1)
            cost = QuadCost(C, c)
        assert x_init.ndimension() == 2 and x_init.size(0) == n_batch
        if self.u_init is None:
            u = torch.zeros(self.T, n_batch, self.n_ctrl).type_as(x_init.data)
        else:
            u = self.u_init
            if u.ndimension() == 2:
                u = u.unsqueeze(1).expand(self.T, n_batch, -1).clone()
        u = u.type_as(x_init.data)
        if self.verbose > 0:
            print('Initial mean(cost): {:.4e}'.format(torch.mean(get_cost(self.T, u, cost, dx, x_init=x_init)).item()))
        best = None
        n_not_improved = 0
        for i in range(self.lqr_iter):
            u = Variable(detach_maybe(u), requires_grad=True)
            # Linearize the dynamics around the current trajectory.
            x = get_traj(self.T, u, x_init=x_init, dynamics=dx)
            if isinstance(dx, LinDx):
                F, f = dx.F, dx.f
            else:
                F, f = self.linearize_dynamics(x, detach_maybe(u), dx, diff=False)
            if isinstance(cost, QuadCost):
                C, c = cost.C, cost.c
            else:
                C, c, _ = self.approximate_cost(x, detach_maybe(u), cost, diff=False)
            x, u, n_total_qp_iter, costs, full_du_norm, mean_alphas = self.solve_lqr_subproblem(x_init, C, c, F, f, cost, dx, x, u)
            n_not_improved += 1
            assert x.ndimension() == 3
            assert u.ndimension() == 3
            if best is None:
                best = {
                    'x': list(torch.split(x, split_size_or_sections=1, dim=1)),
                    'u': list(torch.split(u, split_size_or_sections=1, dim=1)),
                    'costs': costs,
                    'full_du_norm': full_du_norm,
                }
            else:
                for j in range(n_batch):
                    if costs[j] <= best['costs'][j] + self.best_cost_eps:
                        n_not_improved = 0
                        best['x'][j] = x[:,j].unsqueeze(1)
                        best['u'][j] = u[:,j].unsqueeze(1)
                        best['costs'][j] = costs[j]
                        best['full_du_norm'][j] = full_du_norm[j]
            if self.verbose > 0:
                table_log('lqr', (
                    ('iter', i),
                    ('mean(cost)', torch.mean(best['costs']).item(), '{:.4e}'),
                    ('||full_du||_max', max(full_du_norm).item(), '{:.2e}'),
                    # ('||alpha_du||_max', max(alpha_du_norm), '{:.2e}'),
                    # TODO: alphas, total_qp_iters here is for the current iterate, not the best
                    ('mean(alphas)', mean_alphas.item(), '{:.2e}'),
                    ('total_qp_iters', n_total_qp_iter),
                ))
            if max(full_du_norm) < self.eps or \
               n_not_improved > self.not_improved_lim:
                break
        x = torch.cat(best['x'], dim=1)
        u = torch.cat(best['u'], dim=1)
        full_du_norm = best['full_du_norm']
        if isinstance(dx, LinDx):
            F, f = dx.F, dx.f
        else:
            F, f = self.linearize_dynamics(x, u, dx, diff=True)
        if isinstance(cost, QuadCost):
            C, c = cost.C, cost.c
        else:
            C, c, _ = self.approximate_cost(x, u, cost, diff=True)
        x, u = self.solve_lqr_subproblem(x_init, C, c, F, f, cost, dx, x, u, no_op_forward=True)
        if self.detach_unconverged:
            if max(best['full_du_norm']) > self.eps:
                if self.exit_unconverged:
                    assert False
                if self.verbose >= 0:
                    print("LQR Warning: All examples did not converge to a fixed point.")
                    print("Detaching and *not* backpropping through the bad examples.")
                I = full_du_norm < self.eps
                Ix = Variable(I.unsqueeze(0).unsqueeze(2).expand_as(x)).type_as(x.data)
                Iu = Variable(I.unsqueeze(0).unsqueeze(2).expand_as(u)).type_as(u.data)
                x = x*Ix + x.clone().detach()*(1.-Ix)
                u = u*Iu + u.clone().detach()*(1.-Iu)
        costs = best['costs']
        return (x, u, costs)

    def solve_lqr_subproblem(self, x_init, C, c, F, f, cost, dynamics, x, u, no_op_forward=False):
        if self.slew_rate_penalty is None or isinstance(cost, Module):
            _lqr = LQRStep(n_state=self.n_state, n_ctrl=self.n_ctrl, T=self.T,
                u_lower=self.u_lower, u_upper=self.u_upper,
                u_zero_I=self.u_zero_I,
                true_cost=cost,
                true_dynamics=dynamics,
                delta_u=self.delta_u,
                linesearch_decay=self.linesearch_decay,
                max_linesearch_iter=self.max_linesearch_iter,
                delta_space=True,
                current_x=x, current_u=u,
                back_eps=self.back_eps,
                no_op_forward=no_op_forward,
            )
            e = Variable(torch.Tensor())
            return _lqr(x_init, C, c, F, f if f is not None else e)
        else:
            nsc = self.n_state + self.n_ctrl
            _n_state = nsc
            _nsc = _n_state + self.n_ctrl
            n_batch = C.size(1)
            _C = torch.zeros(self.T, n_batch, _nsc, _nsc).type_as(C)
            half_gamI = self.slew_rate_penalty*torch.eye(
                self.n_ctrl).unsqueeze(0).unsqueeze(0).repeat(self.T, n_batch, 1, 1)
            _C[:,:,:self.n_ctrl,:self.n_ctrl] = half_gamI
            _C[:,:,-self.n_ctrl:,:self.n_ctrl] = -half_gamI
            _C[:,:,:self.n_ctrl,-self.n_ctrl:] = -half_gamI
            _C[:,:,-self.n_ctrl:,-self.n_ctrl:] = half_gamI
            slew_C = _C.clone()
            _C = _C + torch.nn.ZeroPad2d((self.n_ctrl, 0, self.n_ctrl, 0))(C)
            _c = torch.cat((torch.zeros(self.T, n_batch, self.n_ctrl).type_as(c),c), 2)
            _F0 = torch.cat((torch.zeros(self.n_ctrl, self.n_state+self.n_ctrl), torch.eye(self.n_ctrl),), 1).type_as(F).unsqueeze(0).unsqueeze(0).repeat(self.T-1, n_batch, 1, 1)
            _F1 = torch.cat((torch.zeros(self.T-1, n_batch, self.n_state, self.n_ctrl).type_as(F),F), 3)
            _F = torch.cat((_F0, _F1), 2)
            if f is not None:
                _f = torch.cat((torch.zeros(self.T-1, n_batch, self.n_ctrl).type_as(f),f), 2)
            else:
                _f = Variable(torch.Tensor())
            u_data = detach_maybe(u)
            if self.prev_ctrl is not None:
                prev_u = self.prev_ctrl
                if prev_u.ndimension() == 1:
                    prev_u = prev_u.unsqueeze(0)
                if prev_u.ndimension() == 2:
                    prev_u = prev_u.unsqueeze(0)
                prev_u = prev_u.data
            else:
                prev_u = torch.zeros(1, n_batch, self.n_ctrl).type_as(u)
            utm1s = torch.cat((prev_u, u_data[:-1])).clone()
            _x = torch.cat((utm1s, x), 2)
            _x_init = torch.cat((Variable(prev_u[0]), x_init), 1)
            if not isinstance(dynamics, LinDx):
                _dynamics = CtrlPassthroughDynamics(dynamics)
            else:
                _dynamics = None
            if isinstance(cost, QuadCost):
                _true_cost = QuadCost(_C, _c)
            else:
                _true_cost = SlewRateCost(cost, slew_C, self.n_state, self.n_ctrl)
            _lqr = LQRStep(n_state=_n_state, n_ctrl=self.n_ctrl, T=self.T,
                u_lower=self.u_lower, u_upper=self.u_upper,
                u_zero_I=self.u_zero_I,
                true_cost=_true_cost,
                true_dynamics=_dynamics,
                delta_u=self.delta_u,
                linesearch_decay=self.linesearch_decay,
                max_linesearch_iter=self.max_linesearch_iter,
                delta_space=True,
                current_x=_x, current_u=u,
                back_eps=self.back_eps,
                no_op_forward=no_op_forward,
            )
            x, *rest = _lqr(_x_init, _C, _c, _F, _f)
            x = x[:,:,self.n_ctrl:]
            return [x] + rest

    def approximate_cost(self, x, u, Cf, diff=True):
        with torch.enable_grad():
            tau = torch.cat((x, u), dim=2).data
            tau = Variable(tau, requires_grad=True)
            if self.slew_rate_penalty is not None:
                print("""
MPC Error: Using a non-convex cost with a slew rate penalty is not yet implemented.
The current implementation does not correctly do a line search.
More details: https://github.com/locuslab/mpc.pytorch/issues/12
""")
                sys.exit(-1)
                differences = tau[1:, :, -self.n_ctrl:] - tau[:-1, :, -self.n_ctrl:]
                slew_penalty = (self.slew_rate_penalty * differences.pow(2)).sum(-1)
            costs = list()
            hessians = list()
            grads = list()
            for t in range(self.T):
                tau_t = tau[t]
                if self.slew_rate_penalty is not None:
                    cost = Cf(tau_t) + (slew_penalty[t-1] if t > 0 else 0)
                else:
                    cost = Cf(tau_t)
                grad = torch.autograd.grad(cost.sum(), tau_t, create_graph=True, retain_graph=True)[0]
                hessian = list()
                for v_i in range(tau.shape[2]):
                    hessian.append(torch.autograd.grad(grad[:, v_i].sum(), tau_t, retain_graph=True)[0])
                hessian = torch.stack(hessian, dim=-1)
                costs.append(cost)
                grads.append(grad - bmv(hessian, tau_t))
                hessians.append(hessian)
            costs = torch.stack(costs, dim=0)
            grads = torch.stack(grads, dim=0)
            hessians = torch.stack(hessians, dim=0)
            if not diff:
                return hessians.data, grads.data, costs.data
            return hessians, grads, costs

    # \@profile
    def linearize_dynamics(self, x, u, dynamics, diff):
        # TODO: Cleanup variable usage.
        n_batch = x[0].size(0)
        if self.grad_method == GradMethods.ANALYTIC:
            _u = Variable(u[:-1].view(-1, self.n_ctrl), requires_grad=True)
            _x = Variable(x[:-1].contiguous().view(-1, self.n_state), requires_grad=True)
            # This inefficiently calls dynamics again, but is worth it because
            # we can efficiently compute grad_input for every time step at once.
            _new_x = dynamics(_x, _u)
            # This check is a little expensive and should only be done if modifying this code.
            # assert torch.abs(_new_x.data - torch.cat(x[1:])).max() <= 1e-6
            if not diff:
                _new_x = _new_x.data
                _x = _x.data
                _u = _u.data
            R, S = dynamics.grad_input(_x, _u)
            f = _new_x - bmv(R, _x) - bmv(S, _u)
            f = f.view(self.T-1, n_batch, self.n_state)
            R = R.contiguous().view(self.T-1, n_batch, self.n_state, self.n_state)
            S = S.contiguous().view(self.T-1, n_batch, self.n_state, self.n_ctrl)
            F = torch.cat((R, S), 3)
            if not diff:
                F, f = list(map(Variable, [F, f]))
            return F, f
        else:
            # TODO: This is inefficient and confusing.
            x_init = x[0]
            x = [x_init]
            F, f = [], []
            for t in range(self.T):
                if t < self.T-1:
                    xt = Variable(x[t], requires_grad=True)
                    ut = Variable(u[t], requires_grad=True)
                    xut = torch.cat((xt, ut), 1)
                    new_x = dynamics(xt, ut)
                    # Linear dynamics approximation.
                    if self.grad_method in [GradMethods.AUTO_DIFF, GradMethods.ANALYTIC_CHECK]:
                        Rt, St = [], []
                        for j in range(self.n_state):
                            Rj, Sj = torch.autograd.grad(
                                new_x[:,j].sum(), [xt, ut],
                                retain_graph=True)
                            if not diff:
                                Rj, Sj = Rj.data, Sj.data
                            Rt.append(Rj)
                            St.append(Sj)
                        Rt = torch.stack(Rt, dim=1)
                        St = torch.stack(St, dim=1)
                        if self.grad_method == GradMethods.ANALYTIC_CHECK:
                            assert False # Not updated
                            Rt_autograd, St_autograd = Rt, St
                            Rt, St = dynamics.grad_input(xt, ut)
                            eps = 1e-8
                            if torch.max(torch.abs(Rt-Rt_autograd)).data[0] > eps or \
                            torch.max(torch.abs(St-St_autograd)).data[0] > eps:
                                print('''nmpc.ANALYTIC_CHECK error: The analytic derivative of the dynamics function may be off.''')
                            else:
                                print('''nmpc.ANALYTIC_CHECK: The analytic derivative of the dynamics function seems correct.
        Re-run with GradMethods.ANALYTIC to continue.''')
                            sys.exit(0)
                    elif self.grad_method == GradMethods.FINITE_DIFF:
                        Rt, St = [], []
                        for i in range(n_batch):
                            Ri = jacobian(lambda s: dynamics(s, ut[i]), xt[i], 1e-4)
                            Si = jacobian(lambda a : dynamics(xt[i], a), ut[i], 1e-4)
                            if not diff:
                                Ri, Si = Ri.data, Si.data
                            Rt.append(Ri)
                            St.append(Si)
                        Rt = torch.stack(Rt)
                        St = torch.stack(St)
                    else:
                        assert False
                    Ft = torch.cat((Rt, St), 2)
                    F.append(Ft)
                    if not diff:
                        xt, ut, new_x = xt.data, ut.data, new_x.data
                    ft = new_x - bmv(Rt, xt) - bmv(St, ut)
                    f.append(ft)
                if t < self.T-1:
                    x.append(detach_maybe(new_x))
            F = torch.stack(F, 0)
            f = torch.stack(f, 0)
            if not diff:
                F, f = list(map(Variable, [F, f]))
            return F, f



In [None]:
# @title pendulum
# https://github.com/locuslab/mpc.pytorch/blob/master/mpc/env_dx/pendulum.py
import torch
from torch.autograd import Function, Variable
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

class PendulumDx(nn.Module):
    def __init__(self, params=None, simple=True):
        super().__init__()
        self.simple = simple
        self.max_torque = 2.0
        self.dt = 0.05
        self.n_state = 3
        self.n_ctrl = 1
        if params is None:
            if simple:
                # gravity (g), mass (m), length (l)
                self.params = Variable(torch.Tensor((10., 1., 1.)))
            else:
                # gravity (g), mass (m), length (l), damping (d), gravity bias (b)
                self.params = Variable(torch.Tensor((10., 1., 1., 0., 0.)))
        else:
            self.params = params
        assert len(self.params) == 3 if simple else 5
        self.goal_state = torch.Tensor([1., 0., 0.])
        self.goal_weights = torch.Tensor([1., 1., 0.1])
        self.ctrl_penalty = 0.001
        # self.lower, self.upper = -2., 2.
        # self.mpc_eps = 1e-3
        # self.linesearch_decay = 0.2
        # self.max_linesearch_iter = 5

    def forward(self, x, u):
        squeeze = x.ndimension() == 1
        if squeeze:
            x = x.unsqueeze(0)
            u = u.unsqueeze(0)
        assert x.ndimension() == 2
        assert x.shape[0] == u.shape[0]
        assert x.shape[1] == 3
        assert u.shape[1] == 1
        assert u.ndimension() == 2
        if x.is_cuda and not self.params.is_cuda:
            self.params = self.params.cuda()
        if not hasattr(self, 'simple') or self.simple:
            g, m, l = torch.unbind(self.params)
        else:
            g, m, l, d, b = torch.unbind(self.params)
        u = torch.clamp(u, -self.max_torque, self.max_torque)[:,0]
        cos_th, sin_th, dth = torch.unbind(x, dim=1)
        th = torch.atan2(sin_th, cos_th)
        if not hasattr(self, 'simple') or self.simple:
            newdth = dth + self.dt*(-3.*g/(2.*l) * (-sin_th) + 3. * u / (m*l**2))
        else:
            sin_th_bias = torch.sin(th + b)
            newdth = dth + self.dt*(-3.*g/(2.*l)*(-sin_th_bias) + 3.*u/(m*l**2) - d*th)
        newth = th + newdth*self.dt
        state = torch.stack((torch.cos(newth), torch.sin(newth), newdth), dim=1)
        if squeeze:
            state = state.squeeze(0)
        return state

    def get_frame(self, x, ax=None):
        x = get_data_maybe(x.view(-1))
        assert len(x) == 3
        l = self.params[2].item()
        cos_th, sin_th, dth = torch.unbind(x)
        th = np.arctan2(sin_th, cos_th)
        x = sin_th*l
        y = cos_th*l
        if ax is None:
            fig, ax = plt.subplots(figsize=(6,6))
        else:
            fig = ax.get_figure()
        ax.plot((0,x), (0, y), color='k')
        ax.set_xlim((-l*1.2, l*1.2))
        ax.set_ylim((-l*1.2, l*1.2))
        return fig, ax

    # def get_true_obj(self): #cost terms for the swingup
    #     q = torch.cat((self.goal_weights, self.ctrl_penalty*torch.ones(self.n_ctrl)))
    #     assert not hasattr(self, 'mpc_lin')
    #     px = -torch.sqrt(self.goal_weights)*self.goal_state #+ self.mpc_lin
    #     p = torch.cat((px, torch.zeros(self.n_ctrl)))
    #     return Variable(q), Variable(p)



### wwwwwww

In [None]:
# @title setup
# !pip install mpc
# https://locuslab.github.io/mpc.pytorch/
# https://github.com/locuslab/mpc.pytorch/tree/master/examples
# https://colab.research.google.com/github/locuslab/mpc.pytorch/blob/master/examples/Pendulum%20Control.ipynb
import torch
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import os
import io
import base64
import tempfile
from IPython.display import HTML
from tqdm import tqdm

%matplotlib inline

In [None]:
# @title pytorch dx
# again

n_state = 2
n_ctrl = 4#3
torch.manual_seed(1)
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.param = nn.Parameter(torch.rand(1,2), requires_grad=True)
        self.lst=[7,10,7]
        self.lin = nn.Sequential(
            # nn.Linear(n_ctrl + n_state, 5), nn.ReLU(),
            # nn.Linear(n_state, 5), nn.ReLU(), nn.Sigmoid(),
            nn.Linear(n_ctrl, self.lst[0]), nn.Tanh(),
            # nn.Linear(self.lst[0], self.lst[1]), nn.Tanh(),
            nn.Linear(self.lst[0], self.lst[-1]), nn.Tanh(), #nn.ReLU(),
            nn.Linear(self.lst[-1], n_state)
        )
        self.n_state = n_state
        self.n_ctrl = n_ctrl
    def forward(self, x, u=None): # state, control
        if u==None: u=self.param
        # else:
        #     with torch.no_grad():
        #         self.param=nn.Parameter(u)
        # print("xu",x.shape, u.shape) # [1, 2] [1, n_ctrl]
        # sx = torch.cat([u, x], dim=-1)
        # x1 = self.lin(sx)
        x1 = self.lin(u)
        x1=x+0.5*x1
        # sx = torch.cat([u, x], dim=-1)
        return x1

dx=model()

n_batch, T, mpc_T = 1, 50, 45 # 16, 100, 20 ; batch size, epochs, ? larger solves
xinit = torch.tensor([[-0.20]*n_state]) # assert x_init.ndimension() == 2 and x_init.size(0) == n_batch
x = xinit # init state
u_init = None # initial control?

# goal_weights = torch.Tensor((0.9, 0.8, 0.7)) # 1., 1., 0.1
goal_weights = torch.Tensor([1.0]*n_state)
# goal_weights = torch.linspace(1, 0.1, n_state)

# goal_state = torch.Tensor((1., 0. ,0.))
# goal_state = torch.Tensor((-0.5, 0.1))
goal_state = torch.Tensor([-0.50]*n_state)

ctrl_penalty = 0.001
q = torch.cat((goal_weights, ctrl_penalty*torch.ones(n_ctrl))) # goal_weights.shape + n_ctrl.shape
Q = torch.diag(q).unsqueeze(0).unsqueeze(0).repeat(mpc_T, n_batch, 1, 1) # [50, 1, 5, 5]
px = -torch.sqrt(goal_weights)*goal_state
# px = -goal_weights*goal_state
# px = -goal_state
p = torch.cat((px, torch.zeros(n_ctrl)))
p = p.unsqueeze(0).repeat(mpc_T, n_batch, 1)


for t in range(T):
    nominal_states, nominal_actions, nominal_objs = MPC(
        n_state, n_ctrl, mpc_T, # state dim, action dim, 
        u_init=u_init,
        u_lower=-1., u_upper=1., #
        # u_lower=-0.8, u_upper=0.8, #
        lqr_iter=50, # 50 num LQR iterations to perform
        verbose=0, #0
        exit_unconverged=False,
        detach_unconverged=False,
        linesearch_decay=0.2, #dx.linesearch_decay,
        max_linesearch_iter=5,#dx.max_linesearch_iter,
        grad_method=GradMethods.AUTO_DIFF,
        eps=1e-2,
    )(x, QuadCost(Q, p), dx)
    
    next_action = nominal_actions[0]
    u_init = torch.cat((nominal_actions[1:], torch.zeros(1, n_batch, n_ctrl)), dim=0)
    # print('u_init',u_init.shape) #[100, 1, 1]
    # u_init[-2] = u_init[-3]
    print(t, x.detach().numpy(), next_action.detach().numpy())
    x = dx(x, next_action)
    error = nn.MSELoss()(x,goal_state)
    print("error",error)


0 [[-0.2 -0.2]] [[-0.17509563 -0.28452164  1.         -1.        ]]
1 [[-0.23185    -0.23529498]] [[-0.18100911 -0.30179918  1.         -1.        ]]
2 [[-0.26399294 -0.27027354]] [[-0.18096578 -0.31398958  1.         -1.        ]]
3 [[-0.29629573 -0.3050775 ]] [[-0.1974833 -0.336871   1.        -1.       ]]
4 [[-0.32907286 -0.33935136]] [[-0.20583177 -0.35086885  1.         -1.        ]]
5 [[-0.36211836 -0.37331837]] [[-0.22839351 -0.3674021   1.         -1.        ]]
6 [[-0.3956101 -0.4067604]] [[-0.2702965  -0.39168274  1.         -1.        ]]
7 [[-0.42983505 -0.43929994]] [[-0.31932482 -0.41390368  1.         -1.        ]]
8 [[-0.46482164 -0.47084403]] [[-0.47291186 -0.5129598   0.93203855 -0.65716136]]
9 [[-0.49130416 -0.49320322]] [[-0.63244087 -0.61670333  0.5832975  -0.34502947]]
10 [[-0.49831346 -0.49866238]] [[-0.6639701  -0.6327156   0.49365735 -0.26996472]]
11 [[-0.4996941  -0.49974266]] [[-0.6700189  -0.6355926   0.47605738 -0.2552817 ]]
12 [[-0.49994403 -0.49995053]] [[-

In [None]:
# test controlability?
torch.manual_seed(1)
dx=model()
# x = torch.randn(n_state).uniform_(-1, 1)
x = torch.zeros(n_state)
u = torch.ones(n_ctrl)
for t in range(50):
    # next_action = torch.randn(n_ctrl)
    st = dx(x, u)
    print(x.detach().numpy(), u.detach().numpy())
    u = torch.zeros(n_ctrl)
    x=st


In [None]:
# def optimal_control(dx, init_st, goal_st):


In [None]:
# @title sentencepiece tokenizer
!pip install sentencepiece
# https://github.com/kutvonenaki/cc100-sentencepiece
!git clone https://github.com/kutvonenaki/cc100-sentencepiece.git

import sentencepiece as spm

# ​cc100_en_vocab_8000.model # cc100_en_vocab_8000.vocab
modelpath="/content/cc100-sentencepiece/trained_tokenizers/cc100_en_vocab_8000.model"
sp = spm.SentencePieceProcessor(model_file=modelpath)
# text="yes tokenizers would be peferable, over traditional cat! dog?"
# encoded = sp.encode(text)
# print(encoded)
# decoded = sp.decode(encoded)
# print(decoded)

# enpieces = sp.encode_as_pieces(text)
# print(enpieces)
# decpieces = sp.decode_pieces(enpieces)
# print(decpieces)
# print(sp.get_piece_size())

# https://pytorch.org/text/stable/data_functional.html#torchtext.data.functional.sentencepiece_tokenizer
from torchtext.transforms import SentencePieceTokenizer
sp_tokenizer = SentencePieceTokenizer(modelpath)
sp_tokenizer(["hello world", "attention is all you need!"])

for x in range(50):
    char=next(loader)
    print(char, end =" ")


## JEPA

In [None]:
# @title setup
# https://openreview.net/pdf?id=BZ5a1r-kVsf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import collections
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# @title data

# data = open('input.txt', 'r').read()
# data = list(data)

# https://edisciplinas.usp.br/pluginfile.php/3403095/mod_resource/content/1/56ViktorFrankl_Mans%20Search.pdf
text='''
Only slowly could these men be guided back to the commonplace truth that no one has the right to do wrong, not
even if wrong has been done to them. We had to strive to
lead them back to this truth, or the consequences would
have been much worse than the loss of a few thousand stalks
of oats. I can still see the prisoner who rolled up his shirt
sleeves, thrust his right hand under my nose and shouted,
"May this hand be cut off if I don't stain it with blood
on the day when I get home!" I want to emphasize that the
man who said these words was not a bad fellow. He had
been the best of comrades in camp and afterwards.
Apart from the moral deformity resulting from the sudden release of mental pressure, there were two other
fundamental experiences which threatened to damage the
character of the liberated prisoner: bitterness and disillusionment when he returned to his former life.
Bitterness was caused by a number of things he came up
against in his former home town. When, on his return, a
man found that in many places he was met only with a
shrug of the shoulders and with hackneyed phrases, he
tended to become bitter and to ask himself why he had
gone through all that he had. When he heard the same
phrases nearly everywhere—"We did not know about it,"
and "We, too, have suffered," then he asked himself, have
they really nothing better to say to me?
The experience of disillusionment is different. Here it
was not one's fellow man (whose superficiality and lack of
feeling was so disgusting that one finally felt like creeping
into a hole and neither hearing nor seeing human beings
any more) but fate itself which seemed so cruel. A man who
Experiences in a Concentration Camp 99
for years had thought he had reached the absolute limit of
all possible suffering now found that suffering has no limits,
and that he could suffer still more, and still more intensely.
When we spoke about attempts to give a man in camp
mental courage, we said that he had to be shown something
to look forward to in the future. He had to be reminded
that life still waited for him, that a human being waited for
his return. But after liberation? There were some men who
found that no one awaited them. Woe to him who found
that the person whose memory alone had given him courage
in camp did not exist any more! Woe to him who, when the
day of his dreams finally came, found it so different from all
he had longed for! Perhaps he boarded a trolley, traveled
out to the home which he had seen for years in his mind,
and only in his mind, and pressed the bell, just as he has
longed to do in thousands of dreams, only to find that the
person who should open the door was not there, and would
never be there again.
'''

# # make dataset
# text=text.replace('\n','')
# data=list(text)
# # print(data)
# # data=sorted([ord(x) for x in set(data)])
# dataset=[ord(x)-31 for x in data]
# # dataset = map(lambda x:x.strip("8"), lst)
# # [f(x) if condition else g(x) for x in sequence]
# dataset=[0 if x>91 else x for x in dataset] # 0 unk, 1-91
# vocab_size=92
# print(dataset)
# 32space;A65-Z90;a97-z122


encoded = sp.encode(text)
print(encoded)
decoded = sp.decode(encoded)
print(decoded)


from torch.utils.data import Dataset
class Datasetme(Dataset): #https://www.kaggle.com/code/pinocookie/pytorch-dataset-and-dataloader/notebook
    def __init__(self, text, embed_size=2):
        super().__init__()
        text=text.replace('\n','')
        data=list(text)
        self.dataset=[self.encode(x) for x in data] # 0 unk, 1-91
        self.vocab_size=vocab_size=92
        self.embed_size=embed_size
        self.batch_size = 1
        self.embed=nn.Embedding(vocab_size, embed_size)
    def __len__(self):
        return len(self.dataset)-1
    def __getitem__(self, index):
        # return self.dataset[index], self.dataset[index+1]
        # for i in range(30):
        #     print(self.dataset[i])
        return self.embeder(self.dataset[index])
        # return self.embeder(char0), self.embeder(char1)
    def embeder(self, x):
        # return self.embed(torch.tensor(self.encode(char)))
        return self.embed(torch.tensor(x))
    def encode(self, char):
        x = ord(char)-31
        if x>91: x=0
        return x
    def decode(self, x): return chr(x+31)
dataset=Datasetme(text)


from torch.utils.data import Dataset
class Datasetme(Dataset):
    def __init__(self, text, embed_size=2):
        super().__init__()
        text=text.replace('\n',' ')
        data=list(text)
        self.dataset=self.tokenize(data)
        self.vocab_size=vocab_size=92
    def __len__(self):
        return len(self.dataset)-1
    def __getitem__(self, index):
        # print("dat: ",self.dataset[index])
        return self.dataset[index]
    def tokenize(self, string):
        data=list(string)
        return [self.token(x) for x in data] # 0 unk, 1-91
    def token(self, char):
        x = ord(char)-31
        if x>91: x=0
        return x
    def detoken(self, x): return chr(x+31)
    def detokenize(self, string):
        data=list(string)
        return [self.detoken(x) for x in data]
dataset=Datasetme(text)

# print(dataset.__getitem__(30))
# print(dataset.embeder(91))

loader=iter(dataset)

text="yes tokenizers would be peferable over traditional cat dog"
encoded = dataset.tokenize(text)
print(encoded)
decoded = dataset.detokenize(encoded)
print(decoded)



In [None]:
# @title jepa
def off_diagonal(x):
    # print("off_diagonal",x.shape)
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class JEPA(nn.Module):
    def __init__(self, in_dimx, in_dimy, dim_sx, dim_sy, dim_z, dim_v):
        super(JEPA, self).__init__()
        self.enc_x = nn.Sequential(
            nn.Linear(in_dimx, dim_sx),
            nn.Tanh(),
            )
        self.enc_y = nn.Sequential(
            nn.Linear(in_dimy, dim_sx),
            nn.Tanh(),
            )
        self.pred = nn.Sequential(
            nn.Linear(dim_sx + dim_z, dim_sy),
            nn.ReLU(),
            nn.Linear(dim_sy, dim_sy),
            )
        self.exp_x = nn.Sequential(
            nn.Linear(dim_sx, dim_v),
            # nn.ReLU(),
            )
        self.exp_y = nn.Sequential(
            nn.Linear(dim_sy, dim_v),
            # nn.ReLU(),
            )
        self.dim_z = dim_z
        self.g_st = nn.Parameter(torch.rand(1,in_dimx), requires_grad=True)

    # https://arxiv.org/pdf/2105.04906.pdf
    def vicreg(self, x, y): # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        # invariance loss
        repr_loss = F.mse_loss(x, y) # s(Z, Z')
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # variance loss
        std_x = torch.sqrt(x.var(dim=0) + 0.0001) #ϵ=0.0001
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        batch_size=x.size(dim=0)
        num_features=32
        sim_coeff=25.0 # λ
        std_coeff=25.0 # µ
        cov_coeff=1.0 # ν

        if x.dim() == 1:
            x = x.view(-1, 1)
        if y.dim() == 1:
            y = y.view(-1, 1)
        x=x.T
        y=y.T

        # # covariance loss
        cov_x = (x.T @ x) / (batch_size - 1) #C(Z)
        cov_y = (y.T @ y) / (batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features)\
         + off_diagonal(cov_y).pow_(2).sum().div(num_features) #c(Z)
        loss = (sim_coeff * repr_loss + std_coeff * std_loss + cov_coeff * cov_loss)
        print("in vicreg ",sim_coeff * repr_loss , std_coeff * std_loss , cov_coeff * cov_loss)
        return loss

    # https://stackoverflow.com/questions/67328098/how-to-find-input-that-maximizes-output-of-a-neural-network-using-pytorch
    def argm(self, sx, sy, dim_z):
        batch=sx.size(dim=0)
        z = nn.Parameter(torch.rand(batch,dim_z), requires_grad=True)
        # z = nn.Parameter(torch.rand(1,dim_z), requires_grad=True)
        self.pred.requires_grad_(False)
        optim = torch.optim.SGD([z], lr=1e-1)
        lossfn = torch.nn.MSELoss()
        sx=sx.detach()
        sy=sy.detach()
        z=z.to(device)
        num_steps = 100
        # print("argm",sx.shape,z.shape)
        for _ in range(num_steps):
            sxz = torch.cat([sx, z], dim=-1)#.to(device)
            sy_ = self.pred(sxz)
            loss = lossfn(sy_, sy)
            loss.backward()
            optim.step()
            optim.zero_grad()
        # z=z.detach()
        return z

    def loss(self, x, y):
        if x.dim()==2: batch_size,_=x.shape
        # print("loss",x,x.shape,x.dtype)
        # sx = self.enc_x(x)
        # sy = self.enc_y(y)
        sx,sy=x,y
        # sx=sx.flatten(start_dim=1)
        # sy=sy.flatten(start_dim=1)
        # print("sx, sy",sx.shape, sy.shape) #10000
        z = self.argm(sx, sy, self.dim_z).to(device)
        # z=torch.tensor(z).view(1,-1)
        # z=z.clone().view(1,-1)
        sxz = torch.cat([sx, z], dim=-1)
        # self.pred.requires_grad_(True)
        sy_ = self.pred(sxz)
        mseloss = nn.MSELoss()(sy, sy_)
        vx = self.exp_x(sx)
        vy = self.exp_y(sy)
        vicloss = self.vicreg(vx, vy)
        return mseloss + vicloss

    def forward(self, sx, a):
        # sx = self.enc_x(x)
        # sx=sx.flatten()
        print("in fwd",sx.shape,a.shape)
        sxz = torch.cat([sx, a], dim=-1)
        sy_ = self.pred(sxz)
        return sy_
    def encode(self, x): return self.enc_x(x)

# vocab_size=dataset.vocab_size #92
# embed_size=dataset.embed_size #4
vocab_size=92
embed_size=2

percept_size=64 # dim of raw data from perceiving world
encode_size=64 # dim of inner model of world state
in_dimx=percept_size
in_dimy=percept_size
dim_sx=encode_size
dim_sy=encode_size
dim_z=embed_size # dim of action
dim_v=5 # expanded dim for vicreg regularisation

model = JEPA(in_dimx, in_dimy, dim_sx, dim_sy, dim_z, dim_v).to(device)

# x=torch.rand(1, in_dimx).to(device)
# y=torch.rand(1, in_dimy).to(device)

batch=10
x=torch.rand(batch, in_dimx).to(device)
y=torch.rand(batch, in_dimy).to(device)

loss = model.loss(x,y)



tensor([0.0326, 0.0571, 0.0159, 0.0333, 0.0406], device='cuda:0',
       grad_fn=<VarBackward0>) tensor([0.0272, 0.0178, 0.0244, 0.0065, 0.0253], device='cuda:0',
       grad_fn=<VarBackward0>)
tensor([0.1809, 0.2391, 0.1265, 0.1828, 0.2017], device='cuda:0',
       grad_fn=<SqrtBackward0>) tensor([0.1652, 0.1339, 0.1564, 0.0812, 0.1592], device='cuda:0',
       grad_fn=<SqrtBackward0>)
in vicreg tensor(0.0853, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(0.8373, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0596, device='cuda:0', grad_fn=<AddBackward0>)


#### wwwwwwwwwww

In [None]:


optimizer = torch.optim.AdamW(model.parameters(), lr = lr, betas = betas)
# optimizer = torch.optim.AdamW(model.parameters, lr = lr, betas = betas)

def loss_fn(logits, y):
    return nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), y.view(-1))



In [None]:
batch_size=10
from torch.utils.data.dataloader import DataLoader
# train_loader = DataLoader(train_dataset, shuffle = True, pin_memory = True, batch_size = batch_size, num_workers = 4)
# test_loader = DataLoader(test_dataset, shuffle = True, pin_memory = True, batch_size = batch_size, num_workers = 0)
loader = DataLoader(dataset, shuffle = False, batch_size = batch_size, num_workers = 0)

epochs=1
for epoch in range(epochs):
    # run_epoch(train_loader)
    # train(loader, model, loss_fn, optimizer)
    test_loss = eval(loader, model, loss_fn)
    print('Test Loss:', test_loss)


in eval torch.Size([1, 64]) torch.Size([10, 2])
in fwd torch.Size([1, 64]) torch.Size([10, 2])


RuntimeError: ignored

In [None]:
def dembed(embedding):
# embedding=torch.tensor([ 0.2544, -1.5085])
    dist = torch.norm(dataset.embed.weight.data - embedding, dim=1)
    nearest = torch.argmin(dist)
    # print(nearest)
    out = dataset.decode(nearest)
    # print(out)
    return out

pbar = enumerate(loader)
x = model.g_st.to(device)
for it, a in pbar:
    # print(it, a)
    print(it, [dembed(x) for x in a])

0 ['O', 'n', 'l', 'y', ' ', 's', 'l', 'o', 'w', 'l']
1 ['y', ' ', 'c', 'o', 'u', 'l', 'd', ' ', 't', 'h']
2 ['e', 's', 'e', ' ', 'm', 'e', 'n', ' ', 'b', 'e']
3 [' ', 'g', 'u', 'i', 'd', 'e', 'd', ' ', 'b', 'a']
4 ['c', 'k', ' ', 't', 'o', ' ', 't', 'h', 'e', ' ']
5 ['c', 'o', 'm', 'm', 'o', 'n', 'p', 'l', 'a', 'c']
6 ['e', ' ', 't', 'r', 'u', 't', 'h', ' ', 't', 'h']
7 ['a', 't', ' ', 'n', 'o', ' ', 'o', 'n', 'e', ' ']
8 ['h', 'a', 's', ' ', 't', 'h', 'e', ' ', 'r', 'i']
9 ['g', 'h', 't', ' ', 't', 'o', ' ', 'd', 'o', ' ']
10 ['w', 'r', 'o', 'n', 'g', ',', ' ', 'n', 'o', 't']
11 ['e', 'v', 'e', 'n', ' ', 'i', 'f', ' ', 'w', 'r']
12 ['o', 'n', 'g', ' ', 'h', 'a', 's', ' ', 'b', 'e']
13 ['e', 'n', ' ', 'd', 'o', 'n', 'e', ' ', 't', 'o']
14 [' ', 't', 'h', 'e', 'm', '.', ' ', 'W', 'e', ' ']
15 ['h', 'a', 'd', ' ', 't', 'o', ' ', 's', 't', 'r']
16 ['i', 'v', 'e', ' ', 't', 'o', 'l', 'e', 'a', 'd']
17 [' ', 't', 'h', 'e', 'm', ' ', 'b', 'a', 'c', 'k']
18 [' ', 't', 'o', ' ', 't', 'h', 'i',

In [None]:
# @title save
from google.colab import drive
drive.mount('/content/gdrive')
PATH="/content/gdrive/MyDrive/torch_save/" # for saving to google drive
name='jepa_mpc.pth'
# PATH="/content/" # for saving on colab only
# name='model.pth'

torch.save(model.state_dict(), PATH+name)

# model.load_state_dict(torch.load(PATH+name))


#### inference

In [None]:

char='l'
a=dataset.encode(char)
a=dataset.embeder(a)
print(a)
# a=torch.tensor(a).view(1,-1).to(device)
# a=torch.tensor(a.clone()).view(1,-1).to(device)
a=a.clone().view(1,-1).to(device)
sx=model.encode(x)
# print("sx a",sx,a,sx.shape, a.shape)
sy = model(sx, a)
print(sy)


In [None]:

goal_st=torch.tensor([-0.0426, -0.2183, -0.0351,  0.0014, -0.0591, -0.0012, -0.0198,  0.0319,
          0.1887,  0.1562,  0.0134,  0.1240, -0.1527, -0.0601,  0.0372,  0.0323,
          0.0605, -0.0699,  0.0766, -0.0679, -0.0489, -0.1938, -0.1501, -0.0716,
          0.2145,  0.0125, -0.1071, -0.0086,  0.2173,  0.0805, -0.1742,  0.0086,
         -0.0442,  0.2411, -0.0668, -0.1390,  0.0308,  0.1623, -0.2207,  0.1363,
          0.0945,  0.1521, -0.1333,  0.0900, -0.0801,  0.1209,  0.0082,  0.0654,
          0.0111,  0.1649,  0.1243, -0.2940, -0.0775, -0.0036,  0.1020,  0.0426,
         -0.2507,  0.1013, -0.0405,  0.0273,  0.0664, -0.0432,  0.2488,  0.0915]).to(device)
init_st=model.g_st
dx=model
ctrl=optimal_control(dx, init_st, goal_st)
print(ctrl)


In [None]:
print(model.g_st)

In [None]:

context = "This is what "

# def encode(char):
#     x = ord(char)-31
#     if x>91: x=0
#     return x
# def decode(x): return chr(x+31)

out=[]
x = model.g_st
for i in range(13):
    # a=torch.zeros(1,1).to(device)
    char=context[i]
    a=dataset.encode(char)
    a=dataset.embeder(a)
    # print(a)
    # a=torch.tensor(a).view(1,-1).to(device)
    a=a.clone().view(1,-1).to(device)
    sx=model.encode(x)
    # print("sx a",sx,a,sx.shape, a.shape)
    sy = model(sx, a)
    # out.append(sy)
    sx=sy
# print(out)
print(''.join(out))




### archive

#### train eval

In [None]:
# from tqdm import tqdm

# grad_clip_norm=1
# lr=1e-4
# betas=(0.9, 0.95)
# batch_size=1
# def train(loader, model, loss_fn, optimizer):
#     model.train()
#     losses = []
#     pbar = tqdm(enumerate(loader), total = len(loader))
#     for it, (x, y) in pbar:
#         # print("x,y",x.dtype,y.dtype) #torch.float32 [1,4]
#         # print("x,y",x.shape,y.shape)
#         x = x.to(device)
#         y = y.to(device)
#         # print("x",x)
#         # with torch.set_grad_enabled(True):

#         # logits = model(x)
#         # # loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
#         # loss = loss_fn(logits, y)
#         loss = model.loss(x,y)
#         losses.append(loss.item())

#         model.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
#         optimizer.step()
#         # lr = lr
#         # pbar.set_description(f"epoch {epoch + 1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
#         pbar.set_description(f"epoch {epoch + 1} iter {it}: train loss {loss.item():.5f}")


# def eval(loader, model, loss_fn):
#     model.eval()
#     losses = []
#     pbar = enumerate(loader)
#     for it, (x, y) in pbar:
#         x = x.to(device)
#         y = y.to(device)
#         # with torch.set_grad_enabled(False):
#         with torch.no_grad():
#             # logits = model(x)
#             # # loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
#             # loss = loss_fn(logits, y)
#             loss = model.loss(x,y)
#             losses.append(loss.item())
#     test_loss = float(np.mean(losses))
#     # logger.info("test loss: %f", test_loss)
#     return test_loss


In [None]:
    # def loss(self, x, y):
    #     sx = self.enc_x(x)
    #     sy = self.enc_y(y)
    #     sx=sx.flatten()
    #     sy=sy.flatten()
    #     z = self.argm(sx, sy)
    #     sxz = torch.cat([sx, z], dim=-1)
    #     sy_ = self.pred(sxz)
    #     # loss(sy, sy_)
    #     mseloss = nn.MSELoss()(sy, sy_)

    #     vx = self.exp_x(sx)
    #     vy = self.exp_y(sy)
    #     # print("vx",vx.shape) #[40]
    #     vicloss = self.vicreg(vx, vy)
    #     return mseloss + vicloss


In [None]:
# @title archive optuna argm
# !pip install optuna
# import optuna

    # def argm(self, sx, sy):
    #     optuna.logging.set_verbosity(optuna.logging.WARNING)
    #     sampler = optuna.samplers.NSGAIISampler()
    #     # sampler = optuna.samplers.MOTPESampler()
    #     study = optuna.create_study(direction="maximize", sampler=sampler, pruner=optuna.pruners.MedianPruner())
    #     # study = optuna.create_study()
    #     # print("sx",sx.shape)
    #     # sx=sx.flatten()
    #     def objective(trial):
    #         z = trial.suggest_uniform('z', -1, 1)
    #         # print("z trail",sx,z)
    #         z=torch.tensor([z])
    #         sxz = torch.cat([sx, z], dim=-1)
    #         sy_ = self.pred(sxz)
    #         mseloss = nn.MSELoss()(sy, sy_)
    #         return mseloss
    #     study.optimize(objective, n_trials=100)
    #     st=study.best_params
    #     # print("st",st['z'])
    #     st=torch.tensor([st['z']])
    #     return st


    # def argm(self, SX, SY, dim_z):
    #     optuna.logging.set_verbosity(70)
    #     sampler = optuna.samplers.NSGAIISampler()
    #     # sampler = optuna.samplers.MOTPESampler()
    #     pruner = optuna.pruners.MedianPruner()
    #     batch_size=1
    #     # if sx.dim() == 2: batch_size,_=sx.shape
    #     if SX.dim() == 2: batch_size,_=SX.shape
    #     s=[]
    #     for i in range(batch_size):
    #         sx, sy = SX[i],SY[i]
    #         study = optuna.create_study(direction="minimize", sampler=sampler, pruner=pruner)
    #         def objective(trial):
    #             if dim_z<0:
    #                 z = [trial.suggest_categorical("z", np.linspace(-1,1,-dim_z,dtype=np.float32))]
    #             elif dim_z>0:
    #                 z=[]
    #                 for d in range(dim_z):
    #                     z.append(trial.suggest_uniform(chr(d), -1, 1))
    #             # z = trial.suggest_uniform('z', -1, 1)
    #             # print("z trail",z)
    #             z=torch.tensor(z).to(device)
    #             # print("sx, z",sx.dtype, z.dtype) #[500, 20] [1]
    #             sxz = torch.cat([sx, z], dim=-1)
    #             # print("sxz",sxz,sxz.shape)
    #             sy_ = self.pred(sxz)
    #             mseloss = nn.MSELoss()(sy, sy_)
    #             return mseloss
    #         study.optimize(objective, n_trials=10)
    #         st=list(study.best_params.values())
    #         # print(s)
    #         # st=torch.tensor([st['z']])
    #         s.append(st)
    #     return torch.tensor(s)

