In [1]:
from varyingsim.models.vq_vae import VQVAE
from varyingsim.datasets.fov_dataset import SmoothFovDataset
from varyingsim.envs.push_box_offset import PushBoxOffset
from varyingsim.models.osi import OSIModel
from varyingsim.models.feed_forward import FeedForward
from varyingsim.util.view import get_transform, global_to_local_obs
import torch, torchgeometry
import numpy as np
import matplotlib.pyplot as plt

In [2]:
H = 2

location = '/data/varyingsim/datasets/push_box_com_n_100_T_2000.pickle'
dataset = SmoothFovDataset(location, H, obs_skip=50, include_full=True)

env = PushBoxOffset()

# d_in = env.
# d_param = 
# d_share = 
# d_hidden_shared = 
# d_hidden_osi = 

# osi_model = OSIModel(env, H, d_in, d_param, d_share, d_hidden_shared, d_hidden_osi)

# model = VQVAE(env, H, encoder, decoder, k, d, device='cuda')

In [3]:
def plot_obs(obs, scale=10.0, lim=0.5):
    box_xy = obs[:2]
    pusher_xy = obs[7:9] - torch.tensor([ 0.2 ,0.0])
    box_orient = torchgeometry.quaternion_to_angle_axis(obs[3:7])
    vx = np.cos(box_orient[2]) / scale
    vy = np.sin(box_orient[2]) / scale
    plt.arrow(box_xy[0], box_xy[1], vx, vy, width=0.004, label='box_xy')
    plt.scatter(pusher_xy[0], pusher_xy[1], label='pusher_xy')
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    plt.show()

In [4]:
def render_obs(obs, T=1000):
    env = PushBoxOffset()
    env.reset()
    qpos, qvel = torch.split(obs, 9)
    env.set_state(qpos.numpy(), qvel.numpy())
    for t in range(T):
        env.sim.step()
        env.render()
    env.close()

In [5]:
obs = env.reset()
act = [0.5, 0.0, 0.0, 1.0]
obs, rew, done, info = env.step(act)
print('obs')
print(obs)
print('qpos')
print(env.sim.data.qpos)
act_t = torch.tensor(act).float()
obs_t = torch.tensor(obs).float()

act_90 = [0.5, 0.0, 0.0, 1.0]
env.reset()
obs_90, rew, done, info = env.step(act_90)
qpos, qvel = torch.split(torch.tensor(obs_90).float(), 9)
qpos[3:7] = torchgeometry.angle_axis_to_quaternion(torch.tensor([0,0, np.pi / 2]))
env.set_state(qpos.numpy(), qvel.numpy())
obs_90, rew, done, info = env.step(act_90)
print('obs_90')
print(obs_90)
print('qpos_90')
print(env.sim.data.qpos)
act_t_90 = torch.tensor(act_90).float()
obs_t_90 = torch.tensor(obs_90).float()

M, M_inv, all_angles = get_transform(obs_t)
M_90, M_90_inv, all_angles_90 = get_transform(obs_t_90)

obs_relative = global_to_local_obs(obs_t, M_inv, all_angles)
obs_relative_90 = global_to_local_obs(obs_t_90, M_90_inv, all_angles_90)

obs
[-3.91630367e-06  4.16045359e-23  9.99868682e-02  1.00000000e+00
 -4.35181641e-22 -1.58072301e-05  5.89387006e-22  3.49523994e-01
  4.74351966e-22 -1.29018759e-03 -6.98785463e-22 -3.45926836e-03
 -1.85035569e-19 -8.48796378e-03  3.54680008e-19 -1.12709195e-01
  8.68548480e-20]
qpos
[-3.91630367e-06  4.16045359e-23  9.99868682e-02  1.00000000e+00
 -4.35181641e-22 -1.58072301e-05  5.89387006e-22  3.49523994e-01
  4.74351966e-22]
obs_90
[-8.21194763e-06 -1.47160374e-06  9.99736207e-02  7.07106783e-01
  1.14297631e-05 -1.16786245e-05  7.07106780e-01  1.99996067e-01
  1.49522002e-01 -8.83270450e-04 -7.30372325e-04 -3.16069381e-03
 -1.53155896e-04 -8.00996280e-03 -2.00026012e-06 -7.95736027e-06
 -1.12688755e-01]
qpos_90
[-8.21194763e-06 -1.47160374e-06  9.99736207e-02  7.07106783e-01
  1.14297631e-05 -1.16786245e-05  7.07106780e-01  1.99996067e-01
  1.49522002e-01]


In [8]:
print(obs_t)
print(obs_t_90)
print()
print(obs_relative)
print(obs_relative_90)

# plot_obs(obs_t)
# plot_obs(obs_t_90)

# render_obs(obs_t_90)
render_obs(obs_t)


tensor([-3.9163e-06,  4.1605e-23,  9.9987e-02,  1.0000e+00, -4.3518e-22,
        -1.5807e-05,  5.8939e-22,  1.4952e-01,  4.7435e-22, -1.2902e-03,
        -6.9879e-22, -3.4593e-03, -1.8504e-19, -8.4880e-03,  3.5468e-19,
        -1.1271e-01,  8.6855e-20])
tensor([-8.2119e-06, -1.4716e-06,  9.9974e-02,  7.0711e-01,  1.1430e-05,
        -1.1679e-05,  7.0711e-01, -3.9339e-06,  1.4952e-01, -8.8327e-04,
        -7.3037e-04, -3.1607e-03, -1.5316e-04, -8.0100e-03, -2.0003e-06,
        -7.9574e-06, -1.1269e-01])

tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
        -2.5644e-43,  0.0000e+00, -1.4953e-01, -2.5648e-22, -1.2903e-03,
        -6.9425e-22, -3.4592e-03, -1.8503e-19, -8.4880e-03,  3.5468e-19,
        -1.1271e-01,  8.6988e-20])
tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
        -9.0949e-13,  0.0000e+00, -1.4952e-01,  4.2845e-06, -7.3048e-04,
         8.8327e-04, -3.1607e-03, -8.0100e-03,  1.5316e-04, -1.7384e-06,
        -1.1269e-0

In [7]:
# model trainer takes in model and trains it
# TODO: compare relative observation to absolute!!!
