In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import cyipopt

from functools import partial
from typing import *
from dataclasses import dataclass, field
from jaxlie import SE3, SO3
import jax_dataclasses as jdc

from sdf_world.sdf_world import *
from sdf_world.robot_model import *
from sdf_world.robots import *
from sdf_world.sparse_ipopt import *
from sdf_world.ik import *

import time

In [2]:
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7005/static/


In [3]:
PREDEFINED_ROBOTS

['gen3+hand_e', 'panda+panda_hand']

In [4]:
gen3 = get_predefined_robot(world.vis, "gen3+hand_e")
panda = get_predefined_robot(world.vis, "panda+panda_hand")

In [5]:
#load sdf meshes
table_lengths = [0.4, 0.4, 0.2]
table_start = Box(world.vis, "table_start", table_lengths, 'white', 0.5)
table_goal = Box(world.vis, "table_goal", table_lengths, 'white', 0.5)
obstacle = Box(world.vis, "obstacle", [0.4, 0.2, 0.35], 'white', 0.5)

obj_start = Mesh(world.vis, "obj_start", 
                 "./sdf_world/assets/object/mesh.obj",
                 color="blue", alpha=0.5)
obj_goal = Mesh(world.vis, "obj_goal", 
                "./sdf_world/assets/object/mesh.obj",
                color="green", alpha=0.5)
obj_ho = Mesh(world.vis, "obj_ho", 
                "./sdf_world/assets/object/mesh.obj",
                color="yellow", alpha=0.5)
obj = Mesh(world.vis, "obj", 
                "./sdf_world/assets/object/mesh.obj",
                color="white", alpha=0.8)

In [6]:
def SE3_trans(xyz):
    return SE3.from_translation(jnp.array(xyz))

ydev = 0.4
gen3_base_pose = SE3_trans([0, -ydev, 0])
panda_base_pose = SE3_trans([0, ydev, 0])
gen3.set_base_pose(gen3_base_pose)
panda.set_base_pose(panda_base_pose)

table_start.set_translate([0.45, -ydev, 0.2/2])
table_goal.set_translate([0.45, ydev, 0.2/2])
obj_lengths = obj_start.mesh.bounding_box.primitive.extents
obj_start.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
trans_goal = jnp.array([0.45, 0.4, obj_lengths[-2]/2+table_lengths[-1]])
obstacle.set_translate([0.45, 0., 0.35/2])
obj_goal_pose = SE3.from_rotation_and_translation(
    SO3.from_rpy_radians(jnp.pi/2, 0,0), trans_goal)
obj_goal.set_pose(obj_goal_pose)
obj.set_translate([0.45, -0.4, obj_lengths[-1]/2+table_lengths[-1]])
obj_ho.set_translate([0.3, 0., 0.6])

In [7]:
env = SDFContainer([table_start, table_goal, obstacle], safe_dist=None)

In [8]:
hande = Gripper(world.vis, "hande", gen3.gripper.model, 
                gen3.gripper.tool_pose_offset, max_width=0.05, scale=0.001)
panda_hand = Gripper(world.vis, "panda_hand", panda.gripper.model, 
                    panda.gripper.tool_pose_offset, max_width=0.08, is_rev_type=False)

In [9]:
from sdf_world.network import *
#learned model
# models 
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
restored_grasp = orbax_checkpointer.restore("model/grasp_net_prob_dist")
restored_manip_gen3 = orbax_checkpointer.restore("model/manip_net_gen3")
restored_manip_panda = orbax_checkpointer.restore("model/manip_net")

#grasp net
grasp_net = GraspNet(32, 6)
grasp_fn = lambda x: grasp_net.apply(restored_grasp["params"], x)

grasp_logit_fn = lambda g: grasp_fn(g)[0]
grasp_dist_fn = lambda g: grasp_fn(g)[1]

