In [None]:
#default_exp models

In [None]:
#exporti
import torch
import random
from itertools import product

In [None]:
#hide
from nbdev.showdoc import show_doc

# Equivariance

In [None]:
#export
class EquivarianceWrapper:
    """
    A class that represents an equivariance wrapper [1] that implements group equivariance via group averaging [2].
    """
    def __init__(self, 
                 preprocessing:"dl4to.preprocessing.Preprocessing"=None, # The preprocessing strategy to use. This is used in the equivariance wrapper to obtain the scalar and vector field information of the input.
                 rotate:bool=True, # Whether to include rotational equivariance in the transformation group.
                 mirror:bool=True, # Whether to include mirror equivariance in the transformation group.
                 dim:int=2, # The dimension of the transformation group. Specifically, a 2d transformation group does not consider rotations and mirrors along the z-axis.
                 rotate_twice:bool=False, # Whether double-rotations should be used, where the input is rotated twice, along two different axes. This may result in a larger transformation group.
                 sample_rate:float=1. # The rate of transformations that should be randomly sampled in the forward pass. `sample_rate=1.` defaults to all transformations being used in the wrapper. A smaller choice may be beneficial if memory constraints don't allow for the applications of all transformations in each forward pass.
                ):
        self.preprocessing = preprocessing
        assert dim == 2 or dim == 3
        self.rotate = rotate
        self.mirror = mirror
        self.dim = dim
        self.sample_rate = sample_rate
        self.cube_face_indices = self._get_cube_face_indices()
        self.rotate_twice = rotate_twice
        self.name = self._get_name()


    @property
    def preprocessing(self):
        return self._preprocessing


    @preprocessing.setter
    def preprocessing(self, preprocessing):
        self._preprocessing = preprocessing
        if preprocessing is not None:
            vector_directions = preprocessing.vector_directions
            self.vector_field_channels = self._get_vector_field_channels(vector_directions)



    def _get_name(self):
        name = "equiv"
        if self.rotate:
            name += "_rot"
        if self.mirror:
            name += "_mir"
        if self.rotate_twice:
            name += "_twice"
        name += f"_{self.dim}d"
        return name


    def _get_cube_face_indices(self):
        cube_face_indices = [0, 1, 3, 4]
        if self.dim == 3:
            cube_face_indices.extend([2, 5])
        return cube_face_indices


    def _get_vector_field_channels(self, vector_directions):
        x_channel_idx = [i for i, vector_direction in enumerate(vector_directions) if vector_direction == 'x']
        y_channel_idx = [i for i, vector_direction in enumerate(vector_directions) if vector_direction == 'y']
        z_channel_idx = [i for i, vector_direction in enumerate(vector_directions) if vector_direction == 'z']
        return [x_channel_idx, y_channel_idx, z_channel_idx]


    def mirror_input(self, 
                     x:torch.Tensor, # The input that should be mirrored/flipped.
                     flip_dimensions:list # The dimension along which the input should be mirrored.
                    ):
        """
        Returnes a `torch.Tensor`, which is a mirrored version of the input `x`.
        """
        x = torch.flip(x, flip_dimensions)
        x = self._mirror_vector_fields(x, flip_dimensions)
        return x


    def _mirror_vector_fields(self, x, flip_dimensions):
        for flip_dimension in flip_dimensions:
            for flip_channel in self.vector_field_channels[flip_dimension]:
                x[:,flip_channel] = -x[:,flip_channel]
        return x


    def rotate_input(self, 
                     x:torch.Tensor, # The input that should be rotated.
                     rotations:int, # The number of 90° rations that should be performed. Four rotations result in the identity.
                     plane:list # On which plane the input should be rotated.
                    ):
        """
        Returnes a `torch.Tensor`, which is a rotated version of the input `x`.
        """
        x = torch.rot90(x, rotations, plane)
        x = self._rotate_vector_fields(x, rotations, plane)
        return x


    def _rotate_vector_fields(self, x, rotations, plane):
        assert plane[0] < 0 and plane[1] < 0, "Indices of rotational plane must be given in negative integers"
        assert len(x.shape) == 5
        force_index_0, force_index_1 = self.vector_field_channels[plane[0]], self.vector_field_channels[plane[1]]

        if rotations % 4 == 1:
            temp = x[:,force_index_0].clone()
            x[:,force_index_0] = -x[:,force_index_1]
            x[:,force_index_1] = temp

        if rotations % 4 == 2:
            x[:,force_index_0] = -x[:,force_index_0]
            x[:,force_index_1] = -x[:,force_index_1]

        if rotations % 4 == 3:
            temp = x[:,force_index_0].clone()
            x[:,force_index_0] = x[:,force_index_1]
            x[:,force_index_1] = -temp
        return x


    def __get_rotations_and_plane_for_first_rotation(self, cube_face_index):
        if cube_face_index == 0: rotations, plane =  0, [-3, -2]
        if cube_face_index == 1: rotations, plane =  1, [-3, -2]
        if cube_face_index == 2: rotations, plane = -1, [-1, -3]
        if cube_face_index == 3: rotations, plane =  2, [-3, -2]
        if cube_face_index == 4: rotations, plane =  3, [-3, -2]
        if cube_face_index == 5: rotations, plane =  1, [-1, -3]
        return rotations, plane


    def __get_rotations_and_plane_for_second_rotation(self, cube_face_index):
        rotations = [0,2]
        if cube_face_index%3 == 0: plane = [-2, -1]
        if cube_face_index%3 == 1: plane = [-1, -3]
        if self.dim == 3:
            rotations.extend([1, 3])
            if cube_face_index%3 == 2: plane = [-3, -2]
        return rotations, plane


    def _get_rotations(self):
        transforms = []
        for cube_face_index in self.cube_face_indices:
            rotations, plane = self.__get_rotations_and_plane_for_first_rotation(cube_face_index)
            face_transform = lambda x, rotations=rotations, plane=plane: self.rotate_input(x, rotations, plane)
            inverse_face_transform = lambda x, rotations=rotations, plane=plane: torch.rot90(x, -rotations, plane)

            if self.rotate_twice:
                rotations, plane = self.__get_rotations_and_plane_for_second_rotation(cube_face_index)
                for rotation in rotations:
                    transform = lambda x, face_transform=face_transform, rotation=rotation, plane=plane: self.rotate_input(face_transform(x), rotation, plane)
                    inverse_transform = lambda x, inverse_face_transform=inverse_face_transform, rotation=rotation, plane=plane: inverse_face_transform(torch.rot90(x, -rotation, plane))
                    transforms.append((transform, inverse_transform))
            else:
                transforms.append((face_transform, inverse_face_transform))
        return transforms


    def _get_flip_dimensions(self, rotation_transforms):
        flip_dimensions = []

        if len(rotation_transforms) != 1:
            bool_combinations = [[True]]
        else:
            bool_combinations = [[True, False] for _ in range(self.dim)]

        for axes in product(*bool_combinations):
            flip_dimension = []

            for i, axis in enumerate(axes):
                if axis:
                    flip_dimension.append(-3+i)

            flip_dimensions.append(flip_dimension)
        return flip_dimensions


    def _get_mirrors(self, rotation_transforms):
        transforms = []
        flip_dimensions = self._get_flip_dimensions(rotation_transforms)

        for rotation_transform, inverse_rotation_transform in rotation_transforms:
            for flip_dimension in flip_dimensions:
                transform = lambda x, rotation_transform=rotation_transform, flip_dimension=flip_dimension: rotation_transform(self.mirror_input(x, flip_dimension))
                inverse_transform = lambda x, inverse_rotation_transform=inverse_rotation_transform, flip_dimension=flip_dimension: torch.flip(inverse_rotation_transform(x), flip_dimension)
                transforms.append((transform, inverse_transform))
        return transforms


    def _sample_transforms(self, transforms, sample_rate):
        if sample_rate == 1.:
            return transforms
        number_of_sampled_transforms = max(1, round(len(transforms) * sample_rate))
        sampled_transforms = random.sample(transforms, k=number_of_sampled_transforms)
        return sampled_transforms


    def get_transforms(self,
                       sample_rate:float=None # The rate of transformations that should be randomly samples from the equivariance wrapper. `None` defaults to `equivariance_wrapper.sample_rate`. `1.` means that all transformations are considered.
                      ):
        """
        Returns a list of all group actions that are applied to an input in the equivariance wrapper.
        """
        if sample_rate is None:
            sample_rate = self.sample_rate
        rotation_transforms = [(lambda x: x, lambda x: x)]
        transforms = []
        if self.rotate:
            rotation_transforms = self._get_rotations()
            transforms = rotation_transforms
        if self.mirror:
            transforms += self._get_mirrors(rotation_transforms)
        transforms = self._sample_transforms(transforms, sample_rate)
        return transforms


    def __call__(self, 
                 model:torch.nn.Module # The model that should be turned into an equivariant model.
                ):
        """
        Applies the equivariance wrapper to a `torch.nn.Module` model object and returns an `dl4to.models.EquivariantModel` object.
        """
        assert self.vector_field_channels is not None, print("EquivarianceWrapper does not have a preprocessing.")
        return EquivariantModel(model=model, equivariance_wrapper=self)

