Skip to content

Commit

Permalink
Added logic to automatically pick number of solver contexts for LUCuda.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Dec 1, 2022
1 parent 77fab0a commit cea6676
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 46 deletions.
9 changes: 5 additions & 4 deletions tests/optimizer/linear/test_lu_cuda_sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def check_sparse_solver_multistep(batch_size: int, test_exception: bool):
void_objective = th.Objective()
void_ordering = th.VariableOrdering(void_objective, default_order=False)
solver = th.LUCudaSparseSolver(
void_objective,
linearization_kwargs={"ordering": void_ordering},
num_solver_contexts=(num_steps - 1) if test_exception else num_steps,
void_objective, linearization_kwargs={"ordering": void_ordering}
)
linearization = solver.linearization

Expand All @@ -100,7 +98,10 @@ def check_sparse_solver_multistep(batch_size: int, test_exception: bool):
linearization.A_row_ptr = row_ptr

# Only need this line for the test since the objective is a mock
solver.reset(batch_size=batch_size)
solver.reset(
batch_size=batch_size,
num_solver_contexts=(num_steps - 1) if test_exception else num_steps,
)

As = [
torch.randn(
Expand Down
9 changes: 5 additions & 4 deletions theseus/optimizer/linear/baspacho_sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
objective: Objective,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
num_solver_contexts: int = 1,
dev: DeviceType = DEFAULT_DEVICE,
**kwargs,
):
Expand All @@ -41,12 +40,14 @@ def __init__(
super().__init__(objective, linearization_cls, linearization_kwargs, **kwargs)
self.linearization: SparseLinearization = self.linearization

self._num_solver_contexts: int = num_solver_contexts

self._has_been_reset = False
if self.linearization.structure().num_rows:
self.reset(dev)

def reset(self, dev: DeviceType = DEFAULT_DEVICE):
def reset(self, dev: DeviceType = DEFAULT_DEVICE, **kwargs):
if self._has_been_reset:
return
self._has_been_reset = True
if dev == "cuda" and not torch.cuda.is_available():
raise RuntimeError(
"BaspachoSparseSolver: Cuda requested (dev='cuda') but not\n"
Expand Down
4 changes: 4 additions & 0 deletions theseus/optimizer/linear/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
objective, **linearization_kwargs
)

# Deliberately not abstract since some solvers don't need this
def reset(self, **kwargs):
pass

@abc.abstractmethod
def solve(
self, damping: Optional[Union[float, torch.Tensor]] = None, **kwargs
Expand Down
65 changes: 52 additions & 13 deletions theseus/optimizer/linear/lu_cuda_sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import dataclasses
from typing import Any, Dict, List, Optional, Type, Union

import torch
Expand All @@ -16,12 +18,36 @@


class LUCudaSparseSolver(LinearSolver):
# Class for keeping track of the inputs used when `reset()`` is called
# Mostly useful for the `fill_defaults()` method
@dataclasses.dataclass
class ResetCtx:
_DEFAULTS = {"batch_size": 16, "num_solver_contexts": 1}

batch_size: Optional[int]
num_solver_contexts: Optional[int]

def fill_defaults(
self, another_ctx: "LUCudaSparseSolver.ResetCtx"
) -> "LUCudaSparseSolver.ResetCtx":
my_fields = dataclasses.asdict(self)
other_fields = dataclasses.asdict(another_ctx)
new_ctx_fields = copy.copy(LUCudaSparseSolver.ResetCtx._DEFAULTS)
for k, v in my_fields.items():
if v is None:
other_v = other_fields[k]
if other_v is not None:
new_ctx_fields[k] = other_v
else:
new_ctx_fields[k] = v
return LUCudaSparseSolver.ResetCtx(**new_ctx_fields)

def __init__(
self,
objective: Objective,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
num_solver_contexts: int = 1,
num_solver_contexts: Optional[int] = None,
batch_size: Optional[int] = None,
auto_reset: bool = True,
**kwargs,
Expand All @@ -39,21 +65,33 @@ def __init__(
super().__init__(objective, linearization_cls, linearization_kwargs, **kwargs)
self.linearization: SparseLinearization = self.linearization

self._num_solver_contexts: int = num_solver_contexts

self._last_reset_ctx = LUCudaSparseSolver.ResetCtx(None, None)
if self.linearization.structure().num_rows:
if batch_size is not None:
self.reset(batch_size=batch_size)
else:
self.reset()
self.reset(batch_size=batch_size, num_solver_contexts=num_solver_contexts)

self._objective = objective
self._auto_reset = auto_reset

def reset(self, batch_size: int = 16):
def reset(
self,
batch_size: Optional[int] = None,
num_solver_contexts: Optional[int] = None,
**kwargs,
):
# For any inputs that are None, this tries to set their values to
# that used in the last call to `reset()`. If that value is also None,
# (i.e., reset has never been called before) then it fills them
# with the base default values
ctx = LUCudaSparseSolver.ResetCtx(
batch_size, num_solver_contexts
).fill_defaults(self._last_reset_ctx)
# As a consequence of the above, reset() is only run if it either has
# never been run before, or if at least one of the parameters is
# explicitly requested to be different from those used in the last reset
if ctx == self._last_reset_ctx:
return
if not torch.cuda.is_available():
raise RuntimeError("Cuda not available, LUCudaSparseSolver cannot be used")

try:
from theseus.extlib.cusolver_lu_solver import CusolverLUSolver
except Exception as e:
Expand All @@ -80,14 +118,15 @@ def reset(self, batch_size: int = 16):
AtA_col_ind = torch.tensor(AtA_mock.indices, dtype=torch.int32).cuda()
self._solver_contexts: List[CusolverLUSolver] = [
CusolverLUSolver(
batch_size,
ctx.batch_size,
AtA_mock.shape[1],
AtA_row_ptr,
AtA_col_ind,
)
for _ in range(self._num_solver_contexts)
for _ in range(ctx.num_solver_contexts)
]
self._last_solver_context: int = self._num_solver_contexts - 1
self._last_solver_context: int = ctx.num_solver_contexts - 1
self._last_reset_ctx = ctx

def solve(
self,
Expand All @@ -106,7 +145,7 @@ def solve(

self._last_solver_context = (
self._last_solver_context + 1
) % self._num_solver_contexts
) % self._last_reset_ctx.num_solver_contexts

if damping is None:
damping_alpha_beta = None
Expand Down
66 changes: 41 additions & 25 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from theseus.core import Objective
from theseus.optimizer import Linearization, Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer.linear import LinearSolver, LUCudaSparseSolver


@dataclass
Expand Down Expand Up @@ -416,6 +416,29 @@ def _optimize_loop(
] = NonlinearOptimizerStatus.MAX_ITERATIONS
return iters_done

# Returns how many iterations to do with and without autograd
def _split_backward_iters(self, **kwargs) -> Tuple[int, int]:
if kwargs["backward_mode"] == BackwardMode.TRUNCATED:
if "backward_num_iterations" not in kwargs:
raise ValueError("backward_num_iterations expected but not received.")
if kwargs["backward_num_iterations"] > self.params.max_iterations:
warnings.warn(
f"Input backward_num_iterations "
f"(={kwargs['backward_num_iterations']}) > "
f"max_iterations (={self.params.max_iterations}). "
f"Using backward_num_iterations=max_iterations."
)
backward_num_iters = min(
kwargs["backward_num_iterations"], self.params.max_iterations
)
else:
backward_num_iters = {
BackwardMode.UNROLL: self.params.max_iterations,
BackwardMode.DLM: self.params.max_iterations,
BackwardMode.IMPLICIT: 1,
}[kwargs["backward_mode"]]
return backward_num_iters, self.params.max_iterations - backward_num_iters

# `track_best_solution` keeps a **detached** copy (as in no gradient info)
# of the best variables found, but it is optional to avoid unnecessary copying
# if this is not needed
Expand All @@ -432,8 +455,9 @@ def _optimize_impl(
end_iter_callback: Optional[EndIterCallbackType] = None,
**kwargs,
) -> OptimizerInfo:
self.reset(**kwargs)
backward_mode = BackwardMode.resolve(backward_mode)
kwargs_plus_bwd_mode = {**kwargs, **{"backward_mode": backward_mode}}
self.reset(**kwargs_plus_bwd_mode)
with torch.no_grad():
info = self._init_info(
track_best_solution, track_err_history, track_state_history
Expand All @@ -445,10 +469,13 @@ def _optimize_impl(
f"Error: {info.last_err.mean().item()}"
)

backward_num_iters, no_grad_num_iters = self._split_backward_iters(
**kwargs_plus_bwd_mode
)
if backward_mode in [BackwardMode.UNROLL, BackwardMode.DLM]:
self._optimize_loop(
start_iter=0,
num_iter=self.params.max_iterations,
num_iter=backward_num_iters,
info=info,
verbose=verbose,
truncated_grad_loop=False,
Expand All @@ -461,29 +488,10 @@ def _optimize_impl(
] = -1
return info
elif backward_mode in [BackwardMode.IMPLICIT, BackwardMode.TRUNCATED]:
if backward_mode == BackwardMode.IMPLICIT:
backward_num_iterations = 1
else:
if "backward_num_iterations" not in kwargs:
raise ValueError(
"backward_num_iterations expected but not received"
)
if kwargs["backward_num_iterations"] > self.params.max_iterations:
warnings.warn(
f"Input backward_num_iterations "
f"(={kwargs['backward_num_iterations']}) > "
f"max_iterations (={self.params.max_iterations}). "
f"Using backward_num_iterations=max_iterations."
)
backward_num_iterations = min(
kwargs["backward_num_iterations"], self.params.max_iterations
)

num_no_grad_iter = self.params.max_iterations - backward_num_iterations
with torch.no_grad():
# actual_num_iters could be < num_iter due to early convergence
no_grad_iters_done = self._optimize_loop(
num_iter=num_no_grad_iter,
num_iter=no_grad_num_iters,
info=info,
verbose=verbose,
truncated_grad_loop=False,
Expand All @@ -495,7 +503,7 @@ def _optimize_impl(
track_best_solution, track_err_history, track_state_history
)
grad_iters_done = self._optimize_loop(
num_iter=backward_num_iterations,
num_iter=backward_num_iters,
info=grad_loop_info,
verbose=verbose,
truncated_grad_loop=True,
Expand Down Expand Up @@ -575,7 +583,15 @@ def _step(
# problem. Optimizer loop will pass all optimizer kwargs to this method.
# Deliberately not abstract, since some optimizers might not need this
def reset(self, **kwargs) -> None:
pass
print(kwargs)
backward_num_iters, _ = self._split_backward_iters(**kwargs)
if (
isinstance(self.linear_solver, LUCudaSparseSolver)
and "num_solver_contexts" not in kwargs
):
# Auto set number of solver context for the given max iterations
kwargs["num_solver_contexts"] = backward_num_iters
self.linear_solver.reset(**kwargs)

# Called at the end of step() but before variables are update to their new values.
# This method can be used to update any internal state of the optimizer and
Expand Down

0 comments on commit cea6676

Please sign in to comment.