diff --git a/examples/lie_labs.py b/examples/lie_labs.py index 63bb95c3d..7c573dbc8 100644 --- a/examples/lie_labs.py +++ b/examples/lie_labs.py @@ -10,17 +10,69 @@ print(f"Created SE3 tensor with shape {g1.shape}") g2 = g1.clone() -# Can create from a tensor as long as it's consistent with the desired ltype -g3_data = lieF.so3.rand(5) -g3 = lie.from_tensor(g3_data, lie.SO3) +# Identity element +i1 = lie.SO3.identity(2) +i2 = lie.SE3.identity(2) +print("SO3 identity", i1, i1.shape) +print("SE3 identity", i2, i2.shape) + +# Indexing +g1_slice = g1[2:4] +assert g1_slice.shape == (2, 3, 4) +torch.testing.assert_close(g1_slice._t, g1._t[2:4]) # type: ignore +try: + bad = g1[3, 2] +except NotImplementedError: + print("INDEXING ERROR: Can only slice the first dimension for now.") + +# ## Different constructors +g3_data = lieF.SO3.rand(5, requires_grad=True) # this is a regular tensor with SO3 data +# Can create from a tensor as long as it's consistent with the desired ltype +g3 = lie.from_tensor(g3_data, lie.SO3) # keeps grad history +assert g3.grad_fn is not None try: x = lie.from_tensor(torch.zeros(1, 3, 3), lie.SO3) except ValueError as e: print(f"ERROR: {e}") + +def is_shared(t1, t2): # utility to check if memory is shared + return t1.storage().data_ptr() == t2.storage().data_ptr() + + +# # Let's check different copy vs no-copy options +# -- lie.SO3() lie.SE3() +g3_leaf = lie.SO3(g3_data) # creates a leaf tensor and copies data +assert g3_leaf.grad_fn is None +assert not is_shared(g3_leaf, g3_data) + +# -- lie.LieTensor() constructor is equilvalent to lie.SO3() +g3_leaf_2 = lie.LieTensor(g3_data, lie.SO3) +assert g3_leaf_2.grad_fn is None +assert not is_shared(g3_leaf_2, g3_data) + + +# -- as_lietensor() g4 = lie.as_lietensor(g3_data, lie.SO3) -g5 = lie.cast(g3_data, lie.SO3) # alias for as_lietensor +assert is_shared(g3_data, g4) # shares storage if possible +assert g4.grad_fn is not None # result is not a leaf +# Calling with a LieTensor returns the same tensor... +g5 = lie.as_lietensor(g3, lie.SO3) +assert g5 is g3 +# ... unless dtype or device is different +g5_double = lie.as_lietensor(g3, lie.SO3, dtype=torch.double) +assert g5_double is not g3 +assert not is_shared(g5_double, g3) + +# -- cast() +g6 = lie.cast(g3_data, lie.SO3) # alias for as_lietensor +assert is_shared(g3_data, g6) + +# -- LieTensor.new() +g7 = g3.new(g3_data) +assert is_shared(g3_data, g7) # shares storage +assert g7.grad_fn is not None # differentiable # ### Lie operations v = torch.randn(batch_size, 6)