In [14]:
from typing import Any, Optional
import numpy as np
from time import time
import matplotlib.pyplot as plt

import pandas as pd
import scipy
from tqdm import tqdm

from common.evaluate import evaluate_pose_error_J3d_P2d
from paik.solver import NSF, PAIK, Solver, get_solver
from ikp import get_robot, numerical_inverse_kinematics_batch, compute_mmd, gaussian_kernel, inverse_multiquadric_kernel

import torch
from functools import partial
import os
import itertools
from tqdm.contrib import itertools as tqdm_itertools

from paik.file import load_pickle, save_pickle
from latent_space_sampler import Retriever


# set the same random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7effad6d6870>

In [19]:
def solver_batch(solver, P, num_sols, std=0.001, retriever: Optional[Retriever] = None, J_ref=None, radius=0.0, num_clusters=30, num_seeds_per_pose=10, use_samples=int(5e6), verbose=False, retr_type='cluster'):
    # shape: (num_sols, num_poses, m)
    P_num_sols = np.expand_dims(P, axis=0).repeat(num_sols, axis=0)
    # shape: (num_sols*num_poses, n)
    P_num_sols = P_num_sols.reshape(-1, P.shape[-1])
    
    J_ref_num_sols = None
    if J_ref is not None:
        J_ref_num_sols = np.expand_dims(J_ref, axis=0).repeat(num_sols, axis=0)
        J_ref_num_sols = J_ref_num_sols.reshape(-1, J_ref.shape[-1])

    if isinstance(solver, PAIK):
        solver.base_std = std
        F = solver.get_reference_partition_label(P=P, J=J_ref, num_sols=num_sols)
        # shape: (1, num_sols*num_poses, n)
        J_hat = solver.generate_ik_solutions(P=P_num_sols, F=F, verbose=verbose)
    elif isinstance(solver, NSF):
        if retriever is None:
            solver.base_std = std
            J_hat = solver.generate_ik_solutions(P=P, num_sols=num_sols)
        else:
            if retr_type == 'cluster':
                latents = retriever.cluster_retriever(seeds=J_ref, num_poses=P.shape[0], num_sols=num_sols, max_samples=use_samples, radius=radius, n_clusters=num_clusters)
            elif retr_type == 'random':
                latents = retriever.random_retriever(seeds=J_ref, num_poses=P.shape[0], max_samples=use_samples, num_sols=num_sols, radius=radius)
            elif retr_type == 'numerical':
                latents = retriever.numerical_retriever(poses=P, seeds=J_ref, num_sols=num_sols, num_seeds_per_pose=num_seeds_per_pose, radius=radius)
            J_hat = solver.generate_ik_solutions(P=P_num_sols, latents=latents, verbose=verbose)
    else:
        J_hat = np.empty((num_sols, P.shape[0], solver.robot.n_dofs))
        P_torch = torch.tensor(P, dtype=torch.float32).to('cuda')
        for i, p in enumerate(P_torch):
            solutions = solver.generate_ik_solutions(
                p,
                num_sols,
                latent_distribution='gaussian',
                latent_scale=std,
                clamp_to_joint_limits=False,
            )
            J_hat[:, i] = solutions.detach().cpu().numpy()
    # return shape: (num_sols, num_poses, n)
    return J_hat.reshape(num_sols, P.shape[0], -1)


def random_ikp(robot, poses: np.ndarray, solve_fn_batch: Any, num_poses: int, num_sols: int, J_hat_num: Optional[np.ndarray] = None):
    begin = time()
    # shape: (num_poses, num_sols, num_dofs or n)
    J_hat = solve_fn_batch(P=poses, num_sols=num_sols)
    assert J_hat.shape == (
        num_sols, num_poses, robot.n_dofs), f"J_hat shape {J_hat.shape} is not correct"

    l2, ang = evaluate_pose_error_J3d_P2d(
        #init(num_sols, num_poses, num_dofs or n)
        robot, J_hat, poses, return_all=True
    )
    
    num_sols_time_ms = round((time() - begin) / len(poses), 3) * 1000
    
    ret_results = {}
    l2_mean = np.nanmean(l2)
    ang_mean = np.nanmean(ang)
    
    ret_results[f'{num_poses}_{num_sols}'] = {
        "l2_mm": l2_mean * 1000,
        "ang_deg": np.rad2deg(ang_mean),
        "num_sols_time_ms": num_sols_time_ms
    }
    
    if J_hat_num is None:
        mmd_guassian = np.nan
        mmd_imq = np.nan
    else:
        mmd_guassian_list = np.empty((num_poses))
        mmd_imq_list = np.empty((num_poses))
        for i in range(num_poses):
            mmd_guassian_list[i] = compute_mmd(J_hat[:, i], J_hat_num[:, i], kernel=gaussian_kernel)
            mmd_imq_list[i] = compute_mmd(J_hat[:, i], J_hat_num[:, i], kernel=inverse_multiquadric_kernel)
        mmd_guassian = mmd_guassian_list.mean()
        mmd_imq = mmd_imq_list.mean()
        
    ret_results[f'{num_poses}_{num_sols}']['mmd_guassian'] = mmd_guassian
    ret_results[f'{num_poses}_{num_sols}']['mmd_imq'] = mmd_imq
    
    ret_results['ik_sols'] = J_hat

    return ret_results

