# Overview

This notebook analyzes the use of levenberg-marquardt for IK refinement

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from time import time

import torch
import numpy as np
from jrl.utils import set_seed, evenly_spaced_colors
import matplotlib.pyplot as plt

from ikflow.model_loading import get_ik_solver

torch.set_printoptions(linewidth=300, precision=6, sci_mode=False)

In [None]:
MODEL_NAME = "panda__full__lp191_5.25m"
POS_ERROR_THRESHOLD = 0.001
ROT_ERROR_THRESHOLD = 0.01

ikflow_solver, _ = get_ik_solver(MODEL_NAME)
robot = ikflow_solver.robot

In [None]:
def debug_dist_to_jlims(robot, qs):
    """test with:
    qs_random, _ = ikflow_solver.robot.sample_joint_angles_and_poses(5)
    debug_dist_to_jlims(robot, torch.tensor(qs_random))
    """

    mins = 100 * torch.ones(len(qs))
    eps = 1e-5
    for i, (l, u) in enumerate(robot.actuated_joints_limits):
        assert torch.min(qs[:, i]) > l - eps
        assert torch.max(qs[:, i]) < u + eps, f"{torch.max(qs[:, i])} !< {u}"
        mins = torch.minimum(mins, torch.abs(qs[:, i] - l))
        mins = torch.minimum(mins, torch.abs(qs[:, i] - u))
    print("distances to joint limits:", mins)
    return mins

In [None]:
def plot_runtime_curve(curves, labels):
    fig, (axl, axr) = plt.subplots(1, 2, figsize=(18, 8))
    fig.suptitle("Levenberg-Marquardt IK Convergence")

    axl.set_title("Runtime")
    axl.grid(alpha=0.2)
    axl.set_xlabel("batch size")
    axl.set_ylabel("runtime (s)")

    axr.set_title("Success Pct")
    axr.grid(alpha=0.2)
    axr.set_xlabel("batch size")
    axr.set_ylabel("success pct (%)")

    colors = evenly_spaced_colors(int(1.5 * len(curves)))

    for (batch_sizes, runtimes, runtime_stds, success_pcts, success_pct_stds), label, color in zip(
        curves, labels, colors
    ):
        axl.plot(batch_sizes, runtimes, label=label, color=color)
        axl.fill_between(batch_sizes, runtimes - runtime_stds, runtimes + runtime_stds, alpha=0.15, color=color)
        axl.scatter(batch_sizes, runtimes, s=15, color=color)

        axr.plot(batch_sizes, success_pcts, label=label, color=color)
        axr.fill_between(
            batch_sizes, success_pcts - success_pct_stds, success_pcts + success_pct_stds, alpha=0.15, color=color
        )
        axr.scatter(batch_sizes, success_pcts, s=15, color=color)

    # plt.savefig(f"lma_errors_{i}.pdf", bbox_inches="tight")
    # axl.legend()
    axr.legend()
    plt.show()

In [None]:
batch_sizes = [10, 100, 500, 1000, 2000]
k_retry = 3

all_repeat_counts = [(1, 2, 5), (1, 3, 10), (1, 4, 10), (1, 5, 10)]

device = "cuda:0"

curves = []
labels = []

set_seed()

for repeat_counts in all_repeat_counts:
    runtimes = []
    success_pcts = []
    runtime_stds = []
    success_pct_stds = []

    for batch_size in batch_sizes:
        sub_runtimes = []
        sub_success_pcts = []

        for _ in range(k_retry):
            _, target_poses = ikflow_solver.robot.sample_joint_angles_and_poses(
                batch_size, only_non_self_colliding=True
            )
            target_poses = torch.tensor(target_poses, device=device, dtype=torch.float32)
            t0 = time()
            _, valid_solutions = ikflow_solver.generate_exact_ik_solutions(
                target_poses,
                pos_error_threshold=POS_ERROR_THRESHOLD,
                rot_error_threshold=ROT_ERROR_THRESHOLD,
                repeat_counts=repeat_counts,
            )
            sub_runtimes.append(time() - t0)
            sub_success_pcts.append(100 * valid_solutions.sum().item() / batch_size)

        runtimes.append(np.mean(sub_runtimes))
        runtime_stds.append(np.std(sub_runtimes))

        success_pcts.append(np.mean(sub_success_pcts))
        success_pct_stds.append(np.std(sub_success_pcts))

        print(f"batch_size: {batch_size}, runtime: {runtimes[-1]:.3f}s")

    curves.append(
        (
            np.array(batch_sizes),
            np.array(runtimes),
            np.array(runtime_stds),
            np.array(success_pcts),
            np.array(success_pct_stds),
        )
    )
    labels.append(f"repeat_counts: {repeat_counts}")

In [None]:
plot_runtime_curve(curves, labels)

In [None]:
x = torch.zeros((1000, 8), device="cuda:0")
t0 = time()
x_cpu = x.cpu()
print(f"cpu time: {1000 * (time() - t0):.8f}ms")