#manip net: input: wxyzxyz
manip_net = ManipNet(64)
zflip = SE3.from_rotation(SO3.from_z_radians(jnp.pi))
def manip_gen3(grasp_pose_world:SE3, robot_pose:SE3):
    grasp_pose_robot = robot_pose.inverse() @ grasp_pose_world
    grasp_pose_robot_flip = grasp_pose_robot @ zflip
    inputs = jnp.vstack([
        grasp_pose_robot.parameters(), 
        grasp_pose_robot_flip.parameters() ])
    return jax.vmap(manip_net.apply, (None, 0))(
        restored_manip_gen3["params"], inputs).flatten().max()

def manip_panda(grasp_pose_world:SE3, robot_pose:SE3):
    grasp_pose_robot = robot_pose.inverse() @ grasp_pose_world
    grasp_pose_robot_flip = grasp_pose_robot @ zflip
    inputs = jnp.vstack([
        grasp_pose_robot.parameters(), 
        grasp_pose_robot_flip.parameters() ])
    return jax.vmap(manip_net.apply, (None, 0))(
        restored_manip_panda["params"], inputs).flatten().max()

In [10]:
#functions
def grasp_reconst(grasp:Array):
    rot = SO3(grasp_fn(grasp)[2:]).normalize()
    trans = grasp/restored_grasp["scale_to_norm"]
    return SE3.from_rotation_and_translation(rot, trans)
def grasp_embedding(grasp_point):
    grasp = grasp_point * restored_grasp["scale_to_norm"]
    return grasp

In [11]:
from sdf_world.ik import IKConfig, get_ik_fn, fk_err

gen3_fk_ee_fn = gen3.get_fk_ee_fn()
panda_fk_ee_fn = panda.get_fk_ee_fn()

gen3_ik_config = IKConfig(
    jax.tree_util.Partial(gen3_fk_ee_fn), gen3_base_pose, 0.01, 0.05, 20)
panda_ik_config = IKConfig(
    jax.tree_util.Partial(panda_fk_ee_fn), panda_base_pose, 0.01, 0.05, 20)
ik_fn = get_ik_fn()

def fk_constr(q:Array, grasp:Array, obj_pose:Array, robot_pose:Array, fk_ee_fn:Callable):
    target_pose = SE3(obj_pose) @ grasp_reconst(grasp)
    return fk_err(q, target_pose, SE3(robot_pose), fk_ee_fn)
fk_constr_gen3 = partial(fk_constr, fk_ee_fn=gen3_fk_ee_fn)
jac_fk_constr_gen3 = jax.jacfwd(fk_constr_gen3, argnums=[0,1])
fk_constr_panda = partial(fk_constr, fk_ee_fn=panda_fk_ee_fn)
jac_fk_constr_panda = jax.jacfwd(fk_constr_panda, argnums=[0,1])

def ik(q_init, target_pose, config):
    init_val = (q_init, target_pose, 1., 0, config)
    result = ik_fn(init_val=init_val)
    if result[3] >= 20:
        return None
    return result[0]

In [79]:
grasp_logit_fn
def manip_constr_fn_gen3(grasp, obj_pose, robot_pose):
    grasp_pose_world = SE3(obj_pose) @ grasp_reconst(grasp)
    return manip_gen3(grasp_pose_world, SE3(robot_pose))
def manip_constr_fn_panda(grasp, obj_pose, robot_pose):
    grasp_pose_world = SE3(obj_pose) @ grasp_reconst(grasp)
    return manip_panda(grasp_pose_world, SE3(robot_pose))

def manip_obj(grasp1, grasp2, ho_pose, robot_pose1, robot_pose2):
    ho_pose = SE3(ho_pose).normalize()
    grasp_pose1 = ho_pose @ grasp_reconst(grasp1)
    grasp_pose2 = ho_pose @ grasp_reconst(grasp2)
    manip_sum = manip_gen3(grasp_pose1, SE3(robot_pose1)) + \
            manip_panda(grasp_pose2, SE3(robot_pose2))
    return - manip_sum
jac_manip_obj = jax.grad(manip_obj, argnums=[0,1,2])

In [80]:
def qtn_constr(pose_like):
    qtn_like = pose_like[:4]
    return jnp.sum(qtn_like**2) - 1.
jac_qtn_constr = jax.grad(qtn_constr, argnums=[0])

