In [6]:
import os

import jax
from jbdl.rbdl.utils import ModelWrapper
from jax import device_put
from jbdl.rbdl.utils import xyz2int
import jax.numpy as jnp
from jbdl.rbdl.dynamics import forward_dynamics_core
from jbdl.rbdl.contact import detect_contact_core
from jbdl.rbdl.dynamics.state_fun_ode import dynamics_fun_extend_core, events_fun_extend_core
from jbdl.rbdl.dynamics import composite_rigid_body_algorithm_core
from jbdl.rbdl.contact.impulsive_dynamics import impulsive_dynamics_extend_core
from jbdl.rbdl.ode.solve_ivp import integrate_dynamics
from jax.custom_derivatives import closure_convert
import math
from jax.api import jit
from functools import partial
from jbdl.rbdl.tools import plot_model
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import numpy as np
import sys

# print(os.getcwd())
# # os.path.join(os.getcwd())
# # CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
# CURRENT_PATH = os.getcwd()
# print("CURRENT_PATH",CURRENT_PATH)
# SCRIPTS_PATH = os.path.dirname(CURRENT_PATH)
# MODEL_DATA_PATH = os.path.join(SCRIPTS_PATH, "model_data") 
# mdlw = ModelWrapper()
# mdlw.load(os.path.join(MODEL_DATA_PATH, 'half_max_v1.json'))
# model = mdlw.model





In [3]:
class Half_Qaudrupedal():
    """
    Description:
        A quadrupedal robot(UNITREE) environment. The robot has totally 14 joints,
        the previous two are one virtual joint and a base to chasis joint.
    """   
    def __init__(self, reward_fn=None, seed=0):
