## Problem Set 4, Problem 6

In this exercise, you will explore the invariances that can be helpful for classifying point cloud data.

<br>
<hr>

### Problem Definition

#### Point Cloud Classification

- Point cloud data is represented as a set $x = \{ x_i \in \mathbb{R}^3 \}_{i=1}^N$, where $N$ is the number of points.
Denote the space of point clouds by $X$.
- The objective is to investigate the invariances of a classification model $F: X \rightarrow \mathbb{R}^K$, where $y \in \mathbb{R}^K$ represents the classification logits for $K$ classes.

#### Invariances in the Point Cloud Classification Task

- Although a set does not have an inherent order of elements, a point cloud is represented as a list $x = [x_i]_{i=1}^N$ in computers.
Thus, the model should be invariant to permutations of the elements in the list to effectively treat it as a set.
- $\mathbb{R}^3$-invariance, which ensures that translating the point cloud in 3D does not affect the classification output, would also be helpful for this classification task.
- $\mathrm{SO}(3)$-invariance, which ensures that rotating the point cloud in 3D does not affect the classification output, would be helpful for this classification task.

<hr>

### Dependency and settings

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.dgcnn import DGCNN_cls, knn

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

num_pts = 1024      # N
num_classes = 16    # K

<hr>

### Load point cloud data

Load one point cloud data from the ShapeNet [1] dataset and downsample it to contain $N$ points

In [None]:
data = np.loadtxt('ShapeNet/02691156/1a04e3eab45ca15dd86060f189eb133.txt')

x = torch.Tensor(data[:num_pts, :3]).to(device)

<hr>

### Group actions

The elements and group actions of permutation, translation, and rotation groups are defined as follows:

(1) Permutation

- An element of the permutation group is represented as a list $\sigma = [\sigma_i \in \mathbb{N}]_{i=1}^N$ where $1 \leq \sigma_i \leq N$ and $\sigma_i \neq \sigma_j$ for $i \neq j$.
- The action of $\sigma$ on $x = [x_1, \cdots, x_N]$ is defined as follows:
$$\sigma \cdot x = [x_{\sigma_1}, \cdots, x_{\sigma_N}]$$

(2) Translation

- An element of the translation group is represented as a matrix $p \in \mathbb{R}^3$.
- The action of $p$ on $x = [x_1, \cdots, x_N]$ is defined as follows:
$$p \cdot x = [x_1 + p, \cdots, x_N + p]$$

(2) Rotation

- An element of the rotation group is represented as a matrix $R \in \mathrm{SO}(3)$.
- The action of $R$ on $x = [x_1, \cdots, x_N]$ is defined as follows:
$$R \cdot x = [R x_1, \cdots, R x_N]$$

<b>Q. Implement functions to perform the group actions for each group.</b>

In [None]:
def action_perm_x(x, sigma):
    """
    Parameters:
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - sigma: torch.Tensor of shape (N,) representing the permutation indices

    Returns:
    - x_perm: torch.Tensor of shape (N, 3) representing the permuted point cloud
    """

    ##############################
    ####### YOUR CODE HERE #######
    ######## x_perm = ... ########
    ##############################

    return x_perm

def action_trans_x(x, p):
    """
    Parameters:
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - p: torch.Tensor of shape (3,) representing the translation vector

    Returns:
    - x_trans: torch.Tensor of shape (N, 3) representing the translated point cloud
    """

    ##############################
    ####### YOUR CODE HERE #######
    ####### x_trans = ... ########
    ##############################

    return x_trans

def action_rot_x(x, R):
    """
    Parameters:
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - R: torch.Tensor of shape (3, 3) representing the rotation matrix

    Returns:
    - x_rot: torch.Tensor of shape (N, 3) representing the rotated point cloud
    """

    ##############################
    ####### YOUR CODE HERE #######
    ######### x_rot = ... ########
    ##############################

    return x_rot

Check if the group action functions are implemented correctly.

In [None]:
sigma = torch.load('assets/sigma.pt', weights_only=True).to(device)
R = torch.load('assets/R.pt', weights_only=True).to(device)
p = torch.load('assets/p.pt', weights_only=True).to(device)

