# Overview

This notebook analyzes the use of levenberg-marquardt for IK refinement

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from time import time 

import torch
import numpy as np
from jrl.utils import set_seed, make_text_green_or_red, evenly_spaced_colors
from jrl.math_utils import geodesic_distance_between_quaternions
import matplotlib.pyplot as plt

from ikflow.model_loading import get_ik_solver

torch.set_printoptions(linewidth=300, precision=6, sci_mode=False)

In [None]:
MODEL_NAME = "panda__full__lp191_5.25m"
compile_1_cfg = {"fullgraph": False, "mode": "default", "backend": "inductor"}
compile_2_cfg = {"fullgraph": True, "mode": "default" , "backend": "inductor"}
compile_3_cfg = {"fullgraph": False, "mode": "default" , "backend": "inductor", "dynamic": True}


# === Errors
# {"fullgraph": True, "mode": "default" , "backend": "cudagraphs"}
# ^ doesn't work (RuntimeError: Node '5bec40': [(7,)] -> GLOWCouplingBlock -> [(7,)] encountered an error.)
# {"fullgraph": True, "mode": "max-autotune" , "backend": "inductor"}
# ^ doesn't work (RuntimeError: Node '5bec40': [(7,)] -> GLOWCouplingBlock -> [(7,)] encountered an error.)
# {"fullgraph": False, "mode": "max-autotune" , "backend": "inductor"}
# ^ really slow
# {"fullgraph": False, "mode": "default" , "backend": "onnxrt", "dynamic": True}
# ^ onyx unavailable
# {"fullgraph": True, "mode": "reduce-overhead" , "backend": "inductor", "dynamic": True}
# ^ really slow

# >>> torch._dynamo.list_backends()
# ['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'openxla_eval', 'tvm']

solver, _ = get_ik_solver(MODEL_NAME, compile_model=None)
robot = solver.robot
solver_compiled_1, _  = get_ik_solver(MODEL_NAME, compile_model=compile_1_cfg)
solver_compiled_2, _  = get_ik_solver(MODEL_NAME, compile_model=compile_2_cfg)
solver_compiled_3, _  = get_ik_solver(MODEL_NAME, compile_model=compile_3_cfg)

In [None]:
batch_sizes = [10, 100, 500, 1000, 2500, 5000, 7500, 10000]
k_retry = 10

curves = []
labels = []
device = "cuda:0"


def eval_runtime(solver):

    runtimes = []
    runtime_stds = []

    for batch_size in batch_sizes:
        sub_runtimes = []
        for _ in range(k_retry):
            _, target_poses = robot.sample_joint_angles_and_poses(batch_size, only_non_self_colliding=False)
            target_poses = torch.tensor(target_poses, device=device, dtype=torch.float32)
            t0 = time()
            solver.generate_ik_solutions(target_poses)
            sub_runtimes.append(time() - t0)
        runtimes.append(np.mean(sub_runtimes))
        runtime_stds.append(np.std(sub_runtimes))
        print(f"batch_size: {batch_size},\tmean runtime: {runtimes[-1]:.5f} s,\tsub_runtimes: {sub_runtimes}")
    return np.array(runtimes), np.array(runtime_stds)

base_runtimes, base_runtime_stds = eval_runtime(solver)
compiled_1_runtimes, compiled_1_runtime_stds = eval_runtime(solver_compiled_1)
compiled_2_runtimes, compiled_2_runtime_stds = eval_runtime(solver_compiled_2)
compiled_3_runtimes, compiled_3_runtime_stds = eval_runtime(solver_compiled_3)

In [None]:
def plot_runtime_curve(multi_runtimes_means, multi_runtimes_stds, labels):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.set_title(f"Batch size vs. Inference time for default vs. compiled nn.Module")
    ax.grid(alpha=0.2)
    ax.set_xlabel("Batch size")
    ax.set_ylabel("Inference time [s]")

    colors = evenly_spaced_colors(int(len(multi_runtimes_means)))

    for (runtimes, runtime_stds, label, color) in zip(multi_runtimes_means, multi_runtimes_stds, labels, colors):
        ax.plot(batch_sizes, runtimes, label=label, color=color)
        ax.fill_between(batch_sizes, runtimes - runtime_stds, runtimes + runtime_stds, alpha=0.15, color=color)
        ax.scatter(batch_sizes, runtimes, s=15, color=color)

    ax.legend()
    plt.show()

plot_runtime_curve(
    [base_runtimes, compiled_1_runtimes, compiled_2_runtimes, compiled_3_runtimes],
    [base_runtime_stds, compiled_1_runtime_stds, compiled_2_runtime_stds, compiled_3_runtime_stds], 
    ["un-compiled", f"compiled_1: {compile_1_cfg}", f"compiled_2: {compile_2_cfg}", f"compiled_3: {compile_3_cfg}"]
)