#         CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
        CURRENT_PATH = os.getcwd()
        SCRIPTS_PATH = os.path.dirname(CURRENT_PATH)
        MODEL_DATA_PATH = os.path.join(SCRIPTS_PATH, "model_data") 
        mdlw = ModelWrapper()
        mdlw.load(os.path.join(MODEL_DATA_PATH, 'half_max_v1.json'))
        model = mdlw.model
        self.model = model



        NC = int(model["NC"])
        NB = int(model["NB"])
        nf = int(model["nf"])
        contact_cond = model["contact_cond"]
        Xtree = device_put(model["Xtree"])
        ST = model["ST"]
        contactpoint = model["contactpoint"],
        idcontact = tuple(model["idcontact"])
        parent = tuple(model["parent"])
        jtype = tuple(model["jtype"])
        jaxis = xyz2int(model["jaxis"])
        contactpoint = model["contactpoint"]
        I = device_put(model["I"])
        a_grav = device_put(model["a_grav"])
        mu = device_put(0.9)
        contact_force_lb = device_put(contact_cond["contact_force_lb"])
        contact_force_ub = device_put(contact_cond["contact_force_ub"])
        contact_pos_lb = contact_cond["contact_pos_lb"]
        contact_vel_lb = contact_cond["contact_vel_lb"]
        contact_vel_ub = contact_cond["contact_vel_ub"]

        q0 = jnp.array([0.0,  0.4125, 0.0, math.pi/6, math.pi/6, -math.pi/3, -math.pi/3])
        qdot0 = jnp.zeros((7, ))

        q_star = jnp.array([0.0,  0.0, 0.0, math.pi/3, math.pi/3, -2*math.pi/3, -2*math.pi/3])
        qdot_star = jnp.zeros((7, ))

        x0 = jnp.hstack([q0, qdot0])
        t_span = (0.0, 2e-3)
        delta_t = 5e-4
        tau = 0.0

        # flag_contact = (0, 0)
        ncp = 0

        def dynamics_fun(x, t, Xtree, I, contactpoint, u, a_grav, \
            contact_force_lb, contact_force_ub,  contact_pos_lb, contact_vel_lb, contact_vel_ub, mu,\
            ST, idcontact,   parent, jtype, jaxis, NB, NC, nf, ncp):
            q = x[0:NB]
            qdot = x[NB:]
            tau = jnp.matmul(ST, u)
            flag_contact = detect_contact_core(Xtree, q, qdot, contactpoint, contact_pos_lb, contact_vel_lb, contact_vel_ub,\
                idcontact, parent, jtype, jaxis, NC)
            xdot,fqp, H = dynamics_fun_extend_core(Xtree, I, q, qdot, contactpoint, tau, a_grav, contact_force_lb, contact_force_ub,\
            idcontact, flag_contact, parent, jtype, jaxis, NB, NC, nf, ncp, mu)
            return xdot

        def events_fun(y, t, Xtree, I, contactpoint, u, a_grav, contact_force_lb, contact_force_ub, \
            contact_pos_lb, contact_vel_lb, contact_vel_ub, mu, ST, idcontact,  parent, jtype, jaxis, NB, NC, nf, ncp):
            q = y[0:NB]
            qdot = y[NB:]
            flag_contact = detect_contact_core(Xtree, q, qdot, contactpoint, contact_pos_lb, contact_vel_lb, contact_vel_ub,\
                idcontact, parent, jtype, jaxis, NC)

            value = events_fun_extend_core(Xtree, q, contactpoint, idcontact, flag_contact, parent, jtype, jaxis, NC)
            return value

        def impulsive_dynamics_fun(y, t, Xtree, I, contactpoint, u, a_grav, contact_force_lb, contact_force_ub, \
            contact_pos_lb, contact_vel_lb, contact_vel_ub, mu, ST, idcontact,  parent, jtype, jaxis, NB, NC, nf, ncp):
            q = y[0:NB]
            qdot = y[NB:]
            H =  composite_rigid_body_algorithm_core(Xtree, I, parent, jtype, jaxis, NB, q)
            flag_contact = detect_contact_core(Xtree, q, qdot, contactpoint, contact_pos_lb, contact_vel_lb, contact_vel_ub,\
                idcontact, parent, jtype, jaxis, NC)
            qdot_impulse = impulsive_dynamics_extend_core(Xtree, q, qdot, contactpoint, H, idcontact, flag_contact, parent, jtype, jaxis, NB, NC, nf)
            qdot_impulse = qdot_impulse.flatten()
            y_new = jnp.hstack([q, qdot_impulse])
            return y_new

        pure_dynamics_fun = partial(dynamics_fun, ST=ST, idcontact=idcontact, \
                parent=parent, jtype=jtype, jaxis=jaxis, NB=NB, NC=NC, nf=nf, ncp=ncp)

        pure_events_fun = partial(events_fun, ST=ST, idcontact=idcontact, \
                parent=parent, jtype=jtype, jaxis=jaxis, NB=NB, NC=NC, nf=nf, ncp=ncp)

        pure_impulsive_fun =  partial(impulsive_dynamics_fun, ST=ST, idcontact=idcontact, \
            parent=parent, jtype=jtype, jaxis=jaxis, NB=NB, NC=NC, nf=nf, ncp=ncp)


        # def _dynamics_step(pure_dynamics_fun, y0, t_span, delta_t, event, impulsive, *args):
        #     t_eval, sol =  integrate_dynamics(pure_dynamics_fun, y0, t_span, delta_t, event, impulsive, args=args)
        #     yT = sol[-1, :]
        #     return yT

        def _dynamics_step(y0, *args):
            t_eval, sol =  integrate_dynamics(pure_dynamics_fun, y0, t_span, delta_t, pure_events_fun, pure_impulsive_fun, args=args)
            yT = sol[-1, :]
            return yT

        u = jnp.zeros((4,))
        self.pure_args = (Xtree, I, contactpoint, u, a_grav, contact_force_lb, contact_force_ub,  contact_pos_lb, contact_vel_lb, contact_vel_ub, mu)

        self.dynamics_step = _dynamics_step
        print(self.dynamics_step(x0, *self.pure_args))