assert torch.allclose(action_perm_x(x, sigma), torch.load('assets/x_perm.pt', weights_only=True).to(device))
assert torch.allclose(action_rot_x(x, R), torch.load('assets/x_rot.pt', weights_only=True).to(device))
assert torch.allclose(action_trans_x(x, p), torch.load('assets/x_trans.pt', weights_only=True).to(device))

<hr>

### Build classification model

First, we utilize DGCNN [2] as point cloud classification model.

In [None]:
model_1 = DGCNN_cls(k=20, num_classes=num_classes).to(device)
model_1.eval()    # Set model to evaluation mode to avoid batchnorm issues with single-batch inputs in training mode

Get classification logits of the point cloud by running:

In [None]:
model_1(x)

<hr>

### Check invariances

For a model $F$ to be invariant to a group $G$, it must satisfy the following condition:
$$F(g \cdot x) = F(x) \quad \text{for all } g \in G$$

<b>Q. Implement functions to check if the provided model is invariant to permutations, translations, and rotations.</b>
(Hint: You may find `torch.allclose` useful.)

In [None]:
def check_perm_inv_model(model, x, sigma):
    """
    Parameters:
    - model: nn.Module representing the model
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - sigma: torch.Tensor of shape (N,) representing the permutation indices

    Returns:
    - is_perm_inv: bool representing whether the model is permutation invariant
    """

    ##############################
    ####### YOUR CODE HERE #######
    ###### is_perm_inv = ... #####
    ##############################

    return is_perm_inv

def check_trans_inv_model(model, x, p):
    """
    Parameters:
    - model: nn.Module representing the model
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - p: torch.Tensor of shape (3,) representing the translation vector

    Returns:
    - is_trans_inv: bool representing whether the model is translation invariant
    """

    ##############################
    ####### YOUR CODE HERE #######
    ##### is_trans_inv = ... #####
    ##############################

    return is_trans_inv

def check_rot_inv_model(model, x, R):
    """
    Parameters:
    - model: nn.Module representing the model
    - x: torch.Tensor of shape (N, 3) representing the point cloud
    - R: torch.Tensor of shape (3, 3) representing the rotation matrix

    Returns:
    - is_rot_inv: bool representing whether the model is rotation invariant
    """

    ##############################
    ####### YOUR CODE HERE #######
    ###### is_rot_inv = ... ######
    ##############################

    return is_rot_inv

Check if DGCNN is permutation, translation, and rotation invariant.

In [None]:
is_perm_inv = check_perm_inv_model(model_1, x, sigma)
is_trans_inv = check_trans_inv_model(model_1, x, p)
is_rot_inv = check_rot_inv_model(model_1, x, R)

print(f"DGCNN is{' ' if is_perm_inv else ' NOT '}permutation invariant.")
print(f"DGCNN is{' ' if is_trans_inv else ' NOT '}translation invariant.")
print(f"DGCNN is{' ' if is_rot_inv else ' NOT '}rotation invariant.")

<hr>

### Incorporate $\mathbb{R}^3$-invariance through canonicalization

Using a mapping $\bar{p}: X \rightarrow \mathbb{R}^3$ that satisfies $\bar{p}(p \cdot x) = p \cdot \bar{p}(x)$ for all $x \in X$ and $p \in \mathbb{R}^3$, we can construct an $\mathbb{R}^3$-invariant model $\hat{F}$ from a  non-$\mathbb{R}^3$-invariant model $F$ as follows:
$$\hat{F}(x) = F(\bar{p}(x)^{-1} \cdot x)$$
where $\bar{p}(x)^{-1}$ is not the inverse function of $\bar{p}$, but rather the inverse of the group element output by $\bar{p}(x)$. For instance, since $\bar{p}(x)$ outputs a translation vector, $\bar{p}(x)^{-1} = -\bar{p}(x)$.

Here, the mapping $\bar{p}$ transforms the input into a standard canonical form, and is the key to turning a non-invariant model $F$ into $\mathbb{R}^3$-invariant model.

<b>Q. Construct $\bar{p}$ that satisfies the required condition.</b>

In [None]:
class ModelWithCanonicalization:
    def __init__(self, model):
        self.model = model

    def __call__(self, x):
        p_bar = self._p_bar(x)

        out = self.model(action_trans_x(x, -p_bar))   # Inverse of p_bar is -p_bar

        return out

    def _p_bar(self, x):
        ##############################
        ####### YOUR CODE HERE #######
        ######## p_bar = ... #########
        ##############################

        return p_bar

model_2 = ModelWithCanonicalization(model_1)

Check if DGCNN with canonicalization is permutation, translation, and rotation invariant.

In [None]:
is_perm_inv = check_perm_inv_model(model_2, x, sigma)
is_trans_inv = check_trans_inv_model(model_2, x, p)
is_rot_inv = check_rot_inv_model(model_2, x, R)

print(f"DGCNN with canonicalization is{' ' if is_perm_inv else ' NOT '}permutation invariant.")
print(f"DGCNN with canonicalization is{' ' if is_trans_inv else ' NOT '}translation invariant.")
print(f"DGCNN with canonicalization is{' ' if is_rot_inv else ' NOT '}rotation invariant.")

<hr>

### Incorporate $\mathrm{SO}(3)$-invariance through network architecture

Instead of using canonicalization, we can design the network architecture itself to be equivariant to group actions.

Vector Neurons (VNs) [3] are a network architecture specifically designed to achieve $\mathrm{SO}(3)$-equivariance with point cloud data.

In this section, we will implement the basic linear layers of VNs.

#### Feature of Vector Neurons and group action of $\mathrm{SO}(3)$

VNs represent $i$-th point with a <em>vector-list feature</em> $\boldsymbol{V} \in \mathbb{R}^{C \times 3}$ which is a list of $C$ 3D vectors, resulting in vector-list features
 $\mathcal{V} = \{ \boldsymbol{V}_1, \cdots, \boldsymbol{V}_N \} \in \mathbb{R}^{N \times C \times 3}$.

For vector-list features $\mathcal{V}$, an element $R$ of the rotation group $\mathrm{SO}(3)$ acts as follows:
$$R \cdot \mathcal{V} = \mathcal{V} R^T$$

<b>Q. Implement functions of group action of $\mathrm{SO}(3)$.</b>

In [None]:
def action_rot_calV(calV, R):
    """
    Parameters:
    - calV: torch.Tensor of shape (B, N, C, 3) representing the vector-list features (B: batch size, N: number of points, C: number of channels)
    - R: torch.Tensor of shape (3, 3) representing the rotation matrix

    Returns:
    - calV_rot: torch.Tensor of shape (B, N, C, 3) representing the rotated vector-list features
    """

    ##############################
    ####### YOUR CODE HERE #######
    ####### calV_rot = ... #######
    ##############################

    return calV_rot

Check if the group action functions is implemented correctly.

In [None]:
calV = torch.load('assets/calV.pt', weights_only=True).to(device)

assert torch.allclose(action_rot_calV(calV, R), torch.load('assets/calV_rot.pt', weights_only=True).to(device), atol=1e-6)

#### $\mathrm{SO}(3)$-equivariance of VN layers

The $\mathrm{SO}(3)$-equivariance of the VN architecture is achieved by ensuring that each VN layer is $\mathrm{SO}(3)$-equivariant.

Given the number of input and output channels $C$ and $C'$, a VN layer $f: \mathbb{R}^{N \times C \times 3} \rightarrow \mathbb{R}^{N \times C' \times 3}$ must satisfy the following condition to be $\mathrm{SO}(3)$-equivariant:
$$f(R \cdot \mathcal{V}) = R \cdot f(\mathcal{V}) \quad \text{for all } \mathcal{V} \in \mathbb{R}^{N \times C \times 3} \text{ and } R \in \mathrm{SO}(3)$$

<b>Q. Implement functions to verify if the provided VN layer is $\mathrm{SO}(3)$-equivariant.
Use `torch.allclose` with the parameter `atol=1e-6`.</b>

In [None]:
def check_rot_equiv_VN_layer(layer, calV, R):
    """
    Parameters:
    - layer: nn.Module representing the VN layer
    - calV: torch.Tensor of shape (B, N, C, 3) representing the vector-list features
    - R: torch.Tensor of shape (3, 3) representing the rotation matrix

    Returns:
    - is_rot_equiv: bool representing whether the provided VN layer is rotation equivariant
    """

    ##############################
    ####### YOUR CODE HERE #######
    ##### is_rot_equiv = ... #####
    ##############################

    return is_rot_equiv

Check if the provided VN-BatchNorm layer is $\mathrm{SO}(3)$-equivariant.

In [None]:
EPS = 1e-6


class VNBatchNorm(nn.Module):
    def __init__(self, num_features, dim):
        super().__init__()

        self.dim = dim

        if dim == 3 or dim == 4:
            self.bn = nn.BatchNorm1d(num_features)
        elif dim == 5:
            self.bn = nn.BatchNorm2d(num_features)
    
    def forward(self, calV):
        '''
        calV: torch.Tensor of shape (B, N, C, 3) representing the vector-list features
        '''
        calV = calV.permute(0, 2, 3, 1).contiguous()

        norm = torch.norm(calV, dim=2) + EPS
        norm_bn = self.bn(norm)
        norm = norm.unsqueeze(2)
        norm_bn = norm_bn.unsqueeze(2)
        out = calV / norm * norm_bn

        out = out.permute(0, 3, 1, 2).contiguous()
        
        return out


VN_BN_layer = VNBatchNorm(64, 4).to(device)

In [None]:
is_rot_equiv = check_rot_equiv_VN_layer(VN_BN_layer, calV, R)

print(f"VN-BatchNorm layer is{' ' if is_perm_inv else ' NOT '}rotation equivariant.")

#### Implement VN-Linear layer

For a given weight matrix $\mathbf{W} \in \mathbb{R}^{C' \times C}$, the operation of the VN-Linear layer $f_\text{lin}(\cdot;\mathbf{W})$ is defined as follows:
$$f_\text{lin}(\{ \boldsymbol{V}_1, \cdots, \boldsymbol{V}_N \}; \mathbf{W}) = \{ \mathbf{W} \boldsymbol{V}_1, \cdots, \mathbf{W} \boldsymbol{V}_N \}$$

<b>Q. Implement the forward function to perform the operation of the VN-Linear layer.</b>

In [None]:
class VNLinear(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.W = torch.nn.Parameter(torch.randn(out_channels, in_channels)) # torch.Tensor of shape (C', C) representing weight matrix
    
    def forward(self, calV):
        '''
        Parameters:
        - calV: torch.Tensor of shape (B, N, C, 3) representing the vector-list features

        Returns:
        - out: torch.Tensor of shape (B, N, C', 3) representing the output vector-list features
        '''
        ##############################
        ####### YOUR CODE HERE #######
        ########## out = ... #########
        ##############################

        return out


VN_Linear_layer = VNLinear(64, 128).to(device)

Check if the implemented VN-Linear layer is $\mathrm{SO}(3)$-equivariant.

In [None]:
is_rot_equiv = check_rot_equiv_VN_layer(VN_Linear_layer, calV, R)

print(f"The implemented VN-Linear layer is{' ' if is_perm_inv else ' NOT '}rotation equivariant.")

The complete architecture of VN-DGCNN, which incorporates VNs into DGCNN, is structured as follows.

In [None]:
class VNBatchNorm2(nn.Module):
    def __init__(self, num_features, dim):
        super().__init__()

        self.dim = dim

        if dim == 3 or dim == 4:
            self.bn = nn.BatchNorm1d(num_features)
        elif dim == 5:
            self.bn = nn.BatchNorm2d(num_features)

    def forward(self, calV):
        '''
        calV: torch.Tensor of shape (B, C, 3, N, ...) representing the vector-list features
        '''
        norm = torch.norm(calV, dim=2) + EPS
        norm_bn = self.bn(norm)
        norm = norm.unsqueeze(2)
        norm_bn = norm_bn.unsqueeze(2)
        calV = calV / norm * norm_bn

        return calV


class VNLinearLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, negative_slope=0.2):
        super().__init__()

        self.dim = dim
        self.negative_slope = negative_slope

        self.linear = nn.Linear(in_channels, out_channels, bias=False)
        self.batchnorm = VNBatchNorm2(out_channels, dim=dim)

        if share_nonlinearity == True:
            self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
        else:
            self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False)

    def forward(self, calV):
        '''
        calV: torch.Tensor of shape (B, C, 3, N, ...) representing the vector-list features
        '''
        # Linear
        p = self.linear(calV.transpose(1, -1)).transpose(1, -1)

        # BatchNorm
        p = self.batchnorm(p)

        # LeakyReLU
        d = self.map_to_dir(calV.transpose(1, -1)).transpose(1, -1)
        dotprod = (p * d).sum(2, keepdims=True)
        mask = (dotprod >= 0).float()
        d_norm_sq = (d * d).sum(2, keepdims=True)
        x_out = self.negative_slope * p + (1 - self.negative_slope) * (mask * p + (1 - mask) * (p - (dotprod / (d_norm_sq + EPS)) * d))

        return x_out


class VNStdFeature(nn.Module):
    def __init__(self, in_channels, dim=4, normalize_frame=False, share_nonlinearity=False, negative_slope=0.2):
        super().__init__()

        self.dim = dim
        self.normalize_frame = normalize_frame

        self.vn1 = VNLinearLeakyReLU(in_channels, in_channels//2, dim=dim, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope)
        self.vn2 = VNLinearLeakyReLU(in_channels//2, in_channels//4, dim=dim, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope)

        self.vn_lin = nn.Linear(in_channels//4, 3, bias=False)

    def forward(self, calV):
        '''
        calV: torch.Tensor of shape (B, C, 3, N) representing the vector-list features
        '''
        z0 = calV
        z0 = self.vn1(z0)
        z0 = self.vn2(z0)
        z0 = self.vn_lin(z0.transpose(1, -1)).transpose(1, -1)

        if self.normalize_frame:
            # make z0 orthogonal. u2 = v2 - proj_u1(v2)
            v1 = z0[:,0,:]
            v1_norm = torch.sqrt((v1*v1).sum(1, keepdims=True))
            u1 = v1 / (v1_norm+EPS)
            v2 = z0[:,1,:]
            v2 = v2 - (v2*u1).sum(1, keepdims=True)*u1
            v2_norm = torch.sqrt((v2*v2).sum(1, keepdims=True))
            u2 = v2 / (v2_norm+EPS)

            # compute the cross product of the two output vectors        
            u3 = torch.cross(u1, u2)
            z0 = torch.stack([u1, u2, u3], dim=1).transpose(1, 2)
        else:
            z0 = z0.transpose(1, 2)

        if self.dim == 4:
            x_std = torch.einsum('bijm,bjkm->bikm', calV, z0)
        elif self.dim == 3:
            x_std = torch.einsum('bij,bjk->bik', calV, z0)
        elif self.dim == 5:
            x_std = torch.einsum('bijmn,bjkmn->bikmn', calV, z0)

        return x_std, z0


class VN_DGCNN_cls(nn.Module):
    def __init__(self, k, num_classes):
        super().__init__()

        self.k = k

        self.conv1 = VNLinearLeakyReLU(2, 64//3)
        self.conv2 = VNLinearLeakyReLU(64//3*2, 64//3)
        self.conv3 = VNLinearLeakyReLU(64//3*2, 128//3)
        self.conv4 = VNLinearLeakyReLU(128//3*2, 256//3)

        self.conv5 = VNLinearLeakyReLU(256//3+128//3+64//3*2, 1024//3, dim=4, share_nonlinearity=True)

        self.std_feature = VNStdFeature(1024//3*2, dim=4, normalize_frame=False)
        self.linear1 = nn.Linear((1024//3)*12, 512)

        self.bn1 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.T.unsqueeze(0)

        batch_size = x.size(0)
        x = x.unsqueeze(1)
        x = get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x1 = x.mean(dim=-1, keepdim=False)

        x = get_graph_feature(x1, k=self.k)
        x = self.conv2(x)
        x2 = x.mean(dim=-1, keepdim=False)

        x = get_graph_feature(x2, k=self.k)
        x = self.conv3(x)
        x3 = x.mean(dim=-1, keepdim=False)

        x = get_graph_feature(x3, k=self.k)
        x = self.conv4(x)
        x4 = x.mean(dim=-1, keepdim=False)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv5(x)

        num_points = x.size(-1)
        x_mean = x.mean(dim=-1, keepdim=True).expand(x.size())
        x = torch.cat((x, x_mean), 1)
        x, _ = self.std_feature(x)
        x = x.view(batch_size, -1, num_points)

        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        x = F.leaky_relu(self.bn1(self.linear1(x)), negative_slope=0.2)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn2(self.linear2(x)), negative_slope=0.2)
        x = self.dp2(x)
        x = self.linear3(x)

        return x


def get_graph_feature(x, k=20, idx=None, x_coord=None):
    batch_size = x.size(0)
    num_points = x.size(3)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if x_coord is None: # dynamic knn graph
            idx = knn(x, k=k)
        else:          # fixed knn graph with input point coordinates
            idx = knn(x_coord, k=k)

    idx_base = torch.arange(0, batch_size).to(idx).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()
    num_dims = num_dims // 3

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims, 3) 
    x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)

    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous()

    return feature

Build VN-DGCNN model

In [None]:
model_3 = VN_DGCNN_cls(k=20, num_classes=num_classes).to(device)
model_3.eval()

Check if VN-DGCNN is permutation, translation, and rotation invariant.

In [None]:
is_perm_inv = check_perm_inv_model(model_3, x, sigma)
is_trans_inv = check_trans_inv_model(model_3, x, p)
is_rot_inv = check_rot_inv_model(model_3, x, R)

print(f"VN-DGCNN is{' ' if is_perm_inv else ' NOT '}permutation invariant.")
print(f"VN-DGCNN is{' ' if is_trans_inv else ' NOT '}translation invariant.")
print(f"VN-DGCNN is{' ' if is_rot_inv else ' NOT '}rotation invariant.")

<hr>

### Final model with permutation, translation, and rotation invariances.

The final model incorporates permutation, translation, and rotation invariances by achieving $\mathbb{R}^3$-invariance in the VN-DGCNN model through canonicalization.

In [None]:
model_4 = ModelWithCanonicalization(model_3)

Check if VN-DGCNN with canonicalization is permutation, translation, and rotation invariant.

In [None]:
is_perm_inv = check_perm_inv_model(model_4, x, sigma)
is_trans_inv = check_trans_inv_model(model_4, x, p)
is_rot_inv = check_rot_inv_model(model_4, x, R)

print(f"VN-DGCNN with canonicalization is{' ' if is_perm_inv else ' NOT '}permutation invariant.")
print(f"VN-DGCNN with canonicalization is{' ' if is_trans_inv else ' NOT '}translation invariant.")
print(f"VN-DGCNN with canonicalization is{' ' if is_rot_inv else ' NOT '}rotation invariant.")

<hr>

### Advantages of Equivariant Models

<b>Q. In general, what advantages do equivariant models have compared to non-equivariant models?</b>

### References
[1] A. X. Chang, T. Funkhouser, L. Guibas, P. Hanrahan, Q. Huang, Z. Li, S. Savarese, M. Savva, S. Song, H. Su, J. Xiao, L. Yi, F. Yu, ShapeNet: An Information-Rich 3D Model Repository, arXiv.

[2] Y. Wang, Y. Sun, Z. Liu, S. E. Sarma, M. M. Bronstein, J. M. Solomon, Dynamic Graph CNN for Learning on Point Clouds, TOG 2019.

[3] C. Deng, O. Litany, Y. Duan, A. Poulenard, A. Tagliasacchi, L. Guibas, Vector Neurons: A General Framework for SO(3)-Equivariant Networks, ICCV 2021.