In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
import inspect
currentdir = os.path.dirname(os.path.abspath(
    inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
os.sys.path.insert(1, parentdir+'/src')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import time
from tqdm import trange
import matplotlib.pyplot as plt
import matplotlib

In [None]:
import gym
import pybullet as p
import stage.envs
from stage.tasks.twolink.reaching import TwoLinkReaching
from stage.utils.nn import use_gpu
use_gpu()

In [None]:
X = 2
A = 1
XA = X + A

In [None]:
class Integrator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.dt = 0.01
        self.X = 2
        self.A = 1
    def forward(self, x, a):
        x_ = torch.zeros_like(x)
        q, v = x[:, 0:1], x[:, 1:]
        q = q + self.dt * v
        v = v + self.dt * a
        x_[:, 0:1] = q
        x_[:, 1:] = v
        return x_
    def unroll(self, x, a_seq, horizon=None):
        B, L, A = a_seq.shape        
        if horizon is None:
            horizon = L
        S = torch.zeros(B, horizon, X + A)
        for n in range(horizon):
            a = a_seq[:, n, :]
            x = self.forward(x, a)
            xa = torch.cat((x,a), dim=1)
            S[:, n, :] = xa

        return S

In [None]:
class Cost(nn.Module):
    def __init__(self):
        super().__init__()
        self.dt = 0.01    
    def forward(self, x, a):
        return x[:, 0:1] ** 2 + 1e-6*a[:, 0:1]

In [None]:
sys = Integrator()
cost = Cost()
x = torch.randn(1, X)

### Single shooting

In [None]:
H = 90
K = 10000
E = 100
mean = torch.zeros(H, A)
var = (200**2)/16*torch.ones(H, A)

clamp = lambda x: torch.max(torch.min(x, 100*torch.ones_like(x)), -100*torch.ones_like(x))
for i in range(10):
    samples = mean.repeat(K, 1, 1) + var.repeat(K, 1, 1) * torch.randn(K, H, A)
    samples = clamp(samples)
    
    S = sys.unroll(x.expand(K, X), samples)
    Sx = S[:, :, :X].view(-1, X)
    Sa = S[:, :, X:X+A].view(-1, A)
    J = cost(Sx, Sa)
    J = J.view(K, -1)
    J = torch.sum(J, dim=1)
    
    elites = samples[torch.argsort(J)][:E]
    new_mean = torch.mean(elites, dim=0)
    new_var = torch.var(elites, dim=0)

    mean = 0.1 * mean + 0.9 * new_mean
    var = 0.1 * var + 0.9 * new_var
    
    if torch.max(var) < 0.001:
        break
    
actions = mean.unsqueeze(0)

In [None]:
S = sys.unroll(x, actions)
Sx = S[:, :, :X].view(-1, X)
Sx = Sx.detach().cpu().numpy()
plt.plot(Sx[:, 0])

### Multiple shooting

In [None]:
K = 100000
EX = 100
EA = 1000
H = 90
L = 45
M = int(H/L)

mean_x = torch.zeros(M, X)
var_x = (10**2)/16*torch.ones(M, X)
mean_a = torch.zeros(H, A)
var_a = (200**2)/16*torch.ones(H, A)

clamp_a = lambda a: torch.max(torch.min(a, 100*torch.ones_like(a)), -100*torch.ones_like(a))
clamp_x = lambda x: torch.max(torch.min(x, 5*torch.ones_like(x)), -5*torch.ones_like(x))

for i in range(10):
    samples_x = mean_x.repeat(K, 1, 1) + var_x.repeat(K, 1, 1) * torch.randn(K, M, X)
    samples_a = mean_a.repeat(K, 1, 1) + var_a.repeat(K, 1, 1) * torch.randn(K, H, A)
    samples_x = clamp_x(samples_x)
    samples_a = clamp_a(samples_a)
    
    samples_x[:, 0, :] = x.expand(K, X)  
    J = torch.zeros(K)
    G = torch.zeros(K)
    
    for m in range(M):
        x0 = samples_x[:, m, :]
        S = sys.unroll(x0, samples_a[:, m*L:, :], L)
        
        if m < M - 1:
            gap = S[:, -1, :X] -  samples_x[:, m + 1, :]
            gap_cost = L * torch.norm(gap, p=2, dim=1)**2
            
        else:
            gap_cost = 0 * torch.norm(gap, p=2, dim=1)**2
        
        Sx = S[:, :, :X].view(-1, X)
        Sa = S[:, :, X:X+A].view(-1, A)

        Jm = cost(Sx, Sa)
        Jm = Jm.view(K, -1)
        Jm = torch.sum(Jm, dim=1)
        J += Jm
        G += gap_cost
                
    elites_x = samples_x[torch.argsort(J+G)][:EX]
    new_mean_x = torch.mean(elites_x, dim=0)
    new_var_x = torch.var(elites_x, dim=0)

    mean_x = 0.1 * mean_x + 0.9 * new_mean_x
    var_x = 0.1 * var_x + 0.9 * new_var_x

    elites_a = samples_a[torch.argsort(J+G)][:EA]
    new_mean_a = torch.mean(elites_a, dim=0)
    new_var_a = torch.var(elites_a, dim=0)

    mean_a = 0.1 * mean_a + 0.9 * new_mean_a
    var_a = 0.1 * var_a + 0.9 * new_var_a

    if torch.max(var_x) < 0.001 and torch.max(var_a) < 0.001:
        break
    
actions = mean_a.unsqueeze(0)
states = mean_x.unsqueeze(0)

In [None]:
Sx = []
for m in range(M):
        S = sys.unroll(states[:, m, :], actions[:, m*L:, :], L)
        Sx.append(S[:, :, :X].view(-1, X))
Sx = torch.stack(Sx)
Sx = Sx.view(H, -1).detach().cpu().numpy()
plt.plot(Sx[:, 0])

In [None]:
S = sys.unroll(states[:, 0, :], actions)
Sx = S[:, :, :X].view(-1, X)
Sx = Sx.detach().cpu().numpy()
plt.plot(Sx[:, 0])