In [38]:
# hand collision
gen3_hand_pc = hande.get_hand_pc_wrt_tool_pose()
panda_hand_pc = panda_hand.get_hand_pc_wrt_tool_pose()
def hand_col_constr(grasp, obj_pose, hand_pc):
    tool_pose = SE3(obj_pose) @ grasp_reconst(grasp)
    assigned_pc = jax.vmap(tool_pose.apply)(hand_pc)
    return env.distances(assigned_pc).min()
hand_col_constr_gen3 = jax.tree_util.Partial(
    hand_col_constr, hand_pc=gen3_hand_pc)
hand_col_constr_panda = jax.tree_util.Partial(
    hand_col_constr, hand_pc=panda_hand_pc)
jac_hand_col_constr_gen3 = jax.grad(hand_col_constr_gen3, argnums=[0])
jac_hand_col_constr_panda = jax.grad(hand_col_constr_panda, argnums=[0])

In [30]:
fk_point_gen3 = gen3.arm.model.get_fk_point_fn()
fk_point_panda = panda.arm.model.get_fk_point_fn()
def robot_pc_assign(q, robot, robot_pose, fk_point_fn):
    points = []
    for i, pc in enumerate(robot.pcs):
        if pc is None: continue
        link_idx = i if i <= len(robot.arm.model.ub) else len(robot.arm.model.ub)
        pc = jax.vmap(fk_point_fn, (None,None,0))(q, link_idx, pc)
        points.append(pc)
    points = jnp.vstack(points)
    points = jax.vmap(robot_pose.apply)(points)
    return points

def robot_col_constr(q, pc_assign_fn):
    points = pc_assign_fn(q)
    distances = env.distances(points)
    return distances.min()

robot_pc_assign_gen3 = jax.tree_util.Partial(
    robot_pc_assign, robot=gen3, robot_pose=gen3_base_pose, fk_point_fn=fk_point_gen3)
robot_pc_assign_panda = jax.tree_util.Partial(
    robot_pc_assign, robot=panda, robot_pose=panda_base_pose, fk_point_fn=fk_point_panda)
robot_col_constr_gen3 = jax.tree_util.Partial(
    robot_col_constr, pc_assign_fn=robot_pc_assign_gen3)
robot_col_constr_panda = jax.tree_util.Partial(
    robot_col_constr, pc_assign_fn=robot_pc_assign_panda)

jac_robot_col_constr_gen3 = jax.grad(robot_col_constr_gen3, argnums=[0])
jac_robot_col_constr_panda = jax.grad(robot_col_constr_panda, argnums=[0])

In [33]:
pc1 = PointCloud(world.vis, "hand_pc", hande.hand_pc, 0.01, "red")
pc2 = PointCloud(world.vis, "robot_pc", hande.hand_pc, 0.005, "blue")

In [77]:
def get_rot_mid(rot1, rot2):
    return rot1 @ SO3.exp((rot1.inverse() @ rot2).log()/2)
# given
gen3_neutral = jnp.array([0,0,0.,np.pi/3,0,np.pi/2,0.2])
panda_neutral = (panda.arm.model.ub + panda.arm.model.lb)/2
# problem:
obj_start_pose = obj_start.pose
obj_goal_pose = obj_goal.pose

obj_ho_rot = get_rot_mid(obj_start_pose.rotation(), obj_goal_pose.rotation())
obj_ho_pos = jnp.array([0.3, 0., 0.6])
obj_ho_pose_init = SE3.from_rotation_and_translation(obj_ho_rot, obj_ho_pos)
obj_ho.set_pose(obj_ho_pose_init)

In [100]:
ws_lb = np.array([-1, -1, -1, -1, -0.,-0.5,0.2])
ws_ub = np.array([1, 1, 1, 1, 0.5, 0.5, 1.])

In [103]:
bdr1 = SparseIPOPT()
bdr1.add_variable("grasp_1", 3, -1., 1.)
bdr1.add_variable("grasp_2", 3, -1., 1.)
bdr1.add_variable("pose_ho", 7, ws_lb, ws_ub)

bdr1.add_parameter("pose_o1", 7, obj_start_pose.parameters())
bdr1.add_parameter("pose_o2", 7, obj_goal_pose.parameters())
bdr1.add_parameter("pose_gen3", 7, gen3_base_pose.parameters())
bdr1.add_parameter("pose_panda", 7, panda_base_pose.parameters())