def plot_3d_scatter(array, file_path):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(array[:, -1], array[:, -2], array[:, -3])
    ax.set_xlabel('Last 1 Dimension')
    ax.set_ylabel('Last 2 Dimension')
    ax.set_zlabel('Last 3 Dimension')
    plt.savefig(file_path)
    plt.show()
    
def parse_ik_sols(results: dict):
    cp = results.copy()
    ik_sols = {}
    for k, v in cp.items():
        if 'ik_sols' in v:
            ik_sols[k] = v['ik_sols']
    return ik_sols
    
    
def plot_random_3d_joints_scatter(keys, ik_sols_dict: dict, c: dict, marker: dict, label: dict, seeds: np.ndarray, file_path: str, joint_nums=[-1, -2, -3]):
    fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
    
    x, y, z = joint_nums
    
    j = 0
    colors = plt.cm.tab20.colors
    markers = ['o', 's', 'd', '^', 'v', '<', '>', 'p', 'P', '*', 'h', 'H', '+', 'x', 'X', '|', '_', '1', '2', '3', '4']
    
    ax.scatter(seeds[:, x], seeds[:, y], seeds[:, z], c='brown', marker='p', label='Seeds', s=100)
    for i in keys:
        Ji = ik_sols_dict[i].reshape(-1, ik_sols_dict[i].shape[-1])
        
        # if c[i], marker[i], label[i] (dict) is not exist, use a default color, marker, label
        if i not in c:
            # color is a tuple of RGB (0-1)
            c[i] = np.atleast_2d(colors[j])
            j += 1

        if i not in marker:
            marker[i] = markers[j]
            j += 1

        if i not in label:
            label[i] = i.upper()
                        
        if i == 'numerical':
            # alpha is used to make the NUM solution more transparent
            ax.scatter(Ji[:, x], Ji[:, y], Ji[:, z], c='gray', marker='x', label=label[i], alpha=0.6, s=50)
        else:
            ax.scatter(Ji[:, x], Ji[:, y], Ji[:, z], c=c[i], marker=marker[i], label=label[i], s=70)
    
    ax.set_xlabel(f'Joint {x}')
    ax.set_ylabel(f'Joint {y}')
    ax.set_zlabel(f'Joint {z}')
    
    ax.legend()
    plt.show()    
    fig.savefig(file_path)


