Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baspacho device doesn't accept "cuda:0" #597

Open
luisenp opened this issue Sep 10, 2023 · 0 comments
Open

Baspacho device doesn't accept "cuda:0" #597

luisenp opened this issue Sep 10, 2023 · 0 comments

Comments

@luisenp
Copy link
Contributor

luisenp commented Sep 10, 2023

          By setting smaller w and h, try `autograd_mode="dense"`, I find it's a bug for theseus:
import torch
import theseus as th
import torchlie.functional as lieF


def _photometric_error(T_wc1: th.SE3, T_wc2: th.SE3,
                       d1: th.Vector, d2: th.Vector,
                       image1: th.Variable, image2: th.Variable,
                       p_unit_sphere_or_plane: th.Variable,
                       k: th.Variable) -> torch.Tensor:
    """note: undistort and crop image ahead, since mask not supported."""
    batch_size = image1.shape[0]
    h = image1.shape[2]
    w = image1.shape[3]

    # bidirectional projection: 2 -> 1
    T_c1c2 = T_wc1.inverse().compose(T_wc2)
    p_c1 = lieF.SE3.transform(T_c1c2.tensor, p_unit_sphere_or_plane.tensor * d2.tensor.view(batch_size, h, w, 1))
    p1 = p_c1[:, :, :, :2] / p_c1[:, :, :, 2, None]  # (batch_size, h, w, 2)
    p1 = p1 * k.tensor[:, :2].unsqueeze(1).unsqueeze(2) + k.tensor[:, 2:4].unsqueeze(1).unsqueeze(2)
    image2_in_1 = torch.nn.functional.grid_sample(image1.tensor, p1, padding_mode="border")
    err12 = image2_in_1 - image2.tensor

    # bidirectional projection: 1 -> 2
    T_c2c1 = T_c1c2.inverse()
    p_c2 = lieF.SE3.transform(T_c2c1.tensor, p_unit_sphere_or_plane.tensor * d1.tensor.view(batch_size, h, w, 1))
    p2 = p_c2[:, :, :, :2] / p_c2[:, :, :, 2, None]  # (batch_size, h, w, 2)
    p2 = p2 * k.tensor[:, :2].unsqueeze(1).unsqueeze(2) + k.tensor[:, 2:4].unsqueeze(1).unsqueeze(2)
    image1_in_2 = torch.nn.functional.grid_sample(image2.tensor, p2, padding_mode="border")
    err21 = image1_in_2 - image1.tensor

    # the original error dim is too big, make it impossible for Jacobian (batch_size, err_dim, var_dof)
    err = torch.cat((err21, err12), dim=1)
    err = torch.nn.functional.huber_loss(err, torch.zeros(1, dtype=err.dtype, device=err.device),
                                         reduction="none", delta=0.5)
    return torch.sum(err, dim=(1, 2, 3)).unsqueeze(1)  # (batch_size, 1)


def photometric_error_fix_pose(optim_vars, aux_vars) -> torch.Tensor:
    d1: th.Vector = optim_vars[0]  # unknown scale, > 0, better rescale ahead, (batch_size, h x w x 1)
    d2: th.Vector = optim_vars[1]  # unknown scale, > 0, better rescale ahead, (batch_size, h x w x 1)

    image1: th.Variable = aux_vars[0]  # (batch_size, 1, h, w) {normalized_intensity}
    image2: th.Variable = aux_vars[1]  # (batch_size, 1, h, w) {normalized_intensity}
    p_unit_sphere_or_plane: th.Variable = aux_vars[2]  # (batch_size=1, h, w, 3)
    k: th.Variable = aux_vars[3]  # (batch_size=1, 4) {fx, fy, cx, cy}
    T_wc1: th.SE3 = aux_vars[4]
    T_wc2: th.SE3 = aux_vars[5]

    return _photometric_error(T_wc1, T_wc2, d1, d2, image1, image2, p_unit_sphere_or_plane, k)


def main():
    device = "cuda:0"
    dtype = torch.float32
    h = 10
    w = 10
    image1 = th.Variable(torch.rand(1, 1, h, w, dtype=dtype, device=device))
    image2 = th.Variable(torch.rand(1, 1, h, w, dtype=dtype, device=device))
    depth1 = th.Vector.rand(1, h * w, dtype=dtype, device=device, requires_grad=True)
    depth2 = th.Vector.rand(1, h * w, dtype=dtype, device=device, requires_grad=True)
    T_wc1 = th.SE3.rand(1, dtype=dtype, device=device, requires_grad=True)
    T_wc2 = th.SE3.rand(1, dtype=dtype, device=device, requires_grad=True)

    k = th.Variable(torch.tensor([481.20, -480.00, 319.50, 239.50], dtype=dtype, device=device).view(1, 4))
    with torch.no_grad():
        x_coords = torch.linspace(0, w - 1, w, device=k.device)
        y_coords = torch.linspace(0, h - 1, h, device=k.device)
        fx, fy, cx, cy = k[0, 0], k[0, 1], k[0, 2], k[0, 3]
        x_normalized = (x_coords - cx) / fx  # (w)
        y_normalized = (y_coords - cy) / fy  # (h)
        p_unit_sphere_or_plane = torch.empty((1, h, w, 3), dtype=k.dtype, device=k.device)
        p_unit_sphere_or_plane[..., 0] = x_normalized.view(1, 1, w)
        p_unit_sphere_or_plane[..., 1] = y_normalized.view(1, h, 1)
        p_unit_sphere_or_plane[..., 2] = 1
    p_unit_sphere_or_plane = th.Variable(p_unit_sphere_or_plane)  # (batch_size=1, h, w, 3)

    optim_vars = depth1, depth2
    aux_vars = image1, image2, p_unit_sphere_or_plane, k, T_wc1, T_wc2
    weight = th.ScaleCostWeight(torch.tensor(1., dtype=dtype, device=device))
    cost_function = th.AutoDiffCostFunction(optim_vars, photometric_error_fix_pose, 1, weight,
                                            aux_vars=aux_vars, autograd_mode="dense")
    objective = th.Objective(dtype=k.dtype)
    objective.device = device
    objective.add(cost_function)
    optimizer = th.LevenbergMarquardt(
        objective,
        linear_solver_cls=th.BaspachoSparseSolver,
        linearization_cls=th.SparseLinearization,
    )
    theseus_layer = th.TheseusLayer(optimizer)
    theseus_outputs, info = theseus_layer.forward(optimizer_kwargs={
        "verbose": True,
        "adaptive_damping": True})


if __name__ == "__main__":
    main()
Traceback (most recent call last):
  File "/home/huangkun/Git/lora_slam/test/reproduce.py", line 98, in <module>
    main()
  File "/home/huangkun/Git/lora_slam/test/reproduce.py", line 86, in main
    optimizer = th.LevenbergMarquardt(
  File "/home/huangkun/Git/theseus/theseus/optimizer/nonlinear/levenberg_marquardt.py", line 69, in __init__
    super().__init__(
  File "/home/huangkun/Git/theseus/theseus/optimizer/nonlinear/nonlinear_least_squares.py", line 90, in __init__
    self.linear_solver = linear_solver_cls(
  File "/home/huangkun/Git/theseus/theseus/optimizer/linear/baspacho_sparse_solver.py", line 45, in __init__
    self.reset()
  File "/home/huangkun/Git/theseus/theseus/optimizer/linear/baspacho_sparse_solver.py", line 111, in reset
    self.symbolic_decomposition = SymbolicDecomposition(
RuntimeError: Expected device == "cpu" || device == "cuda" to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

baspacho_sparse_solver only accept device name as "cpu" or "cuda", but mine is "cuda:0"

Originally posted by @EXing in #596 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant