In [94]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pydrake
from pydrake.all import (
    MathematicalProgram,
    RotationMatrix,
    RollPitchYaw,
    SnoptSolver,
    IpoptSolver,
    GurobiSolver,
    SolverOptions,
    RandomGenerator,
    UniformlyRandomRotationMatrix,
    MixedIntegerRotationConstraintGenerator,
    IntervalBinning,
)
import os
from spatial_scene_grammars.rules import add_bingham_cost, WorldFrameBinghamRotationRule

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [100]:
# Nonlinear optimization version
def try_solve(R_targ, R_init, verbose=False):
    prog = MathematicalProgram()
    R = prog.NewContinuousVariables(3, 3)
    prog.AddBoundingBoxConstraint(-np.ones(9), np.ones(9), R.flatten())
    RtR = R.T.dot(R)
    z_dir = np.cross(R[:, 0], R[:, 1])
    prog.AddConstraint(np.dot(z_dir, R[:, 2]) >= 0.)
    # Strong bounding box on rotation matrix elements.
    eye = np.eye(3)
    for i in range(3):
        for j in range(3):
            prog.SetInitialGuess(R[i, j], R_init[i, j])
            prog.AddConstraint(RtR[i, j] == eye[i, j])

    rule = WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances(
        RotationMatrix(R_targ), [10., 9., 8]
    )
    Z = np.diag(rule._bingham_dist.z.numpy())
    M = rule._bingham_dist.m.numpy()
    log_normalizer = rule._bingham_dist._norm_const.item()
    
    mode = M[:, -1].reshape(4, 1)
    modemodeT = mode.dot(mode.T)
    expected_objective = np.trace(Z.dot(M.T.dot(modemodeT.dot(M)))) - log_normalizer
    qqt = add_bingham_cost(prog, R, True, M, Z, log_normalizer)

    options = SolverOptions()
    logfile = "/tmp/snopt.log"
    os.system("rm %s" % logfile)
    #solver = SnoptSolver()
    #ptions.SetOption(solver.id(), "Print file", logfile)
    solver = IpoptSolver()

    result = solver.Solve(prog, None, options)
    #with open(logfile) as f:
    #    print(f.read())

    success = result.is_success()
    qqt_opt = result.GetSolution(qqt)
    R_opt = result.GetSolution(R)
    q_expected = RotationMatrix(R_opt).ToQuaternion()
    qqt_expected = q_expected.wxyz().reshape(4, 1).dot(q_expected.wxyz().reshape(1, 4))
    optimal_cost = result.get_optimal_cost()
    qqt_err = np.sum(np.abs(qqt_opt - qqt_expected))
    if verbose or success == False or qqt_err > 1E-6 or np.abs(optimal_cost + expected_objective) > 1E-3:
        print("Expected obj: ", expected_objective)
        print("Optimal cost: ", optimal_cost)
        print("qqt: ", qqt_opt)
        print("R: ", R_opt)
        print("Total R err: ", np.sum(np.abs(R_opt.T.dot(R_targ) - np.eye(3))))
        print("qqt expected: ", qqt_expected)
        print("Total qqt err: ", qqt_err)
        print("Total qqt from mode: ", np.sum(np.abs(qqt_opt - modemodeT)))
        print("Singular vals of qqt: ", np.linalg.svd(qqt_opt)[1])
for k in range(5):
    R_targ = UniformlyRandomRotationMatrix(RandomGenerator(k)).matrix()
    R_init = UniformlyRandomRotationMatrix(RandomGenerator(2*k)).matrix()
    print(try_solve(R_targ, R_init, verbose=False))

None
None
None
None
None


