In [1]:
import pybullet as p
import pybullet_data as pd
import math
import time
import numpy as onp
from replay_buffer import ReplayBuffer
from panda_chef import PandaChefEnv

In [2]:
env = PandaChefEnv(render=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]


In [11]:
import jax
import jax.numpy as np
import jax.random as jnp_random
from jax import grad, jacfwd, jit, partial, vmap
from jax.lax import scan
from jax.experimental import stax, optimizers
from jax.ops import index, index_add, index_update

In [16]:
model_init, model = stax.serial(
    stax.Dense(64), stax.Relu, 
    stax.Dense(64), stax.Relu,
    stax.Dense((state_dim+1)*2)
)
rng = jnp_random.PRNGKey(0)
in_shape = (-1, state_dim+action_dim+1)
out_shape, params = model_init(rng, in_shape)


def aug_state(r, s):
    return onp.concatenate([[r], s])

def log_prob(mu, log_std, y):
    std = np.exp(np.clip(log_std, -5, 2))
    return -0.5*np.sum(np.square(mu-y)/std) \
                -0.5*np.sum(np.clip(log_std, -5, 2))
    
def loss(params, batch):
    x, y, u = batch
    mu, log_std = np.split(model(params, np.concatenate([x,u],axis=1)), 2, axis=1)
    return -np.mean(vmap(log_prob)(mu+x, log_std, y))


@jit
def step(i, opt_state, batch):
    params = get_params(opt_state)
    g = grad(loss)(params, batch)
    return opt_update(i, g, opt_state)


def batch_update(i, opt_state, replay_buffer, batch_size, verbose=False):
    state, action, reward, next_state, next_reward = replay_buffer.sample(batch_size)
    x = np.concatenate([reward.reshape(-1,1), state], axis=1)
    y = np.concatenate([next_reward.reshape(-1,1), next_state], axis=1)
    if verbose:
        print(loss(get_params(opt_state), (x,y, action)))
    return step(i, opt_state, (x, y, action))


@jit
def f(state, u, params): 
    x, key = state
    out = model(params, np.concatenate([x, u]))
    mu, log_std = np.split(out, 2)
    return (mu+x, key), (mu+x)[0]

@jit
def ell(u, x0, params, key):
    mu, rew = scan(partial(f, params=params), (x0, key), u)
    return np.sum(rew)

dell = jit(grad(ell))

def mpc(u, x0, params, key):
    key, subkey = jnp_random.split(key)
    for i in range(10):
        du = dell(u, x0, params, subkey)
        u = u + 1e-2 * du
    return u, key

opt_init, opt_update, get_params = optimizers.adam(step_size=3e-3)


In [18]:
frame_idx = 0
max_frames = 100000
batch_size = 128
key = jnp_random.PRNGKey(0)
opt_state = opt_init(params)
u = np.zeros((50,action_dim))
replay_buffer = ReplayBuffer(100000)
i = 0

while frame_idx < max_frames:
    reward = 0.
    ep_reward = 0.
    state = env.reset()
    u = np.zeros_like(u)
    for t in range(300):
        
        u, key = mpc(u, aug_state(reward, state), get_params(opt_state), key)
        action = onp.clip(u[0].copy(), -1, 1)
        next_state, next_reward, done, _ = env.step(action)
        
        replay_buffer.push(state, action, reward, next_state, next_reward)
        state = next_state
        reward = next_reward
        ep_reward += reward
        u = index_update(u, index[:-1,:], u[1:,:])
        u = index_update(u, index[-1,:], 0.)
        if len(replay_buffer)>batch_size:
            opt_state = batch_update(i, opt_state, replay_buffer, batch_size, False)
            i +=1 
        if done:
            break
        
    print(ep_reward)
        


9.927571333019193
-2.1441011084033077
-1.3655115160510822
17.617692337907602
-1.2913681024706418
-0.31643383975530054
4.004122968131249
-2.1862341915683245
-3.694580600755451
-2.0677407862415422
-7.867627597760358
-0.3184463377461689
10.95705856877712
-0.9538330962798951
12.941825627351596
1.5140432613863268
-2.8484324910522583
-2.9778203526488043
-4.700426955304174
1.5850683639009082
8.415367285176485
-2.3325739087019337
8.8480963842855
4.931654877772724
-2.515329105139924
-6.895318186032073
-1.0526172060144108
-7.325666395777756
3.8338661465518427
11.360403521390305
-0.7459321090668616
6.951572468584835
12.550339453034507
8.917955756534631
-2.951873463482971
-1.241161994507907
7.623577299146231
-4.084252624473659
6.509774828490861
8.795926955993835
0.24840437599710846
-3.1877631797162893
11.543814561986943
5.575382306324485
-2.156468789561664
-3.3366880352543333
7.023900682089261
3.0017699903442483
-1.9666160634358347
12.290379465782907
-0.8028695985670827
8.798727079506873
0.6157412

KeyboardInterrupt: 

In [None]:
# p.loadURDF('./pan_tefal/pan_tefal.urdf', np.array([1.0, 1., 0.2]))
for _ in range(100):
    p.stepSimulation()

In [None]:

panda.step(np.array([targetPosX, targetPosY, targetPosTH]))

In [None]:
while True:
    targetPosX = p.readUserDebugParameter(targetPosXId)
    targetPosY = p.readUserDebugParameter(targetPosYId)
    targetPosTH = p.readUserDebugParameter(targetPosTHId)
    panda.step(np.array([targetPosX, targetPosY, targetPosTH]))
    time.sleep(1./60.)
#     print(p.getBasePositionAndOrientation(pizza))
#     print(p.getBaseVelocity(pizza))

In [7]:
p.getLinkState(panda.robot_id, 12)

((0.7332549623147928, -9.028825902486918e-05, 0.23101649172868027),
 (-0.05122355687644141,
  1.6301575308878612e-06,
  0.9986872102885749,
  5.6782854699545384e-05),
 (0.0, 0.0, 0.0),
 (0.0, 0.0, 0.0, 1.0),
 (0.7332549691200256, -9.028826025314629e-05, 0.23101648688316345),
 (-0.05122355744242668,
  1.6301576124533312e-06,
  0.998687207698822,
  5.678285378962755e-05))

In [20]:
p.getLinkState(pizza, 0)

In [4]:
p.configureDebugVisualizer(p.COV_ENABLE_PLANAR_REFLECTION, groundid)

In [5]:
panda.panda

NameError: name 'panda' is not defined

In [6]:
panda = p.loadURDF("franka_panda/panda.urdf", np.array([0,0,0]), useFixedBase=True, flags=flags)