bdr1.register_fn("grasp_logit", [3], 1,
                 grasp_logit_fn, jax.grad(grasp_logit_fn, argnums=[0]))
bdr1.register_fn("manip_gen3", [7, 7], 1,
                 manip_constr_fn_gen3, 
                 jax.grad(manip_constr_fn_gen3, argnums=[0,1]),
                 jac_out_argnums=[0,1])
bdr1.register_fn("manip_panda", [7, 7], 1,
                 manip_constr_fn_panda, 
                 jax.grad(manip_constr_fn_panda, argnums=[0,1]),
                 jac_out_argnums=[0,1])
bdr1.register_fn("hand_col_gen3", [3, 7], 1,
                 hand_col_constr_gen3, 
                 jac_hand_col_constr_gen3,
                 jac_out_argnums=[0])
bdr1.register_fn("hand_col_panda", [3, 7], 1,
                 hand_col_constr_panda, 
                 jac_hand_col_constr_panda,
                 jac_out_argnums=[0])
bdr1.register_fn("manip_obj", [3,3,7,7,7], 1,
                 manip_obj, jac_manip_obj, 
                 jac_out_argnums=[0,1,2])
bdr1.register_fn("qtn_constr", [7], 1, 
                 qtn_constr, jac_qtn_constr)

def debug(xdict):
    grasps = [xdict["grasp_1"], xdict["grasp_2"]]
    obj_poses = [xdict["pose_o1"], xdict["pose_o2"]]
    hands = [hande, panda_hand]
    hand_pcs = [gen3_hand_pc, panda_hand_pc]
    ho_pose = SE3(xdict["pose_ho"])

    pcs = []
    for i in range(2):
        handover_pose = ho_pose @ grasp_reconst(grasps[i]) #
        grip_pose = SE3(obj_poses[i]) @ grasp_reconst(grasps[i]) #
        hands[i].set_tool_pose(handover_pose)
        pcs.append(jax.vmap(grip_pose.apply)(hand_pcs[i]))
    pc1.reload(points=np.vstack(pcs), size=0.005)
    obj_ho.set_pose(ho_pose)
    time.sleep(0.1)
bdr1.set_debug_callback(debug)

bdr1.set_objective("manip_obj", ["grasp_1", "grasp_2", "pose_ho","pose_gen3","pose_panda"])

bdr1.set_constr("grasp_constr_pick", "grasp_logit", ["grasp_1"], 1., np.inf)
bdr1.set_constr("grasp_constr_place", "grasp_logit", ["grasp_2"], 1., np.inf)

bdr1.set_constr("manip_constr_ho1", "manip_gen3", 
                ["grasp_1", "pose_ho", "pose_gen3"], 0.4, np.inf)
bdr1.set_constr("manip_constr_ho2", "manip_panda", 
                ["grasp_2", "pose_ho", "pose_panda"], 0.4, np.inf)

bdr1.set_constr("hand_col_constr_pick", "hand_col_gen3", 
                ["grasp_1", "pose_o1"], 0.02, np.inf)
bdr1.set_constr("hand_col_constr_place", "hand_col_panda", 
                ["grasp_2", "pose_o2"], 0.02, np.inf)
bdr1.set_constr("qtn_constr", "qtn_constr",
                ["pose_ho"], 0., 0.)
ipopt_grasp = bdr1.build(compile=True)

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
ooo--------------------------------------
---ooo-----------------------------------
ooo---ooooooo----------------------------
---oooooooooo----------------------------
ooo--------------------------------------
---ooo-----------------------------------
------ooooooo----------------------------


In [115]:
bdr2 = SparseIPOPT()
bdr2.add_variable("grasp_1", 3, -1., 1.)
bdr2.add_variable("q_pick", 7, gen3.arm.model.lb, gen3.arm.model.ub)
bdr2.add_variable("grasp_2", 3, -1., 1.)
bdr2.add_variable("q_place", 7, panda.arm.model.lb, panda.arm.model.ub)
bdr1.add_variable("pose_ho", 7, ws_lb, ws_ub)