In [4]:
half_Qaudrupedal = Half_Qaudrupedal()







[ 4.7105610e-07  4.1248140e-01  1.4034277e-06  5.2368802e-01
  5.2368540e-01 -1.0473584e+00 -1.0473539e+00  4.7097684e-04
 -1.8617591e-02  1.4031911e-03  8.9197405e-02  8.6597353e-02
 -1.6085698e-01 -1.5618184e-01]


In [18]:
from numbers import Real
import jax
import jax.numpy as jnp
import numpy as np
import functools
import numpy.random as npr

class Deep_Agent():

    def __init__(
        self,
        state_size,
        action_size,
    ):
        self.state_size = state_size
        self.action_size = action_size

        #init actor and critic
        param_scale = 0.1
        actor_layer_sizes = [self.state_size, 512, 128, 2 * self.action_size]
        self.params = self.init_random_params(param_scale, actor_layer_sizes)
        critic_layer_sizes = [self.state_size, 512, 128, 1]
        self.value_params = self.init_random_params(param_scale, critic_layer_sizes)
        # rnn_layer_sizes = [2 * self.state_size, 32, 32, self.state_size]
        # self.rnn_params = self.init_random_params(param_scale, rnn_layer_sizes)

        self.value_losses = []
        self.h_t = jnp.zeros(4)

    def init_random_params(self, scale, layer_sizes):
        rng=npr.RandomState(0)
        return [(scale * rng.randn(m, n), scale * rng.randn(n))
            for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]        

    def sample_action(self, state, params):
        activations = state
        for w, b in params[:-1]:
            outputs = jnp.dot(activations, w) + b
            activations = jnp.tanh(outputs)
            # activations = jax.nn.relu(outputs)
        final_w, final_b = params[-1]
        logits = jnp.dot(activations, final_w) + final_b
        mu, sigma = jnp.split(logits, 2)

        eps = np.random.randn(1)
        self.action =  mu + sigma * eps  
        self.state = state        
        return self.action

    def value(self, state, params):
        """
        estimate the value of state
        """
        activations = state
        for w, b in params[:-1]:
            outputs = jnp.dot(activations, w) + b
            activations = jnp.tanh(outputs)
            # activations = jax.nn.relu(outputs)
        final_w, final_b = params[-1]
        logits = jnp.dot(activations, final_w) + final_b
        return logits[0]


In [27]:
agent = Deep_Agent(
             state_size = 14,
             action_size = 7,
            )
init_state = jnp.zeros((7, ))
second_state = jnp.zeros((7, ))
state = jnp.concatenate([init_state,second_state])
print("state",state)
control = agent.sample_action(state, agent.params)
u = control[3:7]
print(u)

state [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[-0.13894713 -0.09574576 -0.54688954  0.36939225]


In [8]:
def step_fun(xk,params):
    q_star = jnp.array([0.0,  0.0, 0.0, math.pi/6, -math.pi/6, -math.pi/3, math.pi/3])
    qdot_star = jnp.zeros((7, ))
#     control = agent.sample_action(xk, params)
#     u = control[3:7]
    kp, kd = params
    u = kp * (q_star[3:7] - xk[3:7]) + kd * (qdot_star[3:7] - xk[10:14])
    Xtree, I, contactpoint, u0, a_grav, contact_force_lb, contact_force_ub, contact_pos_lb, contact_vel_lb, contact_vel_ub,mu = half_Qaudrupedal.pure_args
    pure_args = (Xtree, I, contactpoint, u, a_grav, contact_force_lb, contact_force_ub, contact_pos_lb, contact_vel_lb, contact_vel_ub,mu)
    # print("xk:", xk)
    # print("u", u)
    xk = half_Qaudrupedal.dynamics_step(xk, *pure_args)
    loss = jnp.sum((q_star[3:7] - xk[3:7])**2) + jnp.sum((qdot_star[3:7] - xk[10:14])**2)
    return loss

grad_fun = jax.value_and_grad(step_fun,argnums=1)

def update(grads, params, lr=1e-3):
    """
    Description: update weights
    """
    #get norm square
#     total_norm_sqr = 0                
#     for (dw,db) in grads:
#         # print("previous dw",dw)
#         # dw = normalize(dw)
#         # db = normalize(db[:,np.newaxis],axis =0).ravel()
#         total_norm_sqr += np.linalg.norm(dw) ** 2
#         total_norm_sqr += np.linalg.norm(db) ** 2
#     # print("grads",grads)

#     #scale the gradient
#     # print("gradient total_norm_sqr",total_norm_sqr)
#     gradient_clip = 0.2
#     scale = min(
#         1.0, gradient_clip / (total_norm_sqr**0.5 + 1e-4))

#     params = [(w - lr * dw, b - lr * db)
#             for (w, b), (dw, db) in zip(params, grads)]
    dkp, dkd = grads
    params = jnp.array([kp - lr*dkp,kd - lr*dkd])

    return params  

In [None]:
# half_Qaudrupedal = Half_Qaudrupedal()
# %%
%matplotlib 

q0 = np.array([0,  0.5, 0, math.pi/6, -math.pi/6, -math.pi/3, math.pi/3]) # stand with leg in
qdot0 = jnp.zeros((7, ))
x0 = jnp.hstack([q0, qdot0])

q_star = jnp.array([0.0,  0.0, 0.0, math.pi/6, -math.pi/6, -math.pi/3, math.pi/3])
qdot_star = jnp.zeros((7, ))

kp = 200
kd = 3
kp = 50
kd = 1
xksv = []
T = 2e-3

xk = x0
# plt.figure()
# plt.ion()

# fig = plt.gcf()
# ax = Axes3D(fig)
params = jnp.array([50.,1.])

for i in range(500):
    print(i)
#     u = kp * (q_star[3:7] - xk[3:7]) + kd * (qdot_star[3:7] - xk[10:14])
    # u = jnp.array([0., 0., 0., 0.])
    
#     control = agent.sample_action(xk, agent.params)
#     u = control[3:7]
#     Xtree, I, contactpoint, u0, a_grav, contact_force_lb, contact_force_ub, contact_pos_lb, contact_vel_lb, contact_vel_ub,mu = half_Qaudrupedal.pure_args
#     pure_args = (Xtree, I, contactpoint, u, a_grav, contact_force_lb, contact_force_ub, contact_pos_lb, contact_vel_lb, contact_vel_ub,mu)
#     # print("xk:", xk)
#     # print("u", u)
#     xk = half_Qaudrupedal.dynamics_step(xk, *pure_args)
    
    loss, grads = grad_fun(xk,params)
    print("loss",loss)
    params = update(grads, params, 1e-3)
    
#     # xksv.append(xk)
#     ax.clear()
#     plot_model(half_Qaudrupedal.model, xk[0:7], ax)
#     # fcqp = np.array([0, 0, 1, 0, 0, 1])
#     # plot_contact_force(model, xk[0:7], contact_force["fc"], contact_force["fcqp"], contact_force["fcpd"], 'fcqp', ax)
#     ax.view_init(elev=0,azim=-90)
#     ax.set_xlabel('X')
#     ax.set_xlim(-0.3, -0.3+0.6)
#     ax.set_ylabel('Y')
#     ax.set_ylim(-0.15, -0.15+0.6)
#     ax.set_zlabel('Z')
#     ax.set_zlim(-0.1, -0.1+0.6)
#     ax.set_title('Frame')
#     plt.pause(1e-8)
#     # fig.canvas.draw()
# plt.ioff()    

Using matplotlib backend: MacOSX
0