def selective_ik(record_dir, robot_name, num_poses, num_sols, paik_std_list, radius_list, num_clusters_list, num_seeds_per_pose_list):
    
    robot = get_robot(robot_name)
    nsf = get_solver(arch_name="nsf", robot=robot, load=True, work_dir='/home/luca/paik')
    retriever = Retriever(nsf)
    max_samples = int(5e6)
    retriever.init([max_samples], num_clusters_list)
    paik = get_solver(arch_name="paik", robot=robot, load=True, work_dir='/home/luca/paik')
    
    func_name = f"selective_ik_{robot_name}_{num_sols}"
    file_path = f"{record_dir}/{func_name}.pkl"

    results = {}
    # if os.path.exists(file_path):
    #     results = load_pickle(file_path)
        
    # Generate one random pose
    if 'pose' in results:
        J_ref = results['J_ref']
        poses = results['poses']
    else:
        _, poses = nsf.robot.sample_joint_angles_and_poses(n=num_poses)
        results['poses'] = poses
        save_pickle(file_path, results)

    print(f"Start numerical IK...")
    num_solver_batch = partial(numerical_inverse_kinematics_batch, solver=nsf)    
    results['num'] = random_ikp(robot, poses, num_solver_batch, num_poses=num_poses, num_sols=num_sols)
    save_pickle(file_path, results)    
    print(f"Results numerical IK are saved in {file_path}")
    
    J_hat_num = results['num']['ik_sols']
    J_ref = J_hat_num[0]
    
    
    print(f"Start paik...")
    # paik's variable: num_poses, num_sols, std, 
    for std in tqdm(paik_std_list):
        paik_solver_batch = partial(solver_batch, solver=paik, std=std, J_ref=J_ref)
        name = f'paik_gaussian_{std}'
        if name not in results:
            results[name] = random_ikp(robot, poses, paik_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
            save_pickle(file_path, results) 
    print(f"Results paik are saved in {file_path}")
    
    
    print(f"Start paik...")
    # paik's variable: num_poses, num_sols, std, 
    for std in tqdm(paik_std_list):
        paik_solver_batch = partial(solver_batch, solver=paik, std=std, J_ref=J_ref)
        name = f'paik_gaussian_{std}'
        if name not in results:
            results[name] = random_ikp(robot, poses, paik_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
            save_pickle(file_path, results) 
    print(f"Results paik are saved in {file_path}")

    # print(f"Start nsf w/o retreiver...")
    # # nsf's variable: std
    # for std in tqdm(paik_std_list):
    #     nsf_solver_batch = partial(solver_batch, solver=nsf, std=std, retriever=None, J_ref=J_ref)
    #     name = f'nsf_gaussian_{std}'
    #     if name not in results:
    #         results[name] = random_ikp(robot, poses, nsf_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
    #         save_pickle(file_path, results)

    print(f"Start nsf with cluster retriever...")    
    # nsf's variable: num_poses, num_sols, max_samples, radius, num_clusters
    use_samples = max_samples
    for radius, num_clusters in tqdm_itertools.product(radius_list, num_clusters_list):
        nsf_solver_batch = partial(solver_batch, solver=nsf, radius=radius, num_clusters=num_clusters, retriever=retriever, use_samples=use_samples, retr_type='cluster', J_ref=J_ref)
        name = f'nsf_cluster_{radius}_{num_clusters}'
        if name not in results:
            results[name] = random_ikp(robot, poses, nsf_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
            save_pickle(file_path, results)
    print(f"Results nsf with cluster retriever are saved in {file_path}")
    
    print(f"Start nsf with random retriever...")
    # nsf's variable: num_poses, num_sols, max_samples, radius
    for radius, num_clusters in tqdm_itertools.product(radius_list, num_clusters_list):
        use_samples = min(max_samples, num_clusters)
        nsf_solver_batch = partial(solver_batch, solver=nsf, radius=radius, retriever=retriever, use_samples=use_samples, retr_type='random', J_ref=J_ref)
        name = f'nsf_random_{radius}_{use_samples}'
        if name not in results:
            results[name] = random_ikp(robot, poses, nsf_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
            save_pickle(file_path, results)
            
    print(f"Start nsf with numerical retriever...")
    # nsf's variable: num_poses, num_sols, max_samples, radius, num_seeds_per_pose
    for radius, num_seeds_per_pose in tqdm_itertools.product(radius_list, num_seeds_per_pose_list):
        nsf_solver_batch = partial(solver_batch, solver=nsf, radius=radius, retriever=retriever, num_seeds_per_pose=num_seeds_per_pose, retr_type='numerical', J_ref=J_ref)
        name = f'nsf_numerical_{radius}_{num_seeds_per_pose}'
        if name not in results:
            results[name] = random_ikp(robot, poses, nsf_solver_batch, num_poses=num_poses, num_sols=num_sols, J_hat_num=J_hat_num)
            save_pickle(file_path, results)
            
    # for k, v in results.items():
    #     if 'ik_sols' in v:
    #         ik_sols = v['ik_sols']
    #         file_path = f"{record_dir}/{func_name}_{k}.png"
    #         plot_3d_scatter(ik_sols.reshape(-1, ik_sols.shape[-1]), file_path)
    
    # drop pose in results
    results.pop('poses')
    # drop joint_config_ref in results
    results.pop('joint_config_ref')

    ik_sols = parse_ik_sols(results)
    keys_list = [(key, 'numerical') for key in ik_sols.keys() if key != 'num']
    file_path_list = [f"{record_dir}/{func_name}_{key}.png" for key in ik_sols.keys() if key != 'num']
    for keys, file_path in zip(keys_list, file_path_list):
        plot_random_3d_joints_scatter(keys, ik_sols, {}, {}, {}, joint_config_ref, file_path)
    

    # stat_results w/o pose, ik_sols 
    stat_results = {}
    for k, v in results.items():
        v.pop('ik_sols')
        stat_results[k] = v        
    
    df = pd.DataFrame(stat_results).T
    # round to 4 decimal places
    df = df.round(4)
    print(df)
    file_path = f"{record_dir}/selective_ik_evaluation_results_{robot_name}_{num_sols}.csv"
    df.to_csv(file_path)
    print(f"Results are saved in {file_path}")

In [20]:
from common.config import Config_IKP
config = Config_IKP()

config.workdir = '/mnt/d/pads/Documents/paik_store'

kwarg = {
    'record_dir': config.record_dir,
    'robot_name': 'panda',
    'num_poses': 1,
    'num_sols': 100,  # 300, 500, 1000
    'paik_std_list': [0.001, 0.1], # 0.001, 0.1, 0.25, 0.5, 0.7
    'radius_list': [0.001, 0.1], # 0, 0.1, 0.3, 0.5, 0.7, 0.9
    'num_seeds_per_pose_list': [10], # 10, 20, 30, 40, 50
    'num_clusters_list': [30] # 13, 16, 19, 25, 30, 40
}

robot_names = ["panda"] # "panda", "fetch", "fetch_arm", "atlas_arm", "atlas_waist_arm", "baxter_arm"

for robot_name in robot_names:
    print(f"Start to evaluate {robot_name}...")
    kwarg['robot_name'] = robot_name
    selective_ik(**kwarg)

Start to evaluate panda...
WorldModel::LoadRobot: /home/luca/.cache/jrl/temp_urdfs/panda_arm_hand_formatted_link_filepaths_absolute.urdf
joint mimic: no multiplier, using default value of 1 
joint mimic: no offset, using default value of 0 
URDFParser: Link size: 17
URDFParser: Joint size: 12
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link0.dae (59388 verts, 20478 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link1.dae (37309 verts, 12516 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link2.dae (37892 verts, 12716 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link3.dae (42512 verts, 14233 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link4.dae (43520 verts, 14620 tris)
L

  0%|          | 0/1 [00:00<?, ?it/s]

Start to initialize numerical retriever...
[SUCCESS] load from /home/luca/paik/weights/panda/0904-1939
[SUCCESS] load best date 0904-1939 with l2 0.00297 from /home/luca/paik/weights/panda/best_date_paik.csv.
Start numerical IK...
Results numerical IK are saved in /mnt/d/pads/Documents/paik_store/record/2024_11_13/selective_ik_panda_100.pkl
Start paik...


100%|██████████| 2/2 [00:00<00:00,  2.56it/s]


Results paik are saved in /mnt/d/pads/Documents/paik_store/record/2024_11_13/selective_ik_panda_100.pkl
Start paik...


100%|██████████| 2/2 [00:00<00:00, 19737.90it/s]

Results paik are saved in /mnt/d/pads/Documents/paik_store/record/2024_11_13/selective_ik_panda_100.pkl
Start nsf with cluster retriever...





  0%|          | 0/2 [00:00<?, ?it/s]

Results nsf with cluster retriever are saved in /mnt/d/pads/Documents/paik_store/record/2024_11_13/selective_ik_panda_100.pkl
Start nsf with random retriever...


  0%|          | 0/2 [00:00<?, ?it/s]

Start to random retriever...
Start to random retriever...
Start nsf with numerical retriever...


  0%|          | 0/2 [00:00<?, ?it/s]

TypeError: numerical_retriever() got an unexpected keyword argument 'num_seeds_per_pose'

In [None]:
nsf = get_solver(arch_name="nsf", robot=get_robot('panda'), load=True, work_dir='/home/luca/paik')
_, pose = nsf.robot.sample_joint_angles_and_poses(n=1)
num_sols = 100
ik_sols_num = numerical_inverse_kinematics_batch(solver=nsf, P=pose, num_sols=num_sols)
Z = nsf.generate_z_from_ik_solutions(pose, ik_sols_num.reshape(num_sols, nsf.robot.n_dofs))
ik_sols_nsf = nsf.generate_ik_solutions(P=pose, num_sols=num_sols, latents=Z, verbose=True)
l2, ang = evaluate_pose_error_J3d_P2d(robot=nsf.robot, J=ik_sols_nsf.reshape(num_sols, 1, nsf.robot.n_dofs), P=pose, return_all=True)

WorldModel::LoadRobot: /home/luca/.cache/jrl/temp_urdfs/panda_arm_hand_formatted_link_filepaths_absolute.urdf
joint mimic: no multiplier, using default value of 1 
joint mimic: no offset, using default value of 0 
URDFParser: Link size: 17
URDFParser: Joint size: 12
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link0.dae (59388 verts, 20478 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link1.dae (37309 verts, 12516 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link2.dae (37892 verts, 12716 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link3.dae (42512 verts, 14233 tris)
LoadAssimp: Loaded model /home/luca/miniconda3/lib/python3.9/site-packages/jrl/urdfs/panda/meshes/visual/link4.dae (43520 verts, 14620 tris)
LoadAssimp: Loaded model /ho

100%|██████████| 1/1 [00:00<00:00,  4.57it/s]
100%|██████████| 1/1 [00:00<00:00, 10.08it/s]


In [None]:
l2.mean(), np.rad2deg(ang.mean())

(0.006941903072011681, 1.0502243409556615)