bdr2.add_parameter("pose_o1", 7, obj_start_pose.parameters())
bdr2.add_parameter("pose_o2", 7, obj_goal_pose.parameters())
bdr2.add_parameter("pose_gen3", 7, gen3_base_pose.parameters())
bdr2.add_parameter("pose_panda", 7, panda_base_pose.parameters())

bdr2.register_fn("grasp_logit", [3], 1,
                 grasp_logit_fn, jax.grad(grasp_logit_fn, argnums=[0]))
# bdr2.register_fn("manip_gen3", [7, 7], 1,
#                  manip_constr_fn_gen3, 
#                  jax.grad(manip_constr_fn_gen3, argnums=[0]),
#                  jac_out_argnums=[0])
# bdr2.register_fn("manip_gen3", [7, 7], 1,
#                  manip_constr_fn_gen3, 
#                  jax.grad(manip_constr_fn_gen3, argnums=[0]),
#                  jac_out_argnums=[0])
bdr2.register_fn("hand_col_gen3", [3, 7], 1,
                 hand_col_constr_gen3, 
                 jac_hand_col_constr_gen3,
                 jac_out_argnums=[0])
bdr2.register_fn("hand_col_panda", [3, 7], 1,
                 hand_col_constr_panda, 
                 jac_hand_col_constr_panda,
                 jac_out_argnums=[0])
bdr2.register_fn("kin_gen3", [7, 3, 7, 7], 6,
                 fk_constr_gen3, jac_fk_constr_gen3, 
                 jac_out_argnums=[0,1])
bdr2.register_fn("kin_panda", [7, 3, 7, 7], 6,
                 fk_constr_panda, jac_fk_constr_panda, 
                 jac_out_argnums=[0,1])
bdr1.register_fn("qtn_constr", [7], 1, 
                 qtn_constr, jac_qtn_constr)
bdr2.register_fn("robot_col_gen3", [7], 1,
                 robot_col_constr_gen3, jac_robot_col_constr_gen3)
bdr2.register_fn("robot_col_panda", [7], 1,
                 robot_col_constr_panda, jac_robot_col_constr_panda)

def debug_fn(xdict):
    grasps = [xdict["grasp_1"], xdict["grasp_2"]]
    qs = [xdict["q_pick"], xdict["q_place"]]
    obj_poses = [xdict["pose_o1"], xdict["pose_o2"]]
    ho_pose = SE3(xdict["pose_ho"])
    hands = [hande, panda_hand]
    robots: List[ArmGripper] = [gen3, panda]

    for i in range(2):
        tool_pose = SE3(obj_poses[i]) @ grasp_reconst(grasps[i])
        hands[i].set_tool_pose(tool_pose)
        robots[i].set_joint_angles(qs[i])
    obj_ho.set_pose(ho_pose)
    
    pcs2 = []
    pcs2.append(robot_pc_assign_gen3(qs[0]))
    pcs2.append(robot_pc_assign_panda(qs[1]))
    pc2.reload(points=np.vstack(pcs2), size=0.01)
    time.sleep(0.1)
bdr2.set_debug_callback(debug_fn)

bdr2.set_constr("grasp_constr_pick", "grasp_logit", ["grasp_1"], 1., np.inf)
bdr2.set_constr("grasp_constr_place", "grasp_logit", ["grasp_2"], 1., np.inf)
#bdr2.set_constr("manip_constr_pick", "manip_gen3", ["grasp_1", "pose_o1", "pose_gen3"], 0.2, np.inf)
bdr2.set_constr("kin_constr_pick", "kin_gen3", 
                ["q_pick", "grasp_1", "pose_o1", "pose_gen3"], 0., 0.)
bdr2.set_constr("kin_constr_place", "kin_panda", 
                ["q_place", "grasp_2", "pose_o2", "pose_panda"], 0., 0.)
bdr2.set_constr("kin_constr_ho1", "kin_gen3", 
                ["q_pick", "grasp_1", "pose_ho", "pose_gen3"], 0., 0.)