In [None]:
show_doc(EquivarianceWrapper.mirror_input)

<h4 id="EquivarianceWrapper.mirror_input" class="doc_header"><code>EquivarianceWrapper.mirror_input</code><a href="__main__.py#L65" class="source_link" style="float:right">[source]</a></h4>

> <code>EquivarianceWrapper.mirror_input</code>(**`x`**:`Tensor`, **`flip_dimensions`**:`list`)

Returnes a `torch.Tensor`, which is a mirrored version of the input `x`.

||Type|Default|Details|
|---|---|---|---|
|**`x`**|`Tensor`||The input that should be mirrored/flipped.|
|**`flip_dimensions`**|`list`||The dimension along which the input should be mirrored.|


In [None]:
show_doc(EquivarianceWrapper.rotate_input)

<h4 id="EquivarianceWrapper.rotate_input" class="doc_header"><code>EquivarianceWrapper.rotate_input</code><a href="__main__.py#L84" class="source_link" style="float:right">[source]</a></h4>

> <code>EquivarianceWrapper.rotate_input</code>(**`x`**:`Tensor`, **`rotations`**:`int`, **`plane`**:`list`)

Returnes a `torch.Tensor`, which is a rotated version of the input `x`.

||Type|Default|Details|
|---|---|---|---|
|**`x`**|`Tensor`||The input that should be rotated.|
|**`rotations`**|`int`||The number of 90° rations that should be performed. Four rotations result in the identity.|
|**`plane`**|`list`||On which plane the input should be rotated.|


