Skip to content

Commit

Permalink
Add support for masking jacobians of zero weights in the batch (#398)
Browse files Browse the repository at this point in the history
* Added CostWeight.is_zero() method.

* Add a masked_variable context for temporarily mask variables' tensors.

* Add logic to skip jacobians computation for zero weights in batch.

* Enable masked jacobians in vectorization.

* Detached zero mask computation and using smaller EPS.

* Added is_zero for GPCostWeight.

* Changed scale and diagonal weight is_zero to use == 0.
  • Loading branch information
luisenp committed Jan 13, 2023
1 parent 556d6ba commit f77c30c
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 55 deletions.
6 changes: 6 additions & 0 deletions tests/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def weight_jacobians_and_error(self, jacobians, error):
def _copy_impl(self, new_name=None):
return MockCostWeight(self.the_data.copy(), name=new_name)

def is_zero(self):
raise NotImplementedError


class NullCostWeight(th.CostWeight):
def __init__(self):
Expand All @@ -97,6 +100,9 @@ def weight_jacobians_and_error(self, jacobians, error):
def _copy_impl(self, new_name=None):
return NullCostWeight()

def is_zero(self):
raise NotImplementedError


class MockCostFunction(th.CostFunction):
def __init__(
Expand Down
141 changes: 91 additions & 50 deletions tests/core/test_robust_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest

import torch

import theseus as th
from tests.core.common import BATCH_SIZES_TO_TEST


def _new_robust_cf(batch_size, loss_cls, generator) -> th.RobustCostFunction:
def _new_robust_cf(
batch_size, loss_cls, generator, masked_weight=False
) -> th.RobustCostFunction:
v1 = th.rand_se3(batch_size, generator=generator)
v2 = th.rand_se3(batch_size, generator=generator)
w = th.ScaleCostWeight(torch.randn(1, generator=generator))
if masked_weight:
mask = torch.randint(2, (batch_size, 1), generator=generator).bool()
assert mask.any()
assert not mask.all()
w_tensor = torch.randn(batch_size, 1, generator=generator) * mask
else:
w_tensor = torch.randn(1, generator=generator)
w = th.ScaleCostWeight(w_tensor)
cf = th.Local(v1, v2, w)
ll_radius = th.Variable(tensor=torch.randn(1, 1, generator=generator))
return th.RobustCostFunction(cf, loss_cls, ll_radius)
Expand All @@ -38,60 +48,91 @@ def test_robust_cost_weighted_error():
expected_rho2 = loss_cls.evaluate(
(e * e).sum(dim=1, keepdim=True), robust_cf.log_loss_radius.tensor
)
assert rho2.allclose(expected_rho2)
torch.testing.assert_close(rho2, expected_rho2)


def test_robust_cost_grad_form():
@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("loss_cls", [th.WelschLoss, th.HuberLoss])
def test_robust_cost_grad_form(batch_size, loss_cls):
generator = torch.Generator()
generator.manual_seed(0)
for _ in range(10):
for batch_size in BATCH_SIZES_TO_TEST:
for loss_cls in [th.WelschLoss, th.HuberLoss]:
robust_cf = _new_robust_cf(batch_size, loss_cls, generator)
cf = robust_cf.cost_function
jacs, e = cf.weighted_jacobians_error()
cf_grad = _grad(jacs[0], e)
e_norm = (e * e).sum(1, keepdim=True)
rho_prime = loss_cls.linearize(e_norm, robust_cf.log_loss_radius.tensor)
# `weighted_jacobians_error()` is written so that it results in a
# gradient equal to drho_de2 * J^T * e, which in the code is
# `rho_prime * cf_grad`.
expected_grad = rho_prime.view(-1, 1, 1) * cf_grad
rescaled_jac, rescaled_e = robust_cf.weighted_jacobians_error()
grad = _grad(rescaled_jac[0], rescaled_e)
assert grad.allclose(expected_grad, atol=1e-6)


def test_robust_cost_jacobians():
robust_cf = _new_robust_cf(batch_size, loss_cls, generator)
cf = robust_cf.cost_function
jacs, e = cf.weighted_jacobians_error()
cf_grad = _grad(jacs[0], e)
e_norm = (e * e).sum(1, keepdim=True)
rho_prime = loss_cls.linearize(e_norm, robust_cf.log_loss_radius.tensor)
# `weighted_jacobians_error()` is written so that it results in a
# gradient equal to drho_de2 * J^T * e, which in the code is
# `rho_prime * cf_grad`.
expected_grad = rho_prime.view(-1, 1, 1) * cf_grad
rescaled_jac, rescaled_e = robust_cf.weighted_jacobians_error()
grad = _grad(rescaled_jac[0], rescaled_e)
torch.testing.assert_close(grad, expected_grad, atol=1e-6, rtol=1e-6)


@pytest.mark.parametrize("batch_size", BATCH_SIZES_TO_TEST)
@pytest.mark.parametrize("loss_cls", [th.WelschLoss, th.HuberLoss])
def test_robust_cost_jacobians(batch_size, loss_cls):
generator = torch.Generator()
generator.manual_seed(0)

for _ in range(10):
for batch_size in BATCH_SIZES_TO_TEST:
for loss_cls in [th.WelschLoss, th.HuberLoss]:
robust_cf = _new_robust_cf(batch_size, loss_cls, generator)
v1, v2 = robust_cf.cost_function.var, robust_cf.cost_function.target
v_aux = v1.copy()
ll_radius = robust_cf.log_loss_radius
w = robust_cf.cost_function.weight

def test_fn(v_data):
v_aux.update(v_data)
new_robust_cf = th.RobustCostFunction(
th.Local(v_aux, v2, w), loss_cls, ll_radius
)
e = new_robust_cf.cost_function.weighted_error()
e_norm = (e * e).sum(1, keepdim=True)
return loss_cls.evaluate(e_norm, ll_radius.tensor) / 2.0

aux_id = torch.arange(batch_size)
grad_raw_dense = torch.autograd.functional.jacobian(
test_fn, (v1.tensor,)
)[0]
grad_raw_sparse = grad_raw_dense[aux_id, :, aux_id]
expected_grad = v1.project(grad_raw_sparse, is_sparse=True)

rescaled_jac, rescaled_err = robust_cf.weighted_jacobians_error()
grad = _grad(rescaled_jac[0], rescaled_err)

assert grad.allclose(expected_grad, atol=1e-2)
robust_cf = _new_robust_cf(batch_size, loss_cls, generator)
v1, v2 = robust_cf.cost_function.var, robust_cf.cost_function.target
v_aux = v1.copy()
ll_radius = robust_cf.log_loss_radius
w = robust_cf.cost_function.weight

def test_fn(v_data):
v_aux.update(v_data)
new_robust_cf = th.RobustCostFunction(
th.Local(v_aux, v2, w), loss_cls, ll_radius
)
e = new_robust_cf.cost_function.weighted_error()
e_norm = (e * e).sum(1, keepdim=True)
return loss_cls.evaluate(e_norm, ll_radius.tensor) / 2.0

aux_id = torch.arange(batch_size)
grad_raw_dense = torch.autograd.functional.jacobian(test_fn, (v1.tensor,))[0]
grad_raw_sparse = grad_raw_dense[aux_id, :, aux_id]
expected_grad = v1.project(grad_raw_sparse, is_sparse=True)

rescaled_jac, rescaled_err = robust_cf.weighted_jacobians_error()
grad = _grad(rescaled_jac[0], rescaled_err)

torch.testing.assert_close(grad, expected_grad, atol=1e-2, rtol=1e-2)


def test_masked_jacobians_called(monkeypatch):
rng = torch.Generator()
rng.manual_seed(0)
robust_cf = _new_robust_cf(128, th.WelschLoss, rng, masked_weight=True)
robust_cf._supports_masking = True

called = [False]

def masked_jacobians_mock(cost_fn, mask):
called[0] = True
return cost_fn.jacobians()

monkeypatch.setattr(
th.core.cost_function, "masked_jacobians", masked_jacobians_mock
)
robust_cf.weighted_jacobians_error()
assert called[0]


@pytest.mark.parametrize("loss_cls", [th.WelschLoss, th.HuberLoss])
def test_mask_jacobians(loss_cls):
batch_size = 512
rng = torch.Generator()
rng.manual_seed(0)
robust_cf = _new_robust_cf(batch_size, loss_cls, rng, masked_weight=True)
jac_expected, err_expected = robust_cf.weighted_jacobians_error()
robust_cf._supports_masking = True
jac, err = robust_cf.weighted_jacobians_error()
torch.testing.assert_close(err, err_expected)
for j1, j2 in zip(jac, jac_expected):
torch.testing.assert_close(j1, j2)
86 changes: 86 additions & 0 deletions tests/core/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,89 @@ def test_vectorized_retract():

for v1, v2 in zip(variables, variables_vectorized):
assert v1.tensor.allclose(v2.tensor)


# This solves a very simple objective of the form sum (wi * (xi - ti)) **2, where
# some wi can be zero with some probability. When vectorize=True, our vectorization
# class will compute masked batched jacobians. So, this function can be used to test
# that the solution is the same when this feature is on/off. We also check if we
# can do a backward pass when this masking is used.
def _solve_fn_for_masked_jacobians(
batch_size, dof, num_costs, weight_cls, vectorize, device
):
rng = torch.Generator()
rng.manual_seed(batch_size)
obj = th.Objective()
variables = [th.Vector(dof=dof, name=f"x{i}") for i in range(num_costs)]
targets = [
th.Vector(tensor=torch.randn(batch_size, dof, generator=rng), name=f"t{i}")
for i in range(num_costs)
]
base_tensor = torch.ones(
batch_size, dof if weight_cls == th.DiagonalCostWeight else 1, device=device
)
# Wrapped into a param to pass to torch optimizer if necessary
params = [
torch.nn.Parameter(
base_tensor.clone() * (torch.rand(1, generator=rng).item() > 0.9)
)
for _ in range(num_costs)
]
weights = [weight_cls(params[i]) for i in range(num_costs)]
for i in range(num_costs):
obj.add(th.Difference(variables[i], targets[i], weights[i], name=f"cf{i}"))

input_tensors = {
f"x{i}": torch.ones(batch_size, dof, device=device) for i in range(num_costs)
}
layer = th.TheseusLayer(
th.LevenbergMarquardt(obj, step_size=0.1, max_iterations=5),
vectorize=vectorize,
)
layer.to(device=device)
sol, _ = layer.forward(input_tensors)

# Check that we can backprop through this without errors
if vectorize:
optim = torch.optim.Adam(params, lr=1e-4)
for _ in range(5): # do a few steps
optim.zero_grad()
layer.forward(input_tensors)
loss = obj.error_squared_norm().sum()
loss.backward()
optim.step()

return sol


@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dof", [1, 4])
@pytest.mark.parametrize("num_costs", [1, 64])
@pytest.mark.parametrize("weight_cls", [th.ScaleCostWeight, th.DiagonalCostWeight])
def test_masked_jacobians(batch_size, dof, num_costs, weight_cls):
device = "cuda:0" if torch.cuda.is_available() else "cpu"

sol1 = _solve_fn_for_masked_jacobians(
batch_size, dof, num_costs, weight_cls, True, device
)
sol2 = _solve_fn_for_masked_jacobians(
batch_size, dof, num_costs, weight_cls, False, device
)

for i in range(num_costs):
torch.testing.assert_close(sol1[f"x{i}"], sol2[f"x{i}"])


def test_masked_jacobians_called(monkeypatch):
called = [False]

def masked_jacobians_mock(cost_fn, mask):
called[0] = True
return cost_fn.jacobians()

monkeypatch.setattr(
th.core.cost_function, "masked_jacobians", masked_jacobians_mock
)

_solve_fn_for_masked_jacobians(128, 2, 16, th.ScaleCostWeight, True, "cpu")
assert called[0]
2 changes: 2 additions & 0 deletions tests/embodied/motionmodel/test_double_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_gp_motion_model_cost_weight_weights():

def test_gp_motion_model_cost_weight_copy():
q_inv = torch.randn(10, 2, 2)
q_inv = torch.bmm(q_inv.transpose(1, 2), q_inv) # make it pos. def.
dt = torch.rand(1)
cost_weight = th.eb.GPCostWeight(q_inv, dt, name="gp")
check_another_theseus_function_is_copy(
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_gp_motion_model_cost_function_error_vector_vars():
dt = th.Variable(torch.rand(1).double())

q_inv = torch.randn(batch_size, dof, dof).double()
q_inv = torch.bmm(q_inv.transpose(1, 2), q_inv) # make it pos. def.
# won't be used for the test, but it's required by cost_function's constructor
cost_weight = th.eb.GPCostWeight(q_inv, dt)
cost_function = th.eb.GPMotionModel(
Expand Down
3 changes: 3 additions & 0 deletions tests/optimizer/linearization_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def weight_jacobians_and_error(self, jacobians, error):
def _copy_impl(self, new_name=None):
raise NotImplementedError

def is_zero(self):
raise NotImplementedError


def build_test_objective_and_linear_system():
# This function creates the an objective that results in the
Expand Down
2 changes: 2 additions & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Vectorize,
WelschLoss,
as_variable,
masked_jacobians,
masked_variables,
)
from .geometry import ( # usort: skip
LieGroup,
Expand Down
10 changes: 8 additions & 2 deletions theseus/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .cost_function import AutoDiffCostFunction, AutogradMode, CostFunction, ErrFnType
from .cost_function import (
AutoDiffCostFunction,
AutogradMode,
CostFunction,
ErrFnType,
masked_jacobians,
)
from .cost_weight import CostWeight, DiagonalCostWeight, ScaleCostWeight
from .objective import Objective
from .robust_cost_function import RobustCostFunction
from .robust_loss import HuberLoss, RobustLoss, WelschLoss
from .variable import Variable, as_variable
from .variable import Variable, as_variable, masked_variables
from .vectorizer import Vectorize
Loading

0 comments on commit f77c30c

Please sign in to comment.