Skip to content

Commit

Permalink
Added a test to check that computed vectorized errors are correct.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Jun 6, 2022
1 parent fc13ffb commit eadbf65
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
47 changes: 47 additions & 0 deletions theseus/core/tests/test_vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

import theseus as th
Expand Down Expand Up @@ -140,3 +141,49 @@ def test_correct_schemas_and_shared_vars():
assert len(cost_fns) == 1
seen_cnt[6] += 1
assert seen_cnt == [1] * 7


def test_vectorized_error():
rng = np.random.default_rng(0)
generator = torch.Generator()
generator.manual_seed(0)
for _ in range(20):
dim = rng.choice([1, 2])
objective = th.Objective()
batch_size = rng.choice(range(1, 11))

vectors = [
th.Vector(
data=torch.randn(batch_size, dim, generator=generator), name=f"v{i}"
)
for i in range(rng.choice([1, 10]))
]
target = th.Vector(dim, name="target")
w = th.ScaleCostWeight(torch.randn(1, generator=generator))
for v in vectors:
objective.add(th.Difference(v, w, target))

se3s = [
th.SE3(
data=th.SE3.rand(batch_size, generator=generator).data,
requires_check=False,
)
for i in range(rng.choice([1, 10]))
]
s_target = th.SE3.rand(1, generator=generator)
ws = th.DiagonalCostWeight(torch.randn(6, generator=generator))
# ws = th.ScaleCostWeight(torch.randn(1, generator=generator))
for s in se3s:
objective.add(th.Difference(s, ws, s_target))

vectorization = th.Vectorize(objective)
objective.update()

assert objective._cost_functions_iterable is vectorization._cost_fn_wrappers
for w in vectorization._cost_fn_wrappers:
for cost_fn in objective.cost_functions.values():
if cost_fn is w.cost_fn:
w_jac, w_err = cost_fn.weighted_jacobians_error()
assert w._cached_error.allclose(w_err)
for jac, exp_jac in zip(w._cached_jacobians, w_jac):
assert jac.allclose(exp_jac, atol=1e-6)
3 changes: 3 additions & 0 deletions theseus/core/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def _update_vectorized_vars(
else:
var.update(torch.cat(names_to_data[name], dim=0))

# Computes the error of the vectorized cost function and distributes the error
# to the cost function wrappers. The list of wrappers must correspond to the
# same schema from which the vectorized cost function was obtained.
@staticmethod
def _compute_error_and_replace_wrapper_caches(
vectorized_cost_fn: CostFunction,
Expand Down

0 comments on commit eadbf65

Please sign in to comment.