In [1]:
import numpy as np

from common.evaluate import evaluate_pose_error_J3d_P2d, mmd_J3d_J3d
from paik.solver import NSF, PAIK, Solver, get_solver
import torch
# set the same random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f1df4266170>

In [2]:
robot_name = 'panda'
nsf = get_solver(arch_name="nsf", robot_name=robot_name, load=True, work_dir='/home/luca/paik')

WorldModel::LoadRobot: /home/luca/.cache/jrl/temp_urdfs/panda_arm_hand_formatted_link_filepaths_absolute.urdf
joint mimic: no multiplier, using default value of 1 
joint mimic: no offset, using default value of 0 
URDFParser: Link size: 17
URDFParser: Joint size: 12
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link0.dae (59388 verts, 20478 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link1.dae (37309 verts, 12516 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link2.dae (37892 verts, 12716 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link3.dae (42512 verts, 14233 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link4.dae (43520 verts, 14620 tris)
LoadAssimp: Loaded model /ho

In [3]:
num_poses = 1
num_sols = 10
nsf.random_ikp(num_poses=num_poses, num_sols=num_sols)
Q, P = nsf.robot.sample_joint_angles_and_poses(n=num_poses)

[INFO] Retrieve latent ids: [3 4 3 6 5 8 8 9 4 9]
              l2        ang
count  10.000000  10.000000
mean    0.004062   0.711490
std     0.000673   0.366385
min     0.003397   0.197257
25%     0.003495   0.402913
50%     0.004000   0.802059
75%     0.004172   0.872497
max     0.005201   1.172909
  l2 (mm)    ang (deg)    inference_time (ms)
---------  -----------  ---------------------
     4.06         0.71                    242


In [5]:
ik_1 = nsf.generate_ik_solutions(P, num_sols=num_sols)

[INFO] Retrieve latent ids: [9 0 8 2 2 8 2 6 2 9]


In [7]:
ik_1[3], ik_1[4]

(array([[ 0.68666255,  0.60162404, -1.89308982, -2.32056342, -1.63677905,
          1.40748783,  3.02942408]]),
 array([[ 0.68666255,  0.60162404, -1.89308982, -2.32056342, -1.63677905,
          1.40748783,  3.02942408]]))

In [4]:
z = np.random.randn(num_sols, num_poses, nsf.n) * 0.25
J_hat = nsf.generate_ik_solutions_z(P, z)

[INFO] zsample: torch.Size([10, 10, 7])


In [5]:
l2, ang = evaluate_pose_error_J3d_P2d(
    # input J.shape = (num_sols, num_poses, num_dofs or n)
    nsf.robot, J_hat, P, return_all=True
)
l2.mean(), ang.mean()

(0.002391491343725345, 0.026514987178270357)

In [6]:
Z = nsf.generate_z_from_ik_solutions(P, J_hat)

In [7]:
C = nsf._get_conditions(nsf.P)
        
batch_size = 4000
C_batch, complementary = nsf._get_divisible_conditions(C, batch_size)
J_batch, _ = nsf._get_divisible_conditions(nsf.J, batch_size)
C_batch = C_batch.reshape(-1, batch_size, C_batch.shape[-1])
J_batch = J_batch.reshape(-1, batch_size, J_batch.shape[-1])
C_batch = nsf.normalize_input_data(C_batch, "C", return_torch=True)
J_batch = nsf.normalize_input_data(J_batch, "J", return_torch=True)

In [8]:
C_batch.shape

torch.Size([1250, 4000, 8])

In [10]:
Z = nsf.generate_z_from_dataset()
nsf.Z = Z

100%|██████████| 1250/1250 [00:46<00:00, 27.06it/s]


In [11]:
Z.shape

(5000000, 7)

In [24]:
P.shape, Q.shape

((10, 7), (10, 7))

In [25]:
k = 3
ids = nsf.J_knn.kneighbors(Q, n_neighbors=k, return_distance=False)
# shape: (num_poses, k) -> shape: (k, num_poses)  -> shape: (k * num_poses)   
ids = ids.T.flatten()
Z_from_retriever = nsf.Z[ids].reshape(k, -1, nsf.n)
Z_from_retriever.shape

(3, 10, 7)

In [26]:
J_hat = nsf.generate_ik_solutions(P)

[INFO] zsample: torch.Size([3, 10, 7])


In [29]:
l2, ang = evaluate_pose_error_J3d_P2d(
    # input J.shape = (num_sols, num_poses, num_dofs or n)
    nsf.robot, J_hat, P, return_all=True
)

J_hat, l2.mean(), ang.mean()

(array([[[-3.31831395e-01,  5.53793683e-01, -2.93761987e-01,
          -1.00853130e+00,  2.98260301e-01,  3.52948691e+00,
           2.86497496e+00],
         [ 9.58107473e-01,  1.53841400e+00, -1.60083879e+00,
          -1.03044418e+00, -2.81750608e+00,  1.17028075e+00,
           2.65757530e+00],
         [ 1.26421746e+00, -1.08115221e+00, -6.19726572e-01,
          -9.68137730e-01,  1.75946170e+00,  5.96916836e-01,
           6.98580843e-04],
         [ 2.95476622e+00, -1.60395111e+00, -2.41647105e+00,
          -1.71153156e+00,  2.55749870e+00,  3.80441753e-01,
           2.64321100e+00],
         [ 1.63458720e-01, -5.87918855e-01,  2.15731162e+00,
          -2.05632113e+00,  5.77481415e-01,  1.74245126e+00,
           2.04423553e+00],
         [-1.98495390e+00,  1.50547179e+00, -1.42433026e+00,
          -2.81718394e+00, -1.20824063e-02,  2.98994683e+00,
          -4.54542510e-01],
         [ 1.52867439e+00,  2.00881696e-01, -2.63599288e+00,
          -9.19316771e-01,  5.36244304e

In [28]:
Q

array([[-0.47813367,  0.75401898,  0.38477857, -0.34651987,  1.56624594,
         3.37616003,  1.23805015],
       [ 0.9516884 ,  1.61753734, -1.71370916, -1.02782358, -2.69299185,
         1.19453485,  2.69865112],
       [ 1.17901648, -1.02719825, -0.73656699, -0.97782709,  1.95747637,
         0.52060569,  0.0307969 ],
       [ 2.85962467, -1.70733469, -2.34655172, -1.67812004,  2.29232654,
         0.45315506,  2.48292335],
       [ 0.34651804, -0.83356563,  1.9193927 , -2.0292541 ,  0.7855137 ,
         1.59226584,  1.89213008],
       [-1.96927503,  1.56327791, -1.35026079, -2.80812454,  0.30779575,
         2.97060507, -0.69504103],
       [ 0.36424706,  0.49430428, -2.40168094, -0.83186423,  1.63553567,
         2.44896337,  1.07883102],
       [-0.17201711, -0.84923803,  2.13306025, -2.38871523, -2.44976751,
         0.34778621,  0.51843398],
       [-2.72520849, -1.30546461, -2.28180756, -1.75879815, -1.92776707,
         3.37079207, -2.25003372],
       [-2.81876333,  0.8899

In [15]:
from sklearn.neighbors import KernelDensity
import numpy as np
rng = np.random.RandomState(42)
X = rng.random_sample((1000, 10))
kde = KernelDensity(kernel='gaussian', bandwidth='scott').fit(X)
log_density = kde.score_samples(X[:3])
density = np.exp(log_density)
density

array([0.00188191, 0.0018762 , 0.00262504])

In [17]:
def retrieve_latent_random(nsf, P, k=3):
    # shape: (num_poses, k) -> shape: (k, num_poses)  -> shape: (k * num_poses)   
    ids = np.random.choice(nsf.Z.shape[0], size=k * P.shape[0])
    Z_from_retriever = nsf.Z[ids].reshape(k, -1, nsf.n)
    return Z_from_retriever

J_hat = nsf.generate_ik_solutions(P, latent=retrieve_latent_random(nsf, P, k=3))

l2, ang = evaluate_pose_error_J3d_P2d(
    # input J.shape = (num_sols, num_poses, num_dofs or n)
    nsf.robot, J_hat, P, return_all=True
)

J_hat, l2.mean(), ang.mean()

[INFO] zsample: torch.Size([3, 10, 7])


(array([[[-0.56983889, -1.75718682,  1.92132765, -1.49288864,
           2.63523365,  2.13125222, -1.31622444],
         [-2.53703618, -1.08433767,  2.46056838, -1.2939548 ,
           0.36704743,  1.0854832 ,  0.09402749],
         [ 2.11695889,  1.54050658, -1.53919242, -1.42744076,
          -1.04969303,  2.17757248,  0.03774959],
         [-1.63435316, -0.33219711,  0.35427749, -2.13327943,
          -2.43783981,  2.29258477, -0.27085495],
         [-2.18398905,  1.25297939,  2.3516816 , -2.46395189,
           0.07793708,  1.37128505, -1.19071427],
         [ 0.55906822,  1.69489908, -2.7262792 , -2.11871862,
           2.55428871,  2.8924417 , -3.25569602],
         [-0.57813888, -1.42137614,  0.87403833, -0.33850012,
          -1.9896887 ,  0.93833362, -2.51586121],
         [ 0.9807585 , -0.49972024, -0.42998273, -2.06959917,
           0.3926429 ,  1.943103  , -0.0969265 ],
         [ 0.66406813, -1.19950841,  1.33142633, -0.60559512,
           1.80216964,  2.95753846, -2.617