In [106]:
# MI optimization version
def try_solve(R_targ, verbose=False):
    prog = MathematicalProgram()
    R = prog.NewContinuousVariables(3, 3)
    
    mip_rot_gen = MixedIntegerRotationConstraintGenerator(
        approach = MixedIntegerRotationConstraintGenerator.Approach.kBilinearMcCormick,
        num_intervals_per_half_axis=3,
        interval_binning = IntervalBinning.kLinear
    )
    mip_rot_gen.AddToProgram(R, prog)

    rule = WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances(
        RotationMatrix(R_targ), [100., 90., 80]
    )
    Z = np.diag(rule._bingham_dist.z.numpy())
    M = rule._bingham_dist.m.numpy()
    log_normalizer = rule._bingham_dist._norm_const.item()
    
    mode = M[:, -1].reshape(4, 1)
    modemodeT = mode.dot(mode.T)
    expected_objective = np.trace(Z.dot(M.T.dot(modemodeT.dot(M)))) - log_normalizer

    active = prog.NewBinaryVariables(1)[0]
    prog.AddLinearEqualityConstraint(active == 1.)
    qqt = add_bingham_cost(prog, R, active, M, Z, log_normalizer)

    solver = GurobiSolver()
    options = SolverOptions()
    logfile = "/tmp/gurobi.log"
    os.system("rm -f %s" % logfile)
    options.SetOption(solver.id(), "LogFile", logfile)
    options.SetOption(solver.id(), "MIPGap", 1E-3)
    result = solver.Solve(prog, None, options)

    success = result.is_success()
    qqt_opt = result.GetSolution(qqt)
    R_opt = result.GetSolution(R)
    q_expected = RotationMatrix(R_opt).ToQuaternion()
    qqt_expected = q_expected.wxyz().reshape(4, 1).dot(q_expected.wxyz().reshape(1, 4))
    optimal_cost = result.get_optimal_cost()
    qqt_err = np.sum(np.abs(qqt_opt - qqt_expected))
    if verbose or success == False or qqt_err > 1E-6:
        print("Expected obj: ", expected_objective)
        print("Optimal cost: ", optimal_cost)
        print("qqt: ", qqt_opt)
        print("R: ", R_opt)
        print("Total R err: ", np.sum(np.abs(R_opt.T.dot(R_targ) - np.eye(3))))
        print("qqt expected: ", qqt_expected)
        print("Total qqt err: ", qqt_err)
        print("Total qqt from mode: ", np.sum(np.abs(qqt_opt - modemodeT)))
        print("Singular vals of qqt: ", np.linalg.svd(qqt_opt)[1])
    return np.sum(qqt_opt - qqt_expected) < 1E-4
for k in range(1):
    R_targ = UniformlyRandomRotationMatrix(RandomGenerator(k)).matrix()
    #R_targ = RotationMatrix(RollPitchYaw(np.pi/4., 0., 0.)).matrix()
    print("r TARG: ", R_targ)
    print(try_solve(R_targ, verbose=False))

r TARG:  [[-0.18568923  0.08435438  0.97898102]
 [-0.93864483  0.2794536  -0.20211768]
 [-0.29062929 -0.95644656  0.02728728]]




Expected obj:  -0.013063191649785257
Optimal cost:  -1.0915310594167558
qqt:  [[ 0.29166667 -0.20833333  0.29166667 -0.29166667]
 [-0.20833333  0.125      -0.20833333  0.20833333]
 [ 0.29166667 -0.20833333  0.29166667 -0.29166667]
 [-0.29166667  0.20833333 -0.29166667  0.29166667]]
R:  [[-0.16666667  0.16666667  1.        ]
 [-1.          0.16666667 -0.16666667]
 [-0.16666667 -1.          0.16666667]]
Total R err:  0.6255847073956433
qqt expected:  [[ 0.28488372 -0.20348837  0.28488372 -0.28488372]
 [-0.20348837  0.14534884 -0.20348837  0.20348837]
 [ 0.28488372 -0.20348837  0.28488372 -0.28488372]
 [-0.28488372  0.20348837 -0.28488372  0.28488372]]
Total qqt err:  0.11046511627907046
Total qqt from mode:  0.38920772944675797
Singular vals of qqt:  [1.02041650e+00 2.04164999e-02 6.73519211e-17 5.47382213e-48]
True