bdr2.set_constr("kin_constr_place", "kin_panda", 
                ["q_place", "grasp_2", "pose_ho", "pose_panda"], 0., 0.)
#collision
bdr2.set_constr("hand_col_constr_pick", "hand_col_gen3", 
                ["grasp_1", "pose_o1"], 0.02, np.inf)
bdr2.set_constr("hand_col_constr_place", "hand_col_panda", 
                ["grasp_2", "pose_o2"], 0.02, np.inf)
bdr2.set_constr("robot_col_pick", "robot_col_gen3", 
                ["q_pick"], 0.02, np.inf)
bdr2.set_constr("robot_col_place", "robot_col_panda", 
                ["q_place"], 0.02, np.inf)
ipopt_full = bdr2.build(compile=False)

compiling objective ...
compiling gradient ...
compiling constraints ...
compiling jacobian ...
ooo---------------------------------------------
----------ooo-----------------------------------
oooooooooo--------------------------------------
oooooooooo--------------------------------------
oooooooooo--------------------------------------
oooooooooo--------------------------------------
oooooooooo--------------------------------------
oooooooooo--------------------------------------
----------oooooooooo----------------------------
----------oooooooooo----------------------------
----------oooooooooo----------------------------
----------oooooooooo----------------------------
----------oooooooooo----------------------------
----------oooooooooo----------------------------
ooo---------------------------------------------
----------ooo-----------------------------------
---ooooooo--------------------------------------
-------------ooooooo----------------------------


In [112]:
x0 = bdr1.get_init_value(
    {"grasp_1":np.random.uniform(-1,1,3), 
     "grasp_2":np.random.uniform(-1,1,3),
     "pose_ho":obj_ho_pose_init.parameters()
})
ipopt_grasp.add_option("print_level", 5)
sol, info = ipopt_grasp.solve(x0)
sol_dict = bdr1.split_solution(sol)
print(info["status_msg"])

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        7
Number of nonzeros in inequality constraint Jacobian.:       32
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       41
                     variables with only lower bounds:        0
                variables with lower and upper bounds:       13
                     variables with only upper bounds:        0
Total number of equality constraints.................:        1
Total number of inequality constraints...............:        6
        inequality constraints with only lower bounds:        6
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0 -1.3973939e+00 1.96e+02 1.26e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [113]:
grasp_pick = sol_dict["grasp_1"]
grasp_place = sol_dict["grasp_2"]
pose_ho = SE3(sol_dict["pose_ho"])
grasp_pose_pick = (obj_start_pose @ grasp_reconst(grasp_pick))
grasp_pose_place = (obj_goal_pose @ grasp_reconst(grasp_place))
grasp_pose_ho1 = (pose_ho @ grasp_reconst(grasp_pick))
grasp_pose_ho2 = (pose_ho @ grasp_reconst(grasp_place))

q_pick = ik(gen3_neutral, grasp_pose_pick, gen3_ik_config)
q_place = ik(panda_neutral, grasp_pose_place, panda_ik_config)
q_ho1 = ik(gen3_neutral, grasp_pose_ho1, gen3_ik_config)
q_ho2 = ik(panda_neutral, grasp_pose_ho2, panda_ik_config)

sol_dict["q_pick"] = q_pick
sol_dict["q_place"] = q_place
sol_dict["q_ho1"] = q_ho1
sol_dict["q_ho2"] = q_ho2
gen3.set_joint_angles(q_ho1)
panda.set_joint_angles(q_ho2)

In [116]:
x0 = bdr2.get_init_value(sol_dict)
ipopt_full.add_option("print_level", 5)
sol, info = ipopt_full.solve(x0)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:      120
Number of nonzeros in inequality constraint Jacobian.:       26
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       48
                     variables with only lower bounds:        0
                variables with lower and upper bounds:       20
                     variables with only upper bounds:        0
Total number of equality constraints.................:       12
Total number of inequality constraints...............:        6
        inequality constraints with only lower bounds:        6
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  0.0000000e+00 8.16e-02 1.01e+00   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [47]:
sol_dict = bdr2.split_solution(sol)
grasp_logit_fn(sol_dict["grasp_2"])

Array(-21.95515, dtype=float32)