In [None]:
show_doc(EquivarianceWrapper.get_transforms)

<h4 id="EquivarianceWrapper.get_transforms" class="doc_header"><code>EquivarianceWrapper.get_transforms</code><a href="__main__.py#L195" class="source_link" style="float:right">[source]</a></h4>

> <code>EquivarianceWrapper.get_transforms</code>(**`sample_rate`**:`float`=*`None`*)

Returns a list of all group actions that are applied to an input in the equivariance wrapper.

||Type|Default|Details|
|---|---|---|---|
|**`sample_rate`**|`float`|`None`|The rate of transformations that should be randomly samples from the equivariance wrapper. `None` defaults to `equivariance_wrapper.sample_rate`. `1.` means that all transformations are considered.|


In [None]:
show_doc(EquivarianceWrapper.__call__)

<h4 id="EquivarianceWrapper.__call__" class="doc_header"><code>EquivarianceWrapper.__call__</code><a href="__main__.py#L214" class="source_link" style="float:right">[source]</a></h4>

> <code>EquivarianceWrapper.__call__</code>(**`model`**:`Module`)

Applies the equivariance wrapper to a `torch.nn.Module` model object and returns an `dl4to.models.EquivariantModel` object.

||Type|Default|Details|
|---|---|---|---|
|**`model`**|`Module`||The model that should be turned into an equivariant model.|


In [None]:
#export
class EquivariantModel(torch.nn.Module):
    """
    A class that represents an equivariant model with respect to a specific equivariance wrapper.
    """
    def __init__(self, 
                 model:torch.nn.Module, # A PyTorch neural network.
                 equivariance_wrapper:"dl4to.models.EquivarianceWrapper" # The equivariance wrapper that is applied to the model.
                ):
        super().__init__()
        self.model = model
        self.equivariance_wrapper = equivariance_wrapper


    def __call__(self, 
                 model_inputs:torch.Tensor, # The model inputs that are obtained as output of the preprocessing.
                 sample_rate:float=None # The rate of transformations that should be randomly samples from the equivariance wrapper. `None` defaults to `equivariance_wrapper.sample_rate`.`1.` means that all transformations are applied in the forward pass.
                ):
        """
        The forward method for the equivariant model.
        """
        assert model_inputs.shape[1] == len(self.equivariance_wrapper.preprocessing.vector_directions)
        transforms = self.equivariance_wrapper.get_transforms(sample_rate)
        model_outputs = 0
        for transform, inverse_transform in transforms:
            model_inputs_transformed = transform(model_inputs)
            model_outputs_transformed = self.model(model_inputs_transformed)
            model_outputs += inverse_transform(model_outputs_transformed)
        return model_outputs / len(transforms)

In [None]:
show_doc(EquivariantModel.__call__)

<h4 id="EquivariantModel.__call__" class="doc_header"><code>EquivariantModel.__call__</code><a href="__main__.py#L15" class="source_link" style="float:right">[source]</a></h4>

> <code>EquivariantModel.__call__</code>(**`model_inputs`**:`Tensor`, **`sample_rate`**:`float`=*`None`*)

The forward method for the equivariant model.

||Type|Default|Details|
|---|---|---|---|
|**`model_inputs`**|`Tensor`||The model inputs that are obtained as output of the preprocessing.|
|**`sample_rate`**|`float`|`None`|The rate of transformations that should be randomly samples from the equivariance wrapper. `None` defaults to `equivariance_wrapper.sample_rate`.`1.` means that all transformations are applied in the forward pass.|


# References

[1] Dittmer, Sören, et al. "SELTO: Sample-Efficient Learned Topology Optimization." arXiv preprint arXiv:2209.05098 (2022).

[2] Puny, Omri, et al. "Frame averaging for invariant and equivariant network design." arXiv preprint arXiv:2110.03336 (2021).

In [None]:
#hide
from dl4to.preprocessing import ForcePreprocessing, TrivialPreprocessing

In [None]:
%%time
#hide
#slow

def test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate, mirror, dim, rotate_twice, verbose=False):
    shape = [20, 20, 20]
    p0 = [0., 0., 0.]
    p1 = [.1 , .1 , .1]

    inputs = torch.zeros(2, 3, *shape)
    inputs[0, :, :20, :1] = 1.
    inputs[1, :, :20, :1] = .6


    preprocessing = ForcePreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing,
                                       rotate=rotate, mirror=mirror, 
                                       dim=dim, rotate_twice=rotate_twice)
    sampled_transforms = equivariance.get_transforms(sample_rate=1)

    for i, (transform, inverse_transform) in enumerate(sampled_transforms):
        inputs_transformed = transform(inputs)
        negatives_in_sample_zero = True
        negatives_in_sample_one = True

        if not torch.allclose(inputs, inputs_transformed) and not (rotate and mirror):
            if (rotate and dim == 3) and (i == 6 or i == 19):
                pass
            else:
                negatives_in_sample_zero = torch.any(inputs_transformed[0] < 0)
                negatives_in_sample_one = torch.any(inputs_transformed[1] < 0)

        inputs_aggregated_and_transformed = inputs_transformed.abs().sum(dim=1)
        prediction = inverse_transform(inputs_aggregated_and_transformed)

        target = inputs.abs().sum(dim=1)

        target_and_prediction_are_the_same = torch.allclose(target, prediction)

        assert negatives_in_sample_zero, f"no negatives in sample 0 for transform {i}."
        assert negatives_in_sample_one, f"no negatives in sample 1 for transform {i}."
        assert target_and_prediction_are_the_same, f"target and prediction are different for transform {i}."


test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=False, dim=2, rotate_twice=False)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=False, mirror=True, dim=2, rotate_twice=False)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=True, dim=2, rotate_twice=False)

test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=False, dim=2, rotate_twice=True)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=False, mirror=True, dim=2, rotate_twice=True)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=True, dim=2, rotate_twice=True)

test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=False, dim=3, rotate_twice=True)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=False, mirror=True, dim=3, rotate_twice=True)
test_that_transform_aggregation_and_inverse_transform_yields_the_aggregated_input_for_batch_size_larger_one(rotate=True, mirror=True, dim=3, rotate_twice=True)

CPU times: user 4.55 s, sys: 0 ns, total: 4.55 s
Wall time: 254 ms


In [None]:
%%time
#hide

def test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate, mirror, dim, rotate_twice):
    preprocessing = ForcePreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing, 
                                       rotate=rotate, mirror=mirror,
                                       dim=dim, rotate_twice=rotate_twice)
    n_transforms = 0
    if dim == 2:
        if rotate:
            n_transforms += 4
            if rotate_twice:
                n_transforms += 4
        if mirror:
            n_transforms += 4
            if rotate_twice and rotate:
                n_transforms += 4
        assert len(equivariance.get_transforms(sample_rate=1)) == n_transforms, (len(equivariance.get_transforms(sample_rate=1)), n_transforms)

    if dim == 3:
        if rotate:
            n_transforms += 24
        if mirror:
            n_transforms += 8
        if rotate and mirror:
            n_transforms += 16
        assert len(equivariance.get_transforms(sample_rate=1)) == n_transforms, (len(equivariance.get_transforms(sample_rate=1)), n_transforms)


test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=False, dim=2, rotate_twice=False)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=False, mirror=True, dim=2, rotate_twice=False)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=True, dim=2, rotate_twice=False)

test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=False, dim=2, rotate_twice=True)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=False, mirror=True, dim=2, rotate_twice=True)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=True, dim=2, rotate_twice=True)

test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=False, dim=3, rotate_twice=True)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=False, mirror=True, dim=3, rotate_twice=True)
test_that_transformation_groups_yield_the_correct_number_of_transforms(rotate=True, mirror=True, dim=3, rotate_twice=True)

CPU times: user 215 µs, sys: 0 ns, total: 215 µs
Wall time: 218 µs


In [None]:
%%time
#hide

def test_that_transformation_group_yields_unique_transformations(rotate, mirror, dim, rotate_twice):
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing, 
                                       rotate=rotate, mirror=mirror,
                                       dim=dim, rotate_twice=rotate_twice)
    x = torch.rand(10, 7, 3, 3, 3)
    list_of_transformed_tensors = []

    for transformation, _ in equivariance.get_transforms(sample_rate=1):
        x_ = transformation(x)
        for transformed_tensor in list_of_transformed_tensors:
            assert not torch.equal(x_, transformed_tensor)
        list_of_transformed_tensors.append(x_)


test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=False, dim=2, rotate_twice=False)
test_that_transformation_group_yields_unique_transformations(rotate=False, mirror=True, dim=2, rotate_twice=False)
test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=True, dim=2, rotate_twice=False)

test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=False, dim=2, rotate_twice=True)
test_that_transformation_group_yields_unique_transformations(rotate=False, mirror=True, dim=2, rotate_twice=True)
test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=True, dim=2, rotate_twice=True)

test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=False, dim=3, rotate_twice=True)
test_that_transformation_group_yields_unique_transformations(rotate=False, mirror=True, dim=3, rotate_twice=True)
test_that_transformation_group_yields_unique_transformations(rotate=True, mirror=True, dim=3, rotate_twice=True)

CPU times: user 25 ms, sys: 0 ns, total: 25 ms
Wall time: 24.3 ms


In [None]:
%%time
#hide

def test_that_transformation_group_yields_uniquely_sampled_transformations(rotate, mirror, dim, rotate_twice):
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing, 
                                       rotate=rotate, mirror=mirror,
                                       dim=dim, rotate_twice=rotate_twice)
    x = torch.rand(10, 7, 3, 3, 3)

    list_of_transformed_tensors = []

    for transformation, _ in equivariance.get_transforms(sample_rate=1):
        x_ = transformation(x)
        for transformed_tensor in list_of_transformed_tensors:
            assert not torch.equal(x_, transformed_tensor)
        list_of_transformed_tensors.append(x_)


test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=False, dim=2, rotate_twice=False)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=False, mirror=True, dim=2, rotate_twice=False)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=True, dim=2, rotate_twice=False)

test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=False, dim=2, rotate_twice=True)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=False, mirror=True, dim=2, rotate_twice=True)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=True, dim=2, rotate_twice=True)

test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=False, dim=3, rotate_twice=True)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=False, mirror=True, dim=3, rotate_twice=True)
test_that_transformation_group_yields_uniquely_sampled_transformations(rotate=True, mirror=True, dim=3, rotate_twice=True)

CPU times: user 27 ms, sys: 0 ns, total: 27 ms
Wall time: 25.5 ms


In [None]:
%%time
#hide

def test_that_rotate_vector_fields_changes_the_tensor():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    planes = [[-2, -1], [-1, -3], [-3, -2]]
    rotations = [-3, -2, -1, 1, 2, 3]
    for plane in planes:
        for rotation in rotations:
            x_original = x.clone()
            assert not torch.allclose(x_original, equivariance._rotate_vector_fields(x, rotation, plane))


test_that_rotate_vector_fields_changes_the_tensor()

CPU times: user 4.08 ms, sys: 293 µs, total: 4.37 ms
Wall time: 3.65 ms


In [None]:
%%time
#hide

def test_that_rotate_vector_fields_is_invertable():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    planes = [[-2, -1], [-1, -3], [-3, -2]]
    rotations = [0, 1, 2, 3]

    for plane in planes:
        for rotation in rotations:
            x_original = x.clone()
            assert torch.allclose(x_original, equivariance._rotate_vector_fields(equivariance._rotate_vector_fields(x, rotation, plane), -rotation, plane))


test_that_rotate_vector_fields_is_invertable()

CPU times: user 2.66 ms, sys: 0 ns, total: 2.66 ms
Wall time: 2.41 ms


In [None]:
%%time
#hide

def test_that_rotate_vector_fields_works_correctly():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    planes = [[-2, -1], [-1, -3], [-3, -2]]
    rotations = [-3, -2, -1, 1, 2, 3]

    for plane in planes:
        for rotation in rotations:
            x_original = x.clone()
            assert not torch.allclose(x_original, equivariance._rotate_vector_fields(x, rotation, plane))


test_that_rotate_vector_fields_works_correctly()

CPU times: user 6.28 ms, sys: 0 ns, total: 6.28 ms
Wall time: 5.76 ms


In [None]:
%%time
#hide

def test_that_rotate_changes_the_tensor():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    planes = [[-2, -1], [-1, -3], [-3, -2]]
    rotations = [-3, -2, -1, 1, 2, 3]

    for plane in planes:
        for rotation in rotations:
            x_original = x.clone()
            assert not torch.allclose(x_original, equivariance.rotate_input(x, rotation, plane))


test_that_rotate_changes_the_tensor()

CPU times: user 8.47 ms, sys: 29 µs, total: 8.5 ms
Wall time: 7.65 ms


In [None]:
%%time
#hide

def test_that_rotate_is_invertable():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    planes = [[-2, -1], [-1, -3], [-3, -2]]
    rotations = [0, 1, 2, 3]

    for plane in planes:
        for rotation in rotations:
            x_original = x.clone()
            assert torch.allclose(x_original, equivariance.rotate_input(equivariance.rotate_input(x, rotation, plane), -rotation, plane))


test_that_rotate_is_invertable()

CPU times: user 5.71 ms, sys: 59 µs, total: 5.77 ms
Wall time: 4.77 ms


In [None]:
%%time
#hide

def test_that__mirror_vector_fields_changes_the_tensor():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    for flip_x_axis in [False, True]:
        for flip_y_axis in [False, True]:
            for flip_z_axis in [False, True]:
                flip_dimensions = []
                if flip_x_axis: flip_dimensions.append(-3)
                if flip_y_axis: flip_dimensions.append(-2)
                if flip_z_axis: flip_dimensions.append(-1)

                if any([flip_x_axis, flip_y_axis, flip_z_axis]):
                    x_original = x.clone()
                    assert not torch.allclose(x_original, equivariance._mirror_vector_fields(x, flip_dimensions))


test_that__mirror_vector_fields_changes_the_tensor()

CPU times: user 1.47 ms, sys: 0 ns, total: 1.47 ms
Wall time: 1.31 ms


In [None]:
%%time
#hide

def test_that_mirror_changes_the_tensor():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    for flip_x_axis in [False, True]:
        for flip_y_axis in [False, True]:
            for flip_z_axis in [False, True]:
                flip_dimensions = []
                if flip_x_axis:
                    flip_dimensions.append(-3)
                if flip_y_axis:
                    flip_dimensions.append(-2)
                if flip_z_axis:
                    flip_dimensions.append(-1)

                if any([flip_x_axis, flip_y_axis, flip_z_axis]):
                    x_original = x.clone()
                    assert not torch.allclose(x_original, equivariance.mirror_input(x, flip_dimensions=flip_dimensions))


test_that_mirror_changes_the_tensor()

CPU times: user 0 ns, sys: 1.73 ms, total: 1.73 ms
Wall time: 1.58 ms


In [None]:
%%time
#hide

def test_that__mirror_vector_fields_is_invertable():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    for flip_x_axis in [False, True]:
        for flip_y_axis in [False, True]:
            for flip_z_axis in [False, True]:
                flip_dimensions = []
                if flip_x_axis: flip_dimensions.append(-3)
                if flip_y_axis: flip_dimensions.append(-2)
                if flip_z_axis: flip_dimensions.append(-1)

                if any([flip_x_axis, flip_y_axis, flip_z_axis]):
                    x_original = x.clone()
                    assert torch.allclose(x_original, equivariance._mirror_vector_fields(equivariance._mirror_vector_fields(x, flip_dimensions), flip_dimensions))


test_that__mirror_vector_fields_is_invertable()

CPU times: user 1.97 ms, sys: 0 ns, total: 1.97 ms
Wall time: 1.73 ms


In [None]:
%%time
#hide

def test_that_mirror_is_invertable():
    preprocessing = TrivialPreprocessing()
    equivariance = EquivarianceWrapper(preprocessing=preprocessing)
    x = torch.rand(10, 7, 3, 3, 3)

    for flip_x_axis in [False, True]:
        for flip_y_axis in [False, True]:
            for flip_z_axis in [False, True]:
                flip_dimensions = []
                if flip_x_axis: flip_dimensions.append(-3)
                if flip_y_axis: flip_dimensions.append(-2)
                if flip_z_axis: flip_dimensions.append(-1)

                if any([flip_x_axis, flip_y_axis, flip_z_axis]):
                    x_original = x.clone()
                    assert torch.allclose(x_original, equivariance.mirror_input(equivariance.mirror_input(x, flip_dimensions), flip_dimensions))


test_that_mirror_is_invertable()

CPU times: user 2.3 ms, sys: 0 ns, total: 2.3 ms
Wall time: 2.07 ms
