In [1]:
import time
import jax.numpy as np
from jax import random, grad, jit, vmap, vjp
from jax import jacfwd, jacrev

import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
plt.rcParams.update({'font.size': 10})

from utils.utils import *
from cartpole_policy import policy, policy_jit
from ut_utils.ut_utils import *
from robot_models.custom_cartpole_constrained import CustomCartPoleEnv
from robot_models.cartpole2D import step
from gym_wrappers.record_video import RecordVideo

In [2]:
key = random.PRNGKey(0)

def generate_psd_params():
    n = 4
    N = 50
    diag = random.uniform(key, shape=(n,1))[:,0] + n
    off_diag = random.uniform(key, shape=(int( (n**2-n)/2.0 ),1))[:,0]
    params = np.append(diag, off_diag, axis = 0).reshape(1,-1)
    for i in range(1,50):
        # Diagonal elements
        params_temp = random.uniform( key, shape=( 1,int(n + (n**2 -n)/2.0)) )
        params = np.append( params, params_temp, axis = 0 )    
    return params

def get_future_reward(X, horizon, dt_outer, dynamics_params, params_policy):
    states, weights = initialize_sigma_points_jit(X)
    reward = 0
    for i in range(40):  
        mean_position = get_mean( states, weights )
        solution = policy( params_policy, mean_position )
        next_states_expanded, next_weights_expanded = sigma_point_expand( states, weights, solution, dt_outer, dynamics_params)#, gps )        
        next_states, next_weights = sigma_point_compress( next_states_expanded, next_weights_expanded )
        states = next_states
        weights = next_weights
        reward = reward + reward_UT_Mean_Evaluator_basic( states, weights )
    return reward

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
# get_future_reward_grad = grad(get_future_reward)
# get_future_reward_jit = jit(get_future_reward)

get_future_reward_grad = grad(get_future_reward, 4)
get_future_reward_grad_jit = jit(get_future_reward_grad)

In [4]:
# Set up environment
env_to_render = CustomCartPoleEnv(render_mode="human")
env = RecordVideo( env_to_render, video_folder="/home/hardik/Desktop/", name_prefix="cartpole_constrained_H20" )
observation, info = env.reset(seed=42)

polemass_length, gravity, length, masspole, total_mass, tau = env.polemass_length, env.gravity, env.length, env.masspole, env.total_mass, env.tau
dynamics_params = np.array([ polemass_length, gravity, length, masspole, total_mass])#, tau ])


  logger.warn(


x:0.0, theta:3.141592653589793
h1
x:0.0, theta:3.141592653589793
h1


  logger.warn(


In [5]:
# Initialize parameters
N = 50
H = 20
lr_rate = 0.01
param_w = random.uniform(key, shape=(N,1))[:,0] - 0.5#+ 0.5#+ 2.0  #0.5 work with Lr: 5.0
param_mu = random.uniform(key, shape=(4,N))- 0.5 * np.ones((4,N)) #- 3.5 * np.ones((4,N))
param_Sigma = generate_psd_params() # 10,N
params_policy = np.append( param_w, np.append( param_mu.reshape(-1,1)[:,0], param_Sigma.reshape(-1,1)[:,0] ) )

t = 0
dt_inner = 0.1
dt_outer = 0.1
tf = 5.0

state = np.copy(env.get_state())

In [6]:
get_future_reward_grad_jit( state, H, dt_outer, dynamics_params, params_policy )

2023-04-07 13:08:45.986617: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_get_future_reward] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-04-07 16:45:26.927777: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3h38m40.940927814s

********************************
[Compiling module jit_get_future_reward] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na

In [7]:
key = random.PRNGKey(10)