# Learn Actor Crititc

In [1]:
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt


import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".60"

# Optionally, force JAX to preallocate memory.
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Setup environment variable for Nvidia GPU acceleration
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
)
backend = 'gpu'
# backend = 'METAL'
# backend = 'cpu'

import jax
from jax import numpy as jp
import flax
import flax.linen as nn
import optax


# Enable compliation catch
jax.config.update("jax_compilation_cache_dir", "./jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 5)
# jax.config.update("jax_explain_cache_misses", True)

# Debug Nan
# jax.config.update("jax_debug_nans", True)


# More legible printing from numpy.
jp.set_printoptions(precision=4, suppress=True, linewidth=100)

import mujoco
import mujoco.mjx as mjx
from mujoco.mjx._src import scan
from mujoco.mjx._src import types

# More legible printing from numpy.
np.set_printoptions(precision=4, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

device = jax.devices(backend=backend)[0]

model_path = '../model/inverted_pendulum.xml'

# Single step
mjx_step = jax.jit(mjx.step, backend=backend)



# mjx_multiple_steps = jax.jit(multiple_steps, backend=backend, )

## Load Model

In [2]:
from mujoco.mjx._src.biomtu.acceleration_mtu import acceleration_mtu

mj_model = mujoco.MjModel.from_xml_path(model_path)
mjx_model = mjx.put_model(mj_model,device=device)

# Disable tendon
opt = mjx_model.opt.replace(disableflags = mjx_model.opt.disableflags |mujoco.mjtDisableBit.mjDSBL_PASSIVE)
mjx_model = mjx_model.replace(opt=opt)

mjx_data = mjx.make_data(mjx_model)
mj_data = mujoco.MjData(mj_model)

# Load the Keyframe
# mjx_data = mjx_data.replace(qpos = mj_model.key_qpos[0])
# mj_data.qpos = mj_model.key_qpos[0]

# Calculate equilibrum
# mjx_data = acceleration_mtu.calc_equilibrium(mjx_model, mjx_data)
mjx_data = mjx_step(mjx_model, mjx_data)

## Control Neural Network and Critic Neural Network
For now this NN will only work for the inverted pendulum

In [3]:

class Controller_NN(nn.Module):
    out_dims = 1
    def setup(self):
        # Features means the output dimension
        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=self.out_dims*2)
        
    
    def __call__(self, x, key):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        mean = x[0]
        logstd = x[1]
        std = jp.exp(logstd)
        samples = jax.random.normal(key)*std*0.3 + mean
        
        return samples, mean, logstd

# Test the neural network
controller = Controller_NN()

# Init the model
key = jax.random.key(66)
sub_keys = jax.random.split(key,1)
# The second parameter is the dommy input
controller_params = controller.init(key,jp.empty([1, mjx_model.nq*2]),sub_keys[0])
# print(params)
jit_controller = jax.jit(lambda params, states, key : controller.apply(params, states, key))

class QCritic_NN(nn.Module):

    def setup(self):
        # Features means the output dimension
        self.linear1 = nn.Dense(features=400)
        self.linear2 = nn.Dense(features=400)
        self.linear3 = nn.Dense(features=400)
        self.linear4 = nn.Dense(features=400)
        # The last layer will output the mean and logstd
        self.linear5 = nn.Dense(features=1)
        
    
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        x = nn.relu(x)
        x = self.linear3(x)
        x = nn.relu(x)
        x = self.linear4(x)
        x = nn.relu(x)
        x = self.linear5(x)
        # The last layer of the neural requires samping
        return x

critic = QCritic_NN()

# Init the critic neural network by providing the input dummy. The input dummy defines the input shape
key = jax.random.split(sub_keys[0],1)[0]
sub_key = jax.random.split(key,1)[0]

# The input for the qcritic neural network should be observations and actions
# For the obs=mjx_model.nq*2, actions=1
critic_params = critic.init(key, jp.empty([1,mjx_model.nq*2 + 1]))
jit_critic = jax.jit(lambda params, state_acts: critic.apply(params, state_acts))

# Test the two neural networks
print(jit_controller(controller_params, jp.ones(mjx_model.nq*2), sub_keys[0]))
print(jit_critic(critic_params, jp.ones(mjx_model.nq*2 + 1)))


(Array(-0.2629, dtype=float32), Array(-0.0379, dtype=float32), Array(0.0569, dtype=float32))
[-0.1495]


## Combine Neural Net and Simulation into one Jax function

In [4]:
def step_fn(carry, _):
    data, model = carry
    new_data = mjx.step(model, data)
    new_carry = (new_data, model)
    return new_carry, _

def multiple_steps(model, data):
    init_carry = (data, model)
    y, _ = jax.lax.scan(step_fn, init_carry, None, length=10)
    new_data = y[0]
    return new_data

def nn_mjx_one_step(nn_params, model, data, key):
    states = jp.concatenate([data.qpos, data.qvel])
    act = jit_controller(nn_params, states, key)[0]
    # Generate the next key
    new_key = jax.random.split(key,1)[0]
    data = data.replace(ctrl = jp.array([act]))
    new_data = mjx.step(model, data)
    return new_data, new_key

def nn_step_fn(carry, _):
    nn_params, model, data, key = carry
    new_data, new_key = nn_mjx_one_step(nn_params, model, data, key)
    new_carry = (nn_params, model, new_data, new_key)
    # Calculate reward
    loss = new_data.qpos[1]**2
    return new_carry, loss

def decay_sum_scan(x, decay):
    def f(sxtm1, xt):
        b = xt + decay * sxtm1
        return b, b
    return jax.lax.scan(f, jp.zeros(x.shape[1:]), x)[1]

@jax.jit
def jit_nn_multi_steps(nn_params, model, data, key):
    repeat_length = 10
    init_carry = (nn_params, model, data, key)
    y, losses = jax.lax.scan(nn_step_fn, init_carry, None, length=repeat_length)
    new_data = y[2]
    new_key = y[3]
    loss = decay_sum_scan(losses, 0.95)[repeat_length -1]
    return new_data, loss, new_key

@jax.jit
def jit_v_nn_multi_steps(nn_params, model, data, key):
    return jax.vmap(jit_nn_multi_steps, in_axes=(None, None, 0, 0))(nn_params, model, data, key)

# This function generate
@jax.jit
def jit_vv_nn_multi_steps(nn_params, model, data, key):
    return jax.vmap(jit_v_nn_multi_steps, in_axes=(None, None, None, 1))(nn_params, model, data, key)

## Environment Control 

In [5]:
# @jax.jit
def reset(model, batch_size):
    batch_dummy = jp.zeros(batch_size)
    v_make_data = jax.jit(jax.vmap(lambda model, batch_dummy: mjx.make_data(model), in_axes=(None,0),out_axes=0))
    new_datas = v_make_data(model, batch_dummy)
    return new_datas

## Visualize the model and controller

In [None]:
import mujoco.viewer
import time

previous_frame_time = time.time()
i = 0
key = jax.random.key(334)
with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
    while viewer.is_running():
        # Update mjx_data from mj_data. The mj_data was modified by the viewer
        # mjx_data = mjx_data.replace(ctrl=mj_data.ctrl, xfrc_applied=mj_data.xfrc_applied)
        # Use the nerual network to generate ctrl signal
        # Generate key
        
        mjx_data = mjx_data.replace(xfrc_applied=jp.array(mj_data.xfrc_applied, dtype=jp.float32))
        mjx_data = mjx_data.replace(
            qpos= jp.array(mj_data.qpos, dtype=jp.float32),
            qvel= jp.array(mj_data.qvel, dtype=jp.float32),
            time = jp.array(mj_data.time, dtype=jp.float32))
        
        # Update mjx_model from mj_model
        mjx_model = mjx_model.tree_replace({
            'opt.gravity': jp.array(mj_model.opt.gravity, dtype=jp.float32),
            'opt.tolerance': jp.array(mj_model.opt.tolerance, dtype=jp.float32),
            'opt.ls_tolerance': jp.array(mj_model.opt.ls_tolerance, dtype=jp.float32),
            'opt.timestep': jp.array(mj_model.opt.timestep, dtype=jp.float32),
        })
        
        # mjx_data = mjx_step(mjx_model, mjx_data)
        mjx_data, loss, key = jit_nn_multi_steps(controller_params, mjx_model, mjx_data, key)
        mjx.get_data_into(mj_data, mj_model, mjx_data)
        
        # Record the current time at the start of this frame
        current_frame_time = time.time()
    
        # Calculate the difference in time from the last frame
        time_between_frames = current_frame_time - previous_frame_time
    
        # Print the time between frames
        print(f"Time between frames: {time_between_frames} seconds")
        previous_frame_time = current_frame_time
        
        # print("ACT:", mjx_data.biomtu.act)
        # print(mjx_data.qpos)
        # print(mjx_data.sensordata)
        # print(len(mjx_data.qvel))
        viewer.sync()

## Test generate branch by stochastic policy neural network 

In [16]:
# Let's generate some observations and rewards

def reshape_data(datas, batch_size, fock_size):
    def fn(data):
        new_shape = [*(data.shape)][1:]
        new_shape[0]=batch_size*fock_size
        # print(new_shape)
        return data.reshape(new_shape)
    return jax.tree.map(fn,datas)

init_batch_size = 100
batch_size = init_batch_size
fock_size = 100

states_pool = []
rewards_pool = []
actions_pool = []

# Start with the fixed init state
key = jax.random.key(334)
keys = jax.random.split(key, batch_size)
datas = reset(mjx_model,batch_size)
# print(datas.qpos)
for i in range(100):
    datas, loss, keys = jit_v_nn_multi_steps(controller_params, mjx_model, datas, keys)
    print(datas.qvel.shape, datas.ten_J.shape)
    if(i == 20):
        keys = jax.vmap(jax.random.split, in_axes=(0,None))(keys, fock_size)
        datas, loss, keys = jit_vv_nn_multi_steps(controller_params, mjx_model, datas, keys)
        keys = keys.reshape(batch_size*fock_size)
        datas = reshape_data(datas, batch_size, fock_size)
        # print(datas.qvel.shape, datas.ten_J.shape)
        batch_size = batch_size*fock_size
        
print(datas.qpos)
print(datas.qpos.shape)

(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(100, 2) (100, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2) (10000, 0, 2)
(10000, 2

In [None]:
def change_shape(x):
    # print(x.shape)
    return x.reshape(batch_size*fock_size,-1) 
d = jax.tree.map(change_shape,datas)
print(d.qpos.shape)
print(datas.qpos.shape)


(50, 2)
(5, 10, 2)


In [None]:
print(keys.shape)
kk = jax.vmap(jax.random.split, in_axes=(0,None))(keys, 5)
print(kk.shape)
k = jp.reshape(kk, jp.array(kk.shape))
print(k.shape)

[*kk.shape][1:]
# print(kk.shape)

(50,)
(50, 5)
(50, 5)


5

## Learn To generate Critic Neural Network

In [None]:
print(mjx_data.ctrl)