In [1]:
import pybullet as p
import pybullet_data as pd
import math
import time
import numpy as onp
from replay_buffer import ReplayBuffer
from panda_chef import PandaChefEnv
from IPython.display import clear_output

In [2]:
env = PandaChefEnv(render=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]


In [3]:
import jax
import jax.numpy as np
import jax.random as jnp_random
from jax import grad, jacfwd, jit, partial, vmap
from jax.lax import scan
from jax.experimental import stax, optimizers
from jax.ops import index, index_add, index_update

In [4]:
model_init, model = stax.serial(
    stax.Dense(128), stax.Relu, 
    stax.Dense(128), stax.Relu,
    stax.Dense(state_dim+1)
)
rng = jnp_random.PRNGKey(0)
in_shape = (-1, state_dim+action_dim+1)
out_shape, params = model_init(rng, in_shape)


def aug_state(r, s):
    return onp.concatenate([[r], s])

def log_prob(mu, log_std, y):
    std = np.exp(np.clip(log_std, -5, 2))
    return -0.5*np.sum(np.square(mu-y)/std) \
                -0.5*np.sum(np.clip(log_std, -5, 2))
df = jit(jacfwd(model, argnums=1))
@jit    
def loss(params, batch):
    x, y, u = batch
    mu = model(params, np.concatenate([x,u],axis=1))
#     v = df(params, np.concatenate([x,u],axis=1))
    return 0.5*np.mean(np.square(mu-y)) #+ 1e-3 * np.mean(v**2)


@jit
def step(i, opt_state, batch):
    params = get_params(opt_state)
    g = grad(loss)(params, batch)
    return opt_update(i, g, opt_state)


def batch_update(i, opt_state, replay_buffer, batch_size, verbose=False):
    state, action, reward, next_state, next_reward = replay_buffer.sample(batch_size)
    x = np.concatenate([reward.reshape(-1,1), state], axis=1)
    y = np.concatenate([next_reward.reshape(-1,1), next_state], axis=1)
    if verbose:
        print(loss(get_params(opt_state), (x,y, action)))
    return step(i, opt_state, (x, y, action))


@jit
def f(x, u, params): 
    out = model(params, np.concatenate([x, u], axis=1))
    mu = out
    return mu, mu[:,0]

@jit
def weight_update(eps, sk):
#     sk = sk + 0.1 * log_prob(0.,np.log(0.1), eps)
    sk = sk - np.max(sk)
    w = np.exp(sk/0.1)+1e-5
    w = w/np.sum(w)
    return np.dot(w, eps)
@jit
def predict_from_samples(u, params, x0, eps):
    mu, rew = scan(partial(f, params=params), x0, np.expand_dims(u, 1)+eps)
    sk = rew
    sk = np.cumsum(sk[::-1], axis=0)[::-1]
    du = vmap(weight_update)(eps, sk)
    return u+du

n_samples = 100
@jit
def mppi(u, x0, params, key):
    key, subkey = jnp_random.split(key)
    eps = jnp_random.normal(subkey, shape=(u.shape[0], n_samples, u.shape[1])) * 0.1
    x0 = x0.reshape(1,-1).repeat(n_samples, axis=0)
    u = predict_from_samples(u, params, x0, eps)
    return u, key

opt_init, opt_update, get_params = optimizers.adam(step_size=3e-4)




In [5]:
frame_idx = 0
max_frames = 100000
batch_size = 128
key = jnp_random.PRNGKey(0)
opt_state = opt_init(params)
u = np.zeros((20,action_dim))
replay_buffer = ReplayBuffer(100000)
i = 0

while frame_idx < max_frames:
    reward = 0.
    ep_reward = 0.
    state = env.reset()
    u = np.zeros_like(u)
    for t in range(300):
        
        u, key = mppi(u, aug_state(reward,state), get_params(opt_state), key)
        action = onp.clip(u[0].copy(), -1, 1)
        next_state, next_reward, done, _ = env.step(action)
        
        replay_buffer.push(state, action, reward, next_state, next_reward)
        state = next_state
        reward = next_reward
        ep_reward += reward
        u = index_update(u, index[:-1,:], u[1:,:])
        u = index_update(u, index[-1,:], 0.)
        if len(replay_buffer)>batch_size:
            clear_output(wait=True)
            opt_state = batch_update(i, opt_state, replay_buffer, batch_size, False)
            i +=1 
        if done:
            break
        
    print(ep_reward)
        


KeyboardInterrupt: 

In [13]:
u

DeviceArray([[ 3.7905801e-02,  9.2090685e-03,  1.8728929e-02],
             [ 3.7591286e-02,  8.7462729e-03,  1.4024430e-02],
             [ 3.0980578e-02,  5.5928556e-03,  8.2805436e-03],
             [ 2.7284751e-02,  3.7955828e-03,  5.1098219e-03],
             [ 2.3495577e-02,  1.2316168e-03,  6.9543985e-03],
             [ 2.0937685e-02,  1.6737508e-04,  1.7669548e-03],
             [ 1.9587738e-02,  1.2163789e-04, -2.9218011e-03],
             [ 1.7357092e-02,  2.9553066e-05, -3.7339705e-03],
             [ 1.6358618e-02,  2.9275171e-04, -2.3641556e-03],
             [ 1.3893484e-02,  4.8241421e-04, -1.8130792e-03],
             [ 1.1825153e-02,  7.1047043e-04, -1.5654473e-03],
             [ 1.0065615e-02,  6.1341288e-04, -2.0694488e-03],
             [ 7.8180451e-03,  4.1873782e-04, -2.4115341e-03],
             [ 6.2643229e-03,  1.1565544e-04, -2.3565756e-03],
             [ 4.9185231e-03,  1.4997541e-04, -2.0029189e-03],
             [ 3.5861032e-03,  9.3903531e-05, -1.656726