Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorization refactor #205

Merged
merged 27 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
424c714
Created a wrapper cost function class that combines the aux vars for …
luisenp Jun 1, 2022
ee9e235
Disabled support for optimization variables in cost weights.
luisenp Jun 1, 2022
ea74465
Changed Objective to iterate over CFWrapper if available, and Theseus…
luisenp Jun 1, 2022
1b3af0b
Added a Vectorizer class and moved CFWrappers there.
luisenp Jun 1, 2022
2d3f9d2
Renamed vectorizer as Vectorize, added logic to replace Objective ite…
luisenp Jun 1, 2022
6a146cb
Added a CostFunctionSchema -> List[CostFunction] to use for vectoriza…
luisenp Jun 2, 2022
6c6a887
_CostFunctionWrapper is now meant to just store a cached value coming…
luisenp Jun 2, 2022
77ac280
Added code to automatically compute shared vars in Vectorize.
luisenp Jun 2, 2022
31237da
Changed vectorized costs construction to ensure that their weight is …
luisenp Jun 2, 2022
d30e1af
Implemented main cost function vectorization logic.
luisenp Jun 6, 2022
36e89c7
Updated bug that was causing detached gradients.
luisenp Jun 6, 2022
376e8ef
Fixed invalid check in theseus end-to-end unit tests.
luisenp Jun 6, 2022
ae6db18
Added unit test for schema and shared var computation.
luisenp Jun 6, 2022
0a2ee0a
Added a test to check that computed vectorized errors are correct.
luisenp Jun 6, 2022
58cee83
Moved vectorization update call to base linearization class.
luisenp Jun 7, 2022
7e60f87
Changed code to allow batch_size > 1 in shared variables.
luisenp Jun 7, 2022
399bb90
Fixed unit test and added call to Objective.update() in update_vector…
luisenp Jun 7, 2022
10cbf1c
Added new private iterator for vectorized costs.
luisenp Jun 7, 2022
10b208a
Replaced _register_vars_in_list with TheseusFunction.register_vars.
luisenp Jun 9, 2022
db5f366
Renamed vectorize_cost_fns kwarg as vectorize.
luisenp Jun 9, 2022
bb83db3
Added license headers.
luisenp Jun 9, 2022
1d0cd20
Small refactor.
luisenp Jun 9, 2022
e902924
Fixed bug that was preventing vectorized costs to work with to(). End…
luisenp Jun 9, 2022
0ec439f
Renamed the private Objective cost function iterator to _get_iterator().
luisenp Jun 9, 2022
aab9ead
Renamed kwarg in register_vars.
luisenp Jun 9, 2022
e57f310
Set vectorize=True for inverse kinematics and backward tests.
luisenp Jun 9, 2022
d6a434f
Remove lingering comments.
luisenp Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Objective,
ScaleCostWeight,
Variable,
Vectorize,
)
from .geometry import (
SE2,
Expand Down
1 change: 1 addition & 0 deletions theseus/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .cost_weight import CostWeight, DiagonalCostWeight, ScaleCostWeight
from .objective import Objective
from .variable import Variable
from .vectorizer import Vectorize
14 changes: 2 additions & 12 deletions theseus/core/cost_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,12 @@ def __init__(
# this avoids doing aux_vars=[], which is a bad default since [] is mutable
aux_vars = aux_vars or []

def _register_vars_in_list(var_list_, is_optim=False):
for var_ in var_list_:
if hasattr(self, var_.name):
raise RuntimeError(f"Variable name {var_.name} is not allowed.")
setattr(self, var_.name, var_)
if is_optim:
self.register_optim_var(var_.name)
else:
self.register_aux_var(var_.name)

if len(optim_vars) < 1:
raise ValueError(
"AutodiffCostFunction must receive at least one optimization variable."
)
_register_vars_in_list(optim_vars, is_optim=True)
_register_vars_in_list(aux_vars, is_optim=False)
self.register_vars(optim_vars, is_optim_vars=True)
self.register_vars(aux_vars, is_optim_vars=False)

self._err_fn = err_fn
self._dim = dim
Expand Down
31 changes: 26 additions & 5 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch

Expand Down Expand Up @@ -61,6 +61,15 @@ def __init__(self, dtype: Optional[torch.dtype] = None):
# objective structure might break optimizer initialization).
self.current_version = 0

# ---- Callbacks for vectorization ---- #
# This gets replaced when cost function vectorization is used
self._cost_functions_iterable: Optional[Iterable[CostFunction]] = None

# Used to vectorize cost functions after update
self._vectorization_run: Optional[Callable] = None

self._vectorization_to: Optional[Callable] = None

def _add_function_variables(
self,
function: TheseusFunction,
Expand Down Expand Up @@ -160,15 +169,14 @@ def add(self, cost_function: CostFunction):
self.cost_functions_for_weights[cost_function.weight] = []

if cost_function.weight.num_optim_vars() > 0:
warnings.warn(
raise RuntimeError(
f"The cost weight associated to {cost_function.name} receives one "
"or more optimization variables. Differentiating cost "
"weights with respect to optimization variables is not currently "
"supported, thus jacobians computed by our optimizers will be "
"incorrect. You may want to consider moving the weight computation "
"inside the cost function, so that the cost weight only receives "
"auxiliary variables.",
RuntimeWarning,
"auxiliary variables."
)

self.cost_functions_for_weights[cost_function.weight].append(cost_function)
Expand Down Expand Up @@ -471,9 +479,20 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
batch_sizes.extend([v.data.shape[0] for v in self.aux_vars.values()])
self._batch_size = _get_batch_size(batch_sizes)

def update_vectorization(self):
if self._vectorization_run is not None:
if self._batch_size is None:
self.update()
self._vectorization_run()

# iterates over cost functions
def __iter__(self):
return iter([f for f in self.cost_functions.values()])
return iter([cf for cf in self.cost_functions.values()])

def _get_iterator(self):
if self._cost_functions_iterable is None:
return iter([cf for cf in self.cost_functions.values()])
return iter([cf for cf in self._cost_functions_iterable])

# Applies to() with given args to all tensors in the objective
def to(self, *args, **kwargs):
Expand All @@ -482,3 +501,5 @@ def to(self, *args, **kwargs):
device, dtype, *_ = torch._C._nn._parse_to(*args, **kwargs)
self.device = device or self.device
self.dtype = dtype or self.dtype
if self._vectorization_to is not None:
self._vectorization_to(*args, **kwargs)
4 changes: 2 additions & 2 deletions theseus/core/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def __init__(

def error(self):
mu = torch.stack([v.data for v in self.optim_vars]).sum()
return mu * torch.ones(self._dim)
return mu * torch.ones(1, self._dim)

def jacobians(self):
return [self.error()] * len(self._optim_vars_attr_names)
return [torch.ones(1, self._dim, self._dim)] * len(self._optim_vars_attr_names)

def dim(self) -> int:
return self._dim
Expand Down
26 changes: 9 additions & 17 deletions theseus/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _create_cost_function_with_n_vars_and_m_aux(
yet_another_cost_function = _create_cost_function_with_n_vars_and_m_aux(
"yet_another", ["yet_var_1"], ["yet_aux_1"], cost_weight
)
with pytest.warns(RuntimeWarning): # optim var associated to weight
with pytest.raises(RuntimeError): # optim var associated to weight
objective.add(yet_another_cost_function) # no conflict here

cost_weight_with_conflict_in_aux_var = MockCostWeight(
Expand Down Expand Up @@ -151,24 +151,19 @@ def test_add_and_erase_step_by_step():
var3 = MockVar(1, data=None, name="var3")
aux1 = MockVar(1, data=None, name="aux1")
aux2 = MockVar(1, data=None, name="aux2")
cw1 = MockCostWeight(aux1, name="cw1", add_dummy_var_with_name="ignored_optim_var")
cw2 = MockCostWeight(aux2, name="cw2", add_optim_var=var1)
cw1 = MockCostWeight(
aux1, name="cw1"
) # , add_dummy_var_with_name="ignored_optim_var")
cw2 = MockCostWeight(aux2, name="cw2") # , add_optim_var=var1)

cf1 = MockCostFunction([var1, var2], [aux1, aux2], cw1, name="cf1")
cf2 = MockCostFunction([var1, var3], [aux1], cw2, name="cf2")
cf3 = MockCostFunction([var2, var3], [aux2], cw2, name="cf3")

objective = th.Objective()
for cost_function in [cf1, cf2, cf3]:
if cost_function is not cf3:
with pytest.warns(RuntimeWarning):
# a warning should emit the first time cw1/cw2 are added
objective.add(cost_function)
else:
objective.add(cost_function)
objective.add(cost_function)

for name in ["var1", "ignored_optim_var"]:
assert name in objective.cost_weight_optim_vars
for name in ["var1", "var2", "var2"]:
assert name in objective.optim_vars
for name in ["aux1", "aux2"]:
Expand Down Expand Up @@ -207,13 +202,10 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
_check_funs_for_variable(var1, v1_lis_)
_check_funs_for_variable(var2, v2_lis_)
_check_funs_for_variable(var3, v3_lis_)
_check_funs_for_variable(
cw1.optim_var_at(0), cw1_opt_lis_, is_cost_weight_optim=True
)
_check_funs_for_variable(aux1, a1_lis_, optim_var=False)
_check_funs_for_variable(aux2, a2_lis_, optim_var=False)

v1_lis = [cw2, cf1, cf2]
v1_lis = [cf1, cf2]
v2_lis = [cf1, cf3]
v3_lis = [cf2, cf3]
cw1o_lis = [cw1]
Expand All @@ -223,7 +215,7 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
_check_all_vars(v1_lis, v2_lis, v3_lis, cw1o_lis, a1_lis, a2_lis)

objective.erase("cf1")
v1_lis = [cw2, cf2]
v1_lis = [cf2]
v2_lis = [cf3]
cw1o_lis = []
a1_lis = [cf2] # cf1 and cw1 are deleted, since cw1 not used by any other cost fn
Expand All @@ -232,7 +224,7 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
assert cw1 not in objective.cost_functions_for_weights

objective.erase("cf2")
v1_lis = [cw2] # cw2 still used by cf3
v1_lis = []
v3_lis = [cf3]
a1_lis = []
_check_all_vars(v1_lis, v2_lis, v3_lis, cw1o_lis, a1_lis, a2_lis)
Expand Down
194 changes: 194 additions & 0 deletions theseus/core/tests/test_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

import theseus as th
from theseus.core.vectorizer import _CostFunctionWrapper


def test_costs_vars_and_err_before_vectorization():
for _ in range(20):
objective = th.Objective()
batch_size = torch.randint(low=1, high=10, size=(1,)).item()
v1 = th.Vector(data=torch.randn(batch_size, 1), name="v1")
v2 = th.Vector(data=torch.randn(batch_size, 1), name="v2")
odummy = th.Vector(1, name="odummy")
t1 = th.Vector(data=torch.zeros(1, 1), name="t1")
adummy = th.Variable(data=torch.zeros(1, 1), name="adummy")
cw1 = th.ScaleCostWeight(th.Variable(torch.zeros(1, 1), name="w1"))
cw2 = th.ScaleCostWeight(th.Variable(torch.zeros(1, 1), name="w2"))
cf1 = th.Difference(v1, cw1, t1)

# Also test with autodiff cost
def err_fn(optim_vars, aux_vars):
return optim_vars[0] - aux_vars[0]

cf2 = th.AutoDiffCostFunction([v2, odummy], err_fn, 1, cw2, [t1, adummy])

# Chech that vectorizer's has the correct number of wrappers
objective.add(cf1)
objective.add(cf2)
th.Vectorize(objective)

# Update weights after creating vectorizer to see if data is picked up correctly
w1 = torch.randn(1, 1) # also check that broadcasting works
w2 = torch.randn(batch_size, 1)

# disable for this test since we are not checking the result
objective._vectorization_run = None
objective.update({"w1": w1, "w2": w2})

def _check_attr(cf, var):
return hasattr(cf, var.name) and getattr(cf, var.name) is var

# Check that the vectorizer's cost functions have the right variables and error
saw_cf1 = False
saw_cf2 = False
for cf in objective._get_iterator():
assert isinstance(cf, _CostFunctionWrapper)
optim_vars = [v for v in cf.optim_vars]
aux_vars = [v for v in cf.aux_vars]
assert t1 in aux_vars
assert _check_attr(cf, t1)
w_err = cf.weighted_error()
if cf.cost_fn is cf1:
assert v1 in optim_vars
assert w_err.allclose((v1.data - t1.data) * w1)
assert _check_attr(cf, v1)
saw_cf1 = True
elif cf.cost_fn is cf2:
assert v2 in optim_vars and odummy in optim_vars
assert adummy in aux_vars
assert _check_attr(cf, v2) and _check_attr(cf, odummy)
assert w_err.allclose((v2.data - t1.data) * w2)
saw_cf2 = True
else:
assert False
assert saw_cf1 and saw_cf2


def test_correct_schemas_and_shared_vars():
luisenp marked this conversation as resolved.
Show resolved Hide resolved
v1 = th.Vector(1)
v2 = th.Vector(1)
tv = th.Vector(1)
w1 = th.ScaleCostWeight(1.0)
mv = th.Vector(1)

v3 = th.Vector(3)
v4 = th.Vector(3)

s1 = th.SE2()
s2 = th.SE2()
ts = th.SE2()

objective = th.Objective()
# these two can be grouped
cf1 = th.Difference(v1, w1, tv)
cf2 = th.Difference(v2, w1, tv)
objective.add(cf1)
objective.add(cf2)

# this one uses the same weight and v1, v2, but cannot be grouped
cf3 = th.Between(v1, v2, w1, mv)
objective.add(cf3)

# this one is the same cost function type, var type, and weight but different
# dimension, so cannot be grouped either
cf4 = th.Difference(v3, w1, v4)
objective.add(cf4)

# Now add another group with a different data-type (no-shared weight)
w2 = th.ScaleCostWeight(1.0)
w3 = th.ScaleCostWeight(2.0)
cf5 = th.Difference(s1, w2, ts)
cf6 = th.Difference(s2, w3, ts)
objective.add(cf5)
objective.add(cf6)

# Not grouped with anything cf1 and cf2 because weight type is different
w7 = th.DiagonalCostWeight([1.0])
cf7 = th.Difference(v1, w7, tv)
objective.add(cf7)

vectorization = th.Vectorize(objective)

assert len(vectorization._schema_dict) == 5
seen_cnt = [0] * 7
for schema, cost_fn_wrappers in vectorization._schema_dict.items():
cost_fns = [w.cost_fn for w in cost_fn_wrappers]
var_names = vectorization._var_names[schema]
if cf1 in cost_fns:
assert len(cost_fns) == 2
assert cf2 in cost_fns
seen_cnt[0] += 1
seen_cnt[1] += 1
assert f"{th.Vectorize._SHARED_TOKEN}{w1.scale.name}" in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{tv.name}" in var_names
if cf3 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[2] += 1
if cf4 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[3] += 1
if cf5 in cost_fns:
assert len(cost_fns) == 2
assert cf6 in cost_fns
seen_cnt[4] += 1
seen_cnt[5] += 1
assert f"{th.Vectorize._SHARED_TOKEN}{w2.scale.name}" not in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{w3.scale.name}" not in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{ts.name}" in var_names
if cf7 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[6] += 1
assert seen_cnt == [1] * 7


def test_vectorized_error():
rng = np.random.default_rng(0)
generator = torch.Generator()
generator.manual_seed(0)
for _ in range(20):
dim = rng.choice([1, 2])
objective = th.Objective()
batch_size = rng.choice(range(1, 11))

vectors = [
th.Vector(
data=torch.randn(batch_size, dim, generator=generator), name=f"v{i}"
)
for i in range(rng.choice([1, 10]))
]
target = th.Vector(dim, name="target")
w = th.ScaleCostWeight(torch.randn(1, generator=generator))
for v in vectors:
objective.add(th.Difference(v, w, target))

se3s = [
th.SE3(
data=th.SE3.rand(batch_size, generator=generator).data,
requires_check=False,
)
for i in range(rng.choice([1, 10]))
]
s_target = th.SE3.rand(1, generator=generator)
ws = th.DiagonalCostWeight(torch.randn(6, generator=generator))
# ws = th.ScaleCostWeight(torch.randn(1, generator=generator))
for s in se3s:
objective.add(th.Difference(s, ws, s_target))

vectorization = th.Vectorize(objective)
objective.update_vectorization()

assert objective._cost_functions_iterable is vectorization._cost_fn_wrappers
for w in vectorization._cost_fn_wrappers:
for cost_fn in objective.cost_functions.values():
if cost_fn is w.cost_fn:
w_jac, w_err = cost_fn.weighted_jacobians_error()
assert w._cached_error.allclose(w_err)
for jac, exp_jac in zip(w._cached_jacobians, w_jac):
assert jac.allclose(exp_jac, atol=1e-6)
Loading