In [1]:
import pybullet as p
import pybullet_data as pd
import math
import time
import numpy as onp
from replay_buffer import ReplayBuffer
from panda_chef import PandaChefEnv
from IPython.display import clear_output

In [2]:
env = PandaChefEnv(render=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]


In [3]:
import jax
import jax.numpy as np
import jax.random as jnp_random
from jax import grad, jacfwd, jit, partial, vmap
from jax.lax import scan
from jax.experimental import stax, optimizers
from jax.ops import index, index_add, index_update

In [4]:
model_init, model = stax.serial(
    stax.Dense(200), stax.Selu, 
    stax.Dense(200), stax.Selu,
    stax.Dense(state_dim+1)
)
rng = jnp_random.PRNGKey(0)
in_shape = (-1, state_dim+action_dim+1)
out_shape, params = model_init(rng, in_shape)


def aug_state(r, s):
    return onp.concatenate([[r], s])

def log_prob(mu, y):
    return -0.5*np.sum(np.square(mu-y))
    
def loss(params, batch):
    x, y, u = batch
    mu = model(params, np.concatenate([x,u],axis=1))
    return -np.mean(vmap(log_prob)(mu, y))


@jit
def step(i, opt_state, batch):
    params = get_params(opt_state)
    g = grad(loss)(params, batch)
    return opt_update(i, g, opt_state)


def batch_update(i, opt_state, replay_buffer, batch_size, verbose=False):
    state, action, reward, next_state, next_reward = replay_buffer.sample(batch_size)
    x = np.concatenate([reward.reshape(-1,1), state], axis=1)
    y = np.concatenate([next_reward.reshape(-1,1), next_state], axis=1)
    if verbose:
        print(loss(get_params(opt_state), (x,y, action)))
    return step(i, opt_state, (x, y, action))


@jit
def f(state, u, params): 
    x, key = state
    out = model(params, np.concatenate([x, u]))
    mu = out
    return (mu, key), mu[0]

@jit
def ell(u, x0, params, key):
    mu, rew = scan(partial(f, params=params), (x0, key), u)
    return np.sum(rew)

dell = jit(grad(ell))

def mpc(u, x0, params, key):
    key, subkey = jnp_random.split(key)
    for i in range(10):
        du = dell(u, x0, params, subkey)
        u = u + 1e-1 * du
    return u, key

opt_init, opt_update, get_params = optimizers.adam(step_size=3e-4)




In [5]:
frame_idx = 0
max_frames = 100000
batch_size = 128
key = jnp_random.PRNGKey(0)
opt_state = opt_init(params)
u = np.zeros((20,action_dim))
replay_buffer = ReplayBuffer(100000)
i = 0

while frame_idx < max_frames:
    reward = 0.
    ep_reward = 0.
    state = env.reset()
    u = np.zeros_like(u)
    for t in range(300):
        
        u, key = mpc(u, aug_state(reward, state), get_params(opt_state), key)
        action = onp.clip(u[0].copy(), -1, 1)
        next_state, next_reward, done, _ = env.step(action)
        
        replay_buffer.push(state, action, reward, next_state, next_reward)
        state = next_state
        reward = next_reward
        ep_reward += reward
        u = index_update(u, index[:-1,:], u[1:,:])
        u = index_update(u, index[-1,:], 0.)
        if len(replay_buffer)>batch_size:
#             clear_output(wait=True)
            opt_state = batch_update(i, opt_state, replay_buffer, batch_size, False)
            i +=1 
        if done:
            break
        
    print(ep_reward)
        


-15.649641451766556
77.46600965126996
-69.36713196207081
-9.062986224120717
136.85661171288726
4.23972847511782
-35.01927899368862
-52.41269172828994
-10.464039875137734
23.603895854964403
-24.996333980441435
9.676947965716138
-2.9444670352354305
13.864147546138925
-22.16359433329475
36.43540205922821
140.7477316182396
19.912598721526916
18.738701834828316
15.7310477804
28.60532848129464
-22.033060402505715
38.22067866879841
83.05165713808368
81.22392451173704
-15.344561113887792
-8.988744263149524
25.916904788008782
-8.357589563005003
117.48807207458842
114.97919616856743
143.10624725152073
82.70972856050273
-16.534199244044583
16.767105443328923
14.144352307329495
23.573663652456947
13.606659215771009
3.9672852850142224
68.95457104953603
4.33851616368476
13.118408474918915
17.97621921794389
-5.043261243791158


KeyboardInterrupt: 

In [13]:
u

DeviceArray([[ 3.7905801e-02,  9.2090685e-03,  1.8728929e-02],
             [ 3.7591286e-02,  8.7462729e-03,  1.4024430e-02],
             [ 3.0980578e-02,  5.5928556e-03,  8.2805436e-03],
             [ 2.7284751e-02,  3.7955828e-03,  5.1098219e-03],
             [ 2.3495577e-02,  1.2316168e-03,  6.9543985e-03],
             [ 2.0937685e-02,  1.6737508e-04,  1.7669548e-03],
             [ 1.9587738e-02,  1.2163789e-04, -2.9218011e-03],
             [ 1.7357092e-02,  2.9553066e-05, -3.7339705e-03],
             [ 1.6358618e-02,  2.9275171e-04, -2.3641556e-03],
             [ 1.3893484e-02,  4.8241421e-04, -1.8130792e-03],
             [ 1.1825153e-02,  7.1047043e-04, -1.5654473e-03],
             [ 1.0065615e-02,  6.1341288e-04, -2.0694488e-03],
             [ 7.8180451e-03,  4.1873782e-04, -2.4115341e-03],
             [ 6.2643229e-03,  1.1565544e-04, -2.3565756e-03],
             [ 4.9185231e-03,  1.4997541e-04, -2.0029189e-03],
             [ 3.5861032e-03,  9.3903531e-05, -1.656726