Skip to content

Commit

Permalink
Add a unit test to check the PGO loss values automatically (#456)
Browse files Browse the repository at this point in the history
* Made profiler optional in pose_graph_synthetic script.

* Changed pgo synthetic script to return losses and made savemat optional.

* Added a test to check the PGO loss values.

* Fix incorrect torch device in PGO synthetic script.

* Simplified linear solver cls config.

* Added other linear solvers to the PGO test.

* Clean up time/mem measurement code to avoid errors when CUDA is not available.

* Add test for baspacho.
  • Loading branch information
luisenp committed Feb 24, 2023
1 parent 6cb8272 commit f03c5aa
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 109 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ run_other_unit_tests: &run_other_unit_tests
working_directory: ~/project
command: |
pip install -e ".[dev]"
python -m pytest tests -m "not cudaext" --ignore=tests/geometry/ --ignore=tests/optimizer --ignore-glob=tests/test_theseus_layer.py --ignore=tests/labs/lie
python -m pytest tests -m "not cudaext" -s --ignore=tests/geometry/ --ignore=tests/optimizer --ignore-glob=tests/test_theseus_layer.py --ignore=tests/labs/lie
run_gpu_tests: &run_gpu_tests
- run:
Expand All @@ -222,6 +222,7 @@ run_gpu_tests: &run_gpu_tests
command: |
pytest -s tests/test_theseus_layer.py
pytest -s tests -m "cudaext"
pytest -s tests/test_pgo_benchmark.py -s
build_cuda11_wheel: &build_cuda11_wheel
- run:
Expand Down
4 changes: 4 additions & 0 deletions examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
4 changes: 4 additions & 0 deletions examples/configs/pose_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
8 changes: 4 additions & 4 deletions examples/configs/pose_graph/pose_graph_synthetic.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
seed: 1
solver_device: "cuda"
solver_type: "lucuda"
device: "cuda"
profile: True
savemat: True

num_poses: 256
rotation_noise: 0.05
Expand All @@ -10,16 +11,15 @@ loop_closure_outlier_ratio: 0.25
dataset_size: 256
batch_size: 128


inner_optim:
optimizer_cls: LevenbergMarquardt
linear_solver_cls: LUCudaSparseSolver
optimizer_kwargs:
backward_mode: implicit
verbose: true
track_err_history: true
__keep_final_step_size__: true
adaptive_damping: true
solver: sparse
max_iters: 10
step_size: 0.75
regularize: true
Expand Down
211 changes: 110 additions & 101 deletions examples/pose_graph/pose_graph_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import cProfile
import io
import logging
import os
import pathlib
import pstats
import random
import subprocess
Expand All @@ -21,13 +19,35 @@

import theseus as th
import theseus.utils.examples as theg
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer.linearization import Linearization
from theseus.utils import Timer

# Logger
log = logging.getLogger(__name__)


# Simple wrapper to make cProfile profiler optional
class Profiler:
def __init__(self, c_profiler: cProfile.Profile, active: bool):
self.c_profiler = c_profiler
self.active = active

def enable(self):
if self.active:
self.c_profiler.enable()

def disable(self):
if self.active:
self.c_profiler.disable()

def print(self):
if self.active:
s = io.StringIO()
sortby = pstats.SortKey.CUMULATIVE
ps = pstats.Stats(self.c_profiler, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())


def print_histogram(
pg: theg.PoseGraphDataset, var_dict: Dict[str, torch.Tensor], msg: str
):
Expand Down Expand Up @@ -75,32 +95,45 @@ def pose_loss(
return loss


def run(
cfg: omegaconf.OmegaConf, pg: theg.PoseGraphDataset, results_path: pathlib.Path
):
device = torch.device("cuda")
dtype = torch.float64
pr = cProfile.Profile()
def _maybe_reset_cuda_peak_mem():
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()

LINEARIZATION_MODE: Dict[str, Type[Linearization]] = {
"sparse": th.SparseLinearization,
"dense": th.DenseLinearization,
}

LINEAR_SOLVER_MODE: Dict[str, Type[LinearSolver]] = {
"sparse": cast(
Type[LinearSolver],
(
th.BaspachoSparseSolver
if cast(str, cfg.solver_type) == "baspacho"
else th.LUCudaSparseSolver
)
if cast(str, cfg.solver_device) == "cuda"
else th.CholmodSparseSolver,
),
"dense": th.CholeskyDenseSolver,
}
def _maybe_get_cuda_max_mem_alloc():
return (
torch.cuda.max_memory_allocated() / 1048576
if torch.cuda.is_available()
else torch.nan
)


def run(cfg: omegaconf.OmegaConf):
log.info((subprocess.check_output("lscpu", shell=True).strip()).decode())

torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)

# create (or load) dataset
rng = torch.Generator()
rng.manual_seed(0)
dtype = torch.float64
pg, _ = theg.PoseGraphDataset.generate_synthetic_3D(
num_poses=cfg.num_poses,
rotation_noise=cfg.rotation_noise,
translation_noise=cfg.translation_noise,
loop_closure_ratio=cfg.loop_closure_ratio,
loop_closure_outlier_ratio=cfg.loop_closure_outlier_ratio,
batch_size=cfg.batch_size,
dataset_size=cfg.dataset_size,
generator=rng,
dtype=dtype,
)

device = torch.device(cfg.device)
dtype = torch.float64
profiler = Profiler(cProfile.Profile(), cfg.profile)
pg.to(device=device)

with torch.no_grad():
Expand All @@ -118,11 +151,6 @@ def run(
pose_indices: List[int] = [index for index, _ in enumerate(pg_batch.poses)]
gt_pose_indices: List[int] = []

forward_times = []
backward_times = []
forward_mems = []
backward_mems = []

for edge in pg_batch.edges:
relative_pose_cost = th.Between(
pg_batch.poses[edge.i],
Expand Down Expand Up @@ -176,8 +204,7 @@ def run(
objective,
max_iterations=cfg.inner_optim.max_iters,
step_size=cfg.inner_optim.step_size,
linearization_cls=LINEARIZATION_MODE[cast(str, cfg.inner_optim.solver)],
linear_solver_cls=LINEAR_SOLVER_MODE[cast(str, cfg.inner_optim.solver)],
linear_solver_cls=getattr(th, cfg.inner_optim.linear_solver_cls),
)

# Set up Theseus layer
Expand All @@ -198,8 +225,6 @@ def run(

def run_batch(batch_idx: int):
log.info(f" ------------------- Batch {batch_idx} ------------------- ")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

pg_batch = pg.get_batch_dataset(batch_idx=batch_idx)
theseus_inputs = get_batch_data(pg_batch, pose_indices, gt_pose_indices)
Expand All @@ -208,37 +233,34 @@ def run_batch(batch_idx: int):
with torch.no_grad():
pose_loss_ref = pose_loss(pg_batch.poses, pg_batch.gt_poses)

start_event.record()
torch.cuda.reset_peak_memory_stats()
pr.enable()
theseus_outputs, _ = theseus_optim.forward(
input_tensors=theseus_inputs,
optimizer_kwargs={**cfg.inner_optim.optimizer_kwargs},
)
pr.disable()
end_event.record()

torch.cuda.synchronize()
forward_time = start_event.elapsed_time(end_event)
forward_mem = torch.cuda.max_memory_allocated() / 1048576
timer = Timer(device)
with timer:
_maybe_reset_cuda_peak_mem()
profiler.enable()
theseus_outputs, _ = theseus_optim.forward(
input_tensors=theseus_inputs,
optimizer_kwargs={**cfg.inner_optim.optimizer_kwargs},
)
profiler.disable()
forward_time = 1000 * timer.elapsed_time
forward_mem = _maybe_get_cuda_max_mem_alloc()
log.info(f"Forward pass took {forward_time} ms.")
log.info(f"Forward pass used {forward_mem} MBs.")

start_event.record()
torch.cuda.reset_peak_memory_stats()
pr.enable()
model_optimizer.zero_grad()
loss = (pose_loss(pose_vars, pg_batch.gt_poses) - pose_loss_ref) / pose_loss_ref
loss.backward()
model_optimizer.step()
backward_mem = torch.cuda.max_memory_allocated() / 1048576
pr.disable()
end_event.record()

torch.cuda.synchronize()
backward_time = start_event.elapsed_time(end_event)
log.info(f"Forward pass used {forward_mem} GPU MBs.")

with timer:
_maybe_reset_cuda_peak_mem()
profiler.enable()
model_optimizer.zero_grad()
loss = (
pose_loss(pose_vars, pg_batch.gt_poses) - pose_loss_ref
) / pose_loss_ref
loss.backward()
model_optimizer.step()
backward_mem = _maybe_get_cuda_max_mem_alloc()
profiler.disable()
backward_time = 1000 * timer.elapsed_time
log.info(f"Backward pass took {backward_time} ms.")
log.info(f"Backward pass used {backward_mem} MBs.")
log.info(f"Backward pass used {backward_mem} GPU MBs.")

loss_value = torch.sum(loss.detach()).item()
log.info(
Expand All @@ -249,77 +271,64 @@ def run_batch(batch_idx: int):

print_histogram(pg_batch, theseus_outputs, "Output histogram:")

return [forward_time, backward_time, forward_mem, backward_mem]
return [forward_time, backward_time, forward_mem, backward_mem, loss.item()]

forward_times = []
backward_times = []
forward_mems = []
backward_mems = []
losses = []
for epoch in range(num_epochs):
log.info(f" ******************* EPOCH {epoch} ******************* ")

forward_time_epoch = []
backward_time_epoch = []
forward_mem_epoch = []
backward_mem_epoch = []

losses_epoch = []
for batch_idx in range(pg.num_batches):
if batch_idx == cfg.outer_optim.max_num_batches:
break
forward_time, backward_time, forward_mem, backward_mem = run_batch(
forward_time, backward_time, forward_mem, backward_mem, loss = run_batch(
batch_idx
)

forward_time_epoch.append(forward_time)
backward_time_epoch.append(backward_time)
forward_mem_epoch.append(forward_mem)
backward_mem_epoch.append(backward_mem)
losses_epoch.append(loss)

forward_times.append(forward_time_epoch)
backward_times.append(backward_time_epoch)
forward_mems.append(forward_mem_epoch)
backward_mems.append(backward_mem_epoch)
losses.append(losses_epoch)

results = omegaconf.OmegaConf.to_container(cfg)
results["forward_time"] = forward_times
results["backward_time"] = backward_times
results["forward_mem"] = forward_mems
results["backward_mem"] = backward_mems
file = (
f"pgo_{cfg.solver_device}_{cfg.inner_optim.solver}_{cfg.num_poses}_"
fname = (
f"pgo_{cfg.device}_{cfg.inner_optim.linear_solver_cls.lower()}_{cfg.num_poses}_"
f"{cfg.dataset_size}_{cfg.batch_size}.mat"
)
savemat(file, results)
print(fname)
if cfg.savemat:
savemat(fname, results)

s = io.StringIO()
sortby = pstats.SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
profiler.print()
return losses


@hydra.main(config_path="../configs/pose_graph", config_name="pose_graph_synthetic")
@hydra.main(
config_path="../configs/pose_graph",
config_name="pose_graph_synthetic",
version_base="1.1",
)
def main(cfg):
log.info((subprocess.check_output("lscpu", shell=True).strip()).decode())

torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)

# create (or load) dataset
rng = torch.Generator()
rng.manual_seed(0)
dtype = torch.float64
pg, _ = theg.PoseGraphDataset.generate_synthetic_3D(
num_poses=cfg.num_poses,
rotation_noise=cfg.rotation_noise,
translation_noise=cfg.translation_noise,
loop_closure_ratio=cfg.loop_closure_ratio,
loop_closure_outlier_ratio=cfg.loop_closure_outlier_ratio,
batch_size=cfg.batch_size,
dataset_size=cfg.dataset_size,
generator=rng,
dtype=dtype,
)

results_path = pathlib.Path(os.getcwd())
run(cfg, pg, results_path)
run(cfg)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f03c5aa

Please sign in to comment.