## Import required libraries
### Author: Sameer
### Date: May 2019

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from CartPole import CartPole
# from CartPole_GPS import CartPole_GPS

from ilqr.dynamics import constrain
from copy import deepcopy

from EstimateDynamics import local_estimate
from GMM import Estimated_Dynamics_Prior

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel

from mujoco_py import load_model_from_path, MjSim, MjViewer
import mujoco_py

import time



### Formulate the iLQR problem

In [2]:
'''
1 - dt = time step
2 - N = Number of control points in the trajectory
3 - x0 = Initial state
4 - x_goal = Final state
5 - Q = State cost
6 - R = Control cost
7 - Q_terminal = Cost at the final step
8 - x_dynamics array stores the information regarding system. 
    x_dynamics[0] = m = mass of the pendulum bob 
    x_dynamics[1] = M = mass of the cart 
    x_dynamics[2] = L = length of the massles|s rod 
    x_dynamics[3] = g = gravity 
    x_dynamics[4] = d = damping in the system
'''
dt = 0.05
N = 600  # Number of time steps in trajectory.
x_dynamics = np.array([0.1, 1, 1, 9.80665, 0]) # m=1, M=5, L=2, g=9.80665, d=1
x0 = np.array([0.0, 0.0, 3.14, 0.0])  # Initial state
x_goal = np.array([0.0, 0.0, 0.0, 0.0])
# Instantenous state cost.
Q = np.eye(5)
Q[1,1] = 10
Q[2, 2] = 100
Q[3, 3] = 100
Q[4, 4] = 10
# Terminal state cost.
Q_terminal = np.eye(5) * 100
# Q_terminal[2, 2] = 100
# Q_terminal[3, 3] = 100
# Instantaneous control cost.
R = np.array([[1.0]])

### iLQR on Cart Pole

In [3]:
cartpole_prob = CartPole(dt, N, x_dynamics, x0, x_goal, Q, R, Q_terminal)
xs, us, K, k = cartpole_prob.run_IterLinQuadReg()

iteration 0 accepted 213048.29221466993 [-0.1920869  -0.36947147 -0.28879993 -1.29797015]
iteration 1 accepted 212317.96651019467 [ 0.00646396 -0.13916409  0.07197629 -0.14933402]
iteration 2 accepted 212031.58516764553 [ 0.02914881 -0.1065376   0.07636714 -0.13755234]
iteration 3 accepted 211764.5205322962 [ 0.02844512 -0.10666818  0.07842651 -0.1389912 ]
iteration 4 accepted 211511.15509814498 [ 0.02851199 -0.10663961  0.08058652 -0.1399636 ]
iteration 5 accepted 211268.97241963306 [ 0.02860753 -0.10666771  0.08266774 -0.14096943]
iteration 6 accepted 211036.33877322043 [ 0.02869999 -0.10670817  0.08468708 -0.14196921]
iteration 7 accepted 210812.0605263559 [ 0.02878825 -0.10675937  0.08664882 -0.14296321]
iteration 8 accepted 210595.18523357215 [ 0.02887304 -0.10681952  0.08855712 -0.14395069]
iteration 9 accepted 210384.90008794953 [ 0.02895488 -0.1068872   0.09041565 -0.14493118]
iteration 10 accepted 210180.47698212363 [ 0.02903411 -0.10696119  0.09222771 -0.14590432]
iteration 1

iteration 91 accepted 178637.78381546584 [ 0.03213405 -0.11267795  0.18180995 -0.20700273]
iteration 92 accepted 178577.92189679213 [ 0.0321425  -0.1127115   0.18263211 -0.20764175]
iteration 93 accepted 178533.1611734013 [ 0.03215207 -0.11274466  0.18345148 -0.20827975]
iteration 94 accepted 178498.33192349153 [ 0.03216264 -0.1127774   0.18426812 -0.20891678]
iteration 95 accepted 178470.15679316764 [ 0.03217407 -0.1128097   0.1850821  -0.20955286]
iteration 96 accepted 178446.51596349184 [ 0.03218622 -0.11284153  0.18589347 -0.21018801]
iteration 97 accepted 178426.0086224182 [ 0.03219897 -0.11287286  0.1867023  -0.21082227]
iteration 98 accepted 178407.68974462955 [ 0.03221221 -0.11290369  0.18750865 -0.21145567]
iteration 99 accepted 178390.9094012413 [ 0.03222584 -0.11293399  0.18831257 -0.21208822]
iteration 100 accepted 178375.2126360218 [ 0.03223979 -0.11296374  0.18911413 -0.21271997]
iteration 101 accepted 178360.27575891395 [ 0.03225398 -0.11299294  0.18991338 -0.21335094]
i

iteration 181 accepted 174090.90064500662 [ 0.03320559 -0.11353949  0.25012889 -0.2637463 ]
iteration 182 accepted 173731.9944503023 [ 0.03318784 -0.11351883  0.25087084 -0.2644009 ]
iteration 183 accepted 173320.58356782995 [ 0.03315599 -0.11349446  0.25161303 -0.26505585]
iteration 184 accepted 172844.81900018977 [ 0.03310558 -0.11346539  0.25235539 -0.26571089]
iteration 185 accepted 172288.7970616047 [ 0.03303196 -0.11343055  0.25309787 -0.26636573]
iteration 186 accepted 171630.90054617223 [ 0.03293178 -0.1133892   0.25384042 -0.26702019]
iteration 187 accepted 170841.39926464343 [ 0.03280714 -0.11334183  0.25458309 -0.26767446]
iteration 188 accepted 169878.17184797002 [ 0.03267618 -0.11329259  0.2553262  -0.26832979]
iteration 189 accepted 168676.47401210354 [ 0.0325968  -0.1132546   0.25607074 -0.26899018]
iteration 190 accepted 167119.2112273683 [ 0.0326967  -0.11325681  0.25681889 -0.26966426]
iteration 191 accepted 164934.14264471718 [ 0.03289724 -0.11328084  0.25756973 -0.2

iteration 271 accepted 159874.0362641587 [ 0.03353607 -0.10962795  0.32156198 -0.33098762]
iteration 272 accepted 159870.22340076568 [ 0.03353661 -0.10954874  0.32243974 -0.3318673 ]
iteration 273 accepted 159866.39180401273 [ 0.03353702 -0.10946852  0.32332009 -0.33275094]
iteration 274 accepted 159862.54170899405 [ 0.03353729 -0.10938728  0.32420305 -0.33363856]
iteration 275 accepted 159858.67334158873 [ 0.03353741 -0.109305    0.32508864 -0.33453021]
iteration 276 accepted 159854.78691905725 [ 0.03353739 -0.10922168  0.32597691 -0.33542593]
iteration 277 accepted 159850.88265071102 [ 0.03353722 -0.1091373   0.32686787 -0.33632576]
iteration 278 accepted 159846.96073864913 [ 0.0335369  -0.10905186  0.32776154 -0.33722975]
iteration 279 accepted 159843.02137853717 [ 0.03353642 -0.10896534  0.32865797 -0.33813792]
iteration 280 accepted 159839.06476042018 [ 0.03353579 -0.10887773  0.32955716 -0.33905033]
iteration 281 accepted 159835.09106954344 [ 0.033535   -0.10878902  0.33045915 -0

iteration 361 accepted 159474.61436050327 [ 0.03261537 -0.0968276   0.41424706 -0.43171147]
iteration 362 accepted 159469.52985721448 [ 0.03258939 -0.09659631  0.41547409 -0.43315636]
iteration 363 accepted 159464.417178515 [ 0.03256299 -0.09636235  0.41670645 -0.4346106 ]
iteration 364 accepted 159459.27493451288 [ 0.03253616 -0.09612568  0.41794416 -0.43607425]
iteration 365 accepted 159454.1016633041 [ 0.03250889 -0.09588629  0.41918726 -0.43754741]
iteration 366 accepted 159448.8958282862 [ 0.03248119 -0.09564413  0.42043578 -0.43903016]
iteration 367 accepted 159443.6558153555 [ 0.03245304 -0.09539918  0.42168976 -0.44052258]
iteration 368 accepted 159438.37992999578 [ 0.03242445 -0.09515141  0.42294923 -0.44202476]
iteration 369 accepted 159433.06639423987 [ 0.03239542 -0.09490078  0.42421421 -0.44353679]
iteration 370 accepted 159427.71334350872 [ 0.03236593 -0.09464726  0.42548475 -0.44505875]
iteration 371 accepted 159422.31882332131 [ 0.032336   -0.09439082  0.42676087 -0.446

iteration 451 accepted 158427.8172056292 [ 0.02783249 -0.0608966   0.54922999 -0.60970447]
iteration 452 accepted 158394.45716382167 [ 0.02773211 -0.06026362  0.5510318  -0.61234697]
iteration 453 accepted 158359.21692943687 [ 0.02762919 -0.05962376  0.55284023 -0.61500629]
iteration 454 accepted 158321.90725846405 [ 0.02752349 -0.05897695  0.55465526 -0.61768245]
iteration 455 accepted 158282.3166340692 [ 0.02741476 -0.05832307  0.55647685 -0.6203755 ]
iteration 456 accepted 158240.20809299164 [ 0.02730269 -0.05766201  0.55830498 -0.62308545]
iteration 457 accepted 158195.31553800375 [ 0.02718692 -0.05699367  0.56013961 -0.62581232]
iteration 458 accepted 158147.33944994755 [ 0.02706703 -0.0563179   0.56198069 -0.62855612]
iteration 459 accepted 158095.9418966236 [ 0.02694253 -0.05563457  0.5638282  -0.63131686]
iteration 460 accepted 158040.74071652058 [ 0.02681285 -0.0549435   0.56568209 -0.63409455]
iteration 461 accepted 157981.30273260633 [ 0.02667731 -0.05424451  0.5675423  -0.6

iteration 540 accepted 119627.15477177688 [ 0.01658933  0.02432133  0.73056906 -0.90925024]
iteration 541 accepted 119625.81901886949 [ 0.01638943  0.02564757  0.73278226 -0.91328466]
iteration 542 accepted 119624.51666400902 [ 0.01618802  0.02698216  0.73499793 -0.91733131]
iteration 543 accepted 119623.24625368156 [ 0.0159851   0.0283251   0.73721606 -0.92139012]
iteration 544 accepted 119622.00638078342 [ 0.01578067  0.02967641  0.73943665 -0.92546105]
iteration 545 accepted 119620.79569088864 [ 0.01557472  0.0310361   0.74165968 -0.92954404]
iteration 546 accepted 119619.61288741864 [ 0.01536725  0.03240418  0.74388515 -0.93363905]
iteration 547 accepted 119618.45673568881 [ 0.01515826  0.03378066  0.74611307 -0.93774604]
iteration 548 accepted 119617.32606584363 [ 0.01494775  0.03516556  0.74834344 -0.94186496]
iteration 549 accepted 119616.21977472707 [ 0.01473571  0.0365589   0.75057627 -0.9459958 ]
iteration 550 accepted 119615.13682676051 [ 0.01452215  0.0379607   0.75281157 -

iteration 628 accepted 119570.78419249076 [-0.00824215  0.18172144  0.9443932  -1.32401872]
iteration 629 accepted 119570.49973976513 [-0.00865947  0.18428472  0.94737847 -1.33000316]
iteration 630 accepted 119570.21678435808 [-0.00908259  0.18688156  0.95039321 -1.33604744]
iteration 631 accepted 119569.93507601894 [-0.00951169  0.18951312  0.95343858 -1.34215378]
iteration 632 accepted 119569.65435932174 [-0.009947    0.19218065  0.95651585 -1.3483245 ]
iteration 633 accepted 119569.37437307985 [-0.01038872  0.19488542  0.95962629 -1.35456202]
iteration 634 accepted 119569.09484970628 [-0.01083709  0.19762879  0.96277128 -1.3608689 ]
iteration 635 accepted 119568.8155145074 [-0.01129237  0.20041219  0.96595223 -1.3672478 ]
iteration 636 accepted 119568.53608491129 [-0.01175479  0.20323711  0.96917067 -1.37370153]
iteration 637 accepted 119568.25626961599 [-0.01222465  0.20610513  0.97242814 -1.38023302]
iteration 638 accepted 119567.9757676562 [-0.01270222  0.2090179   0.97572633 -1.

iteration 718 accepted 109343.28108243932 [ 0.00156487 -0.00349134  0.00033833 -0.00291212]
iteration 719 accepted 109343.27753941831 [ 0.00156484 -0.00349131  0.00033833 -0.00291211]
iteration 720 accepted 109343.27413819938 [ 0.00156482 -0.00349129  0.00033832 -0.00291209]
iteration 721 accepted 109343.27087299219 [ 0.0015648  -0.00349127  0.00033832 -0.00291207]
iteration 722 accepted 109343.26773825778 [ 0.0015648  -0.00349125  0.00033832 -0.00291205]
iteration 723 accepted 109343.26472869307 [ 0.0015648  -0.00349123  0.00033832 -0.00291204]
iteration 724 accepted 109343.26183921547 [ 0.0015648  -0.00349122  0.00033831 -0.00291202]
iteration 725 accepted 109343.25906495185 [ 0.00156481 -0.0034912   0.00033831 -0.00291201]
iteration 726 accepted 109343.25640122716 [ 0.00156482 -0.00349118  0.00033831 -0.00291199]
iteration 727 accepted 109343.25384355625 [ 0.00156483 -0.00349117  0.00033831 -0.00291198]
iteration 728 accepted 109343.25138763442 [ 0.00156484 -0.00349116  0.0003383  -

iteration 808 accepted 109343.19175869711 [ 0.00156555 -0.00348992  0.00033813 -0.00291076]
iteration 809 accepted 109343.19160143517 [ 0.00156555 -0.00348991  0.00033813 -0.00291074]
iteration 810 accepted 109343.19144734529 [ 0.00156555 -0.00348989  0.00033813 -0.00291072]
iteration 811 accepted 109343.19129629497 [ 0.00156555 -0.00348987  0.00033813 -0.00291071]
iteration 812 accepted 109343.19114815605 [ 0.00156555 -0.00348985  0.00033812 -0.00291069]
iteration 813 accepted 109343.1910028065 [ 0.00156555 -0.00348983  0.00033812 -0.00291068]
iteration 814 accepted 109343.19086012931 [ 0.00156555 -0.00348982  0.00033812 -0.00291066]
iteration 815 accepted 109343.19072001238 [ 0.00156555 -0.0034898   0.00033812 -0.00291064]
iteration 816 accepted 109343.190582348 [ 0.00156555 -0.00348978  0.00033811 -0.00291063]
iteration 817 accepted 109343.19044703309 [ 0.00156555 -0.00348976  0.00033811 -0.00291061]
iteration 818 accepted 109343.19031396892 [ 0.00156555 -0.00348975  0.00033811 -0.0

iteration 898 accepted 109343.18263289829 [ 0.00156527 -0.00348827  0.00033792 -0.00290931]
iteration 899 accepted 109343.18254924953 [ 0.00156526 -0.00348825  0.00033792 -0.00290929]
iteration 900 accepted 109343.18246566126 [ 0.00156526 -0.00348824  0.00033791 -0.00290928]
iteration 901 accepted 109343.18238213086 [ 0.00156525 -0.00348822  0.00033791 -0.00290926]
iteration 902 accepted 109343.18229865539 [ 0.00156525 -0.0034882   0.00033791 -0.00290924]
iteration 903 accepted 109343.18221523237 [ 0.00156524 -0.00348818  0.00033791 -0.00290923]
iteration 904 accepted 109343.18213185949 [ 0.00156524 -0.00348816  0.0003379  -0.00290921]
iteration 905 accepted 109343.18204853401 [ 0.00156523 -0.00348814  0.0003379  -0.0029092 ]
iteration 906 accepted 109343.18196525394 [ 0.00156523 -0.00348812  0.0003379  -0.00290918]
iteration 907 accepted 109343.18188201723 [ 0.00156522 -0.0034881   0.0003379  -0.00290916]
iteration 908 accepted 109343.18179882137 [ 0.00156522 -0.00348809  0.00033789 -

iteration 988 accepted 109343.17518539653 [ 0.00156478 -0.00348658  0.0003377  -0.00290785]
iteration 989 accepted 109343.17510273894 [ 0.00156478 -0.00348656  0.0003377  -0.00290783]
iteration 990 accepted 109343.17502007632 [ 0.00156477 -0.00348654  0.00033769 -0.00290782]
iteration 991 accepted 109343.1749374091 [ 0.00156476 -0.00348652  0.00033769 -0.0029078 ]
iteration 992 accepted 109343.17485473688 [ 0.00156476 -0.0034865   0.00033769 -0.00290778]
iteration 993 accepted 109343.17477205941 [ 0.00156475 -0.00348648  0.00033769 -0.00290777]
iteration 994 accepted 109343.1746893768 [ 0.00156475 -0.00348646  0.00033768 -0.00290775]
iteration 995 accepted 109343.17460668912 [ 0.00156474 -0.00348644  0.00033768 -0.00290773]
iteration 996 accepted 109343.17452399617 [ 0.00156474 -0.00348642  0.00033768 -0.00290772]
iteration 997 accepted 109343.17444129816 [ 0.00156473 -0.00348641  0.00033768 -0.0029077 ]
iteration 998 accepted 109343.17435859481 [ 0.00156472 -0.00348639  0.00033767 -0.

In [5]:
# State matrix split into individual states. For plotting and analysing purposes.
t = np.arange(N + 1) * dt
x = xs[:, 0] # Position
x_dot = xs[:, 1] # Velocity
theta = np.unwrap(cartpole_prob.deaugment_state(xs)[:, 2])  # Theta, makes for smoother plots.
theta_dot = xs[:, 3] # Angular velocity
us_scaled = constrain(us, -1, 1)

### Simulate the real system and generate the data
Cost matrices, initial position and goal position will remain same as the above problem. As it indicates one policy. But still the initial positions and goal positions must be passed explicitly to the function. But you don't need to pass cost matrices (assume penalty on the system is same), this is just used to use to calculate the cost of the trajectory. Correct control action must be passed. Parameter gamma indicates how much of original data you want to keep

Variance of the Gaussian noise will be taken as input from a Unif(0, var_range) uniform distribution. Inputs: x_initial, x_goal, u, n_rollouts, pattern='Normal', pattern_rand=False, var_range=10, gamma=0.2, percent=20

Pattern controls how the control sequence will be modified after applying white Guassian noise (zero mean).
- Normal: based on the correction/mixing parameter gamma generate control (gamma controls how much noise we want).
- MissingValue: based on the given percentage, set those many values to zero (it is implicitly it uses "Normal" generated control is used). 
- Shuffle: shuffles the entire "Normal" generated control sequence.
- TimeDelay: takes the "Normal" generated control and shifts it by 1 index i.e. one unit time delay.
- Extreme: sets gamma as zeros and generates control based on only noise.

If 'pattern_rand' is 'True' then we don't need to send the explicitly, it will chose one randomly for every rollout (default is 'False'). If you want to chose specific pattern then send it explicitly. 

In [None]:
x_rollout, u_rollout, local_policy, cost = cartpole_prob.gen_rollouts(x0, x_goal, us, n_rollouts=10, pattern_rand=True, var_range=10, gamma=0.2, percent=20)

### Local system dynamics/model estimate
loca_estimate: function takes the states (arranged in a special format, [x(t), u(t), x(t+1)]), no. of gaussian mixtures and no.of states.

In [None]:
model = Estimated_Dynamics_Prior(init_sequential=False, eigreg=False, warmstart=True, 
                 min_samples_per_cluster=20, max_clusters=50, max_samples=20, strength=1.0)
model.update_prior(x_rollout, u_rollout)
A, B, C = model.fit(x_rollout, u_rollout)

In [None]:
print(A.shape)
print(B.shape)
print(C.shape)

In [6]:
Model = "mujoco/cartpole.xml"
model_loaded = load_model_from_path(Model)
sim = MjSim(model_loaded)

In [7]:
viewer = mujoco_py.MjViewer(sim)
t = 0
sim.data.qpos[0] = 0.0
sim.data.qpos[1] = 3.14
sim.data.qvel[0] = 0
sim.data.qvel[1] = 0
final = 0
for i in range(600):
    start_time = time.time()
    state = np.c_[sim.data.qpos[0],sim.data.qvel[0],np.sin(sim.data.qpos[1]),
                  np.cos(sim.data.qpos[1]),sim.data.qvel[1]].T
    control = np.dot(k[i,:],(xs[i].reshape(5,1) - state ))  + K[i].T + us[i]
    sim.data.ctrl[0] = us_scaled[i]
#     sim.data.ctrl[0] = control
    sim.step()
    viewer.render()
    if (sim.data.qpos[0] == 1.0 and sim.data.qpos[1] == 0):
        print('states reached')
        break
print(sim.get_state())

Creating window glfw
MjSimState(time=30.00000000000029, qpos=array([-0.17557436, -2.23021848]), qvel=array([-1.02726635, -2.01922997]), act=None, udd_state={})


In [None]:
import time
time.sleep(5)

In [None]:
from Simulator import Mujoco_sim
Model = "mujoco/cartpole.xml"
cart_pole_simulator = Mujoco_sim(Model,True)
cart_pole_simulator.load(xs,us,k,K,x0,initial=False)
cart_pole_simulator.runSimulation()

In [None]:
cart_pole_simulator.runSimulation()

In [None]:
np.max(xs)

In [None]:
np.max(us)

In [None]:
import matplotlib.pyplot as plt
# plt.plot(xs[0,:])
plt.plot(us)

In [None]:
import sys 
np.set_printoptions(threshold=sys.maxsize)