### Masked Autoencoders for 3D point cloud self-supervised learning

##### high-level Notes on the concept

It proposes a BERT-style pre-training strategy by
masking input tokens of the point cloud and then adopts a transformer architecture to predict discrete tokens of the masked tokens.
However, this method is relatively sophisticated as it is required to
train a DGCNN-based discrete Variational AutoEncoder(dVAE)
before pre-training and relies heavily on contrastive learning as well
as data augmentation during pre-training.

Moreover, the masked
tokens from their inputs are processed from the input of transformer during pre-training, leading to early leakage of location information
and high consumption of computing resources. Diﬀerent from their
method, and more importantly, to introduce masked autoencoding
to the point cloud, we aim to design a neat and eﬃcient scheme of
masked autoencoders. To this end, we first analyse the main chal-
lenges of introducing masked autoencoding for point cloud from the
following aspects:

1. The first is the lack of a unified transformer architecture. Com-
pared to transformer in NLP and Vision Transformer (ViT)
in computervision, transformerarchitectures forpointcloud are
lessstudiedandrelatively diverse,mainlybecausesmalldatasets
cannot meet the large data demand of transformer. Diﬀerent
from previous methods that use dedicated transformer or adopt
extra non-transformer models to assist (such as **Point-BERT**
uses an extra **DGCNN**), we aim to build our autoencoder’s
backbone entirely based on standard transformer, which can
serve as a potential unified architecture for point cloud. This
also enables further development for point cloud to join general
multi-modality frameworks, such as **Data2vec**.

2. Positional embeddings for mask tokens lead to **leakage of location** information. In masked autoencoders, each masked part
is replaced by a share-weighted learnable mask token. All the
mask tokens need to be provided with their location information in input data by positional embeddings. Then after
processing by autoencoders, each mask token is used to reconstruct the corresponding masked part. Providing location information is not an issue for languages and images because they
do not contain location information. While point cloud naturally has location information in the data, leakage of location
information to mask tokens makes the reconstruction task less
challenging, which is harmful for autoencoders learning latent
features. We address this issue by shifting mask tokens from the
input of the autoencoder’s encoder to the input of the autoencoder’s decoder. This delays the leakage of location information and enables the encoder to focus on learning features from
unmasked parts.

3. Point cloud carries information in a diﬀerent density compared to languages and images. Languages contain high-density information, while images contain heavy redundant informa-
tion.12 In the point cloud, information density distribution is
relatively uneven. The points that make up key local features
(e.g., sharp corners and edges) contain a much higher den-
sity of information than the points that make up less impor-
tant local features (e.g., flat surfaces). In other words, if being
masked, the points that contain high-density information are
more diﬃcult to be recovered in the reconstruction task. This
can be directly observed in reconstruction examples, as shown
in Figure 2. Taking the last row of Figure 2 for illustration,
the masked desk surface (left) can be easily recovered, while
the reconstruction of the masked motorcycle’s wheel (right) is
much worse. Although the point cloud contains uneven density
of information, we find that random masking at a high ratio
(60–80%) works well, which is surprisingly the same as images.
This indicates the point cloud is similar to images instead of
languages, in terms of information density.

<img src=./images/Point_MAE.png width=650 style="display:block; margin:auto;">

**FPS**:
The point cloud is processed using FPS (Farthest Point Sampling), a technique to select a subset of points (called centers) that are well-distributed across the point cloud. This reduces the number of points while preserving the overall structure

**KNN**: Using KNN (K-Nearest Neighbors), the point cloud is divided into point patches. Each patch is a group of points centered around one of the FPS-selected centers. For example, if a center point is selected, KNN finds its K nearest neighbors in the point cloud, forming a small local patch of points.

After the decoder, a **one-layer prediction head** (a simple neural network layer) is used to map the decoder’s output to the desired format.
The output of the prediction head is a prediction (Pred) of the masked point patches. Specifically, it predicts the coordinates or features of the points in the masked patches.

The loss function used is the **Chamfer Distance (CD)** with an ${l}_2$ norm:
Chamfer Distance measures the distance between two point sets (predicted and ground truth) by finding the nearest point correspondences.
The $l_2$ norm ensures the predicted points are as close as possible to the ground truth points in terms of Euclidean distance.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import sample_farthest_points, knn_points

# Hyperparameters
NUM_POINTS = 2048  # Total points in the point cloud
NUM_PATCHES = 128  # Number of patches
POINTS_PER_PATCH = 16  # Points per patch (NUM_POINTS / NUM_PATCHES)
MASK_RATIO = 0.8  # Mask 80% of patches
EMBED_DIM = 64  # Embedding dimension for tokens
NUM_HEADS = 4  # Number of attention heads in Transformer
NUM_LAYERS = 2  # Number of Transformer layers

class PointMAE(nn.Module):
    def __init__(self):
        super(PointMAE, self).__init__()
        
        # Token embedding layer for point patches
        self.patch_embed = nn.Linear(POINTS_PER_PATCH * 3, EMBED_DIM)
        
        # Positional encoding for patch centers
        self.pos_embed = nn.Linear(3, EMBED_DIM)
        
        # Mask token (learnable parameter)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, EMBED_DIM))
        
        # Encoder: Transformer for visible tokens
        encoder_layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
        
        # Decoder: Lightweight Transformer
        decoder_layer = nn.TransformerDecoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=1)
        
        # Prediction head to reconstruct point patches
        self.pred_head = nn.Linear(EMBED_DIM, POINTS_PER_PATCH * 3)

    def patchify(self, point_cloud):
        """
        Divide point cloud into patches using FPS and KNN.
        Input: point_cloud (batch_size, num_points, 3)
        Output: patches (batch_size, num_patches, points_per_patch, 3), centers (batch_size, num_patches, 3)
        """
        batch_size = point_cloud.shape[0]
        
        # FPS to select patch centers
        centers, _ = sample_farthest_points(point_cloud, K=NUM_PATCHES)
        
        # KNN to group points into patches
        _, idx, _ = knn_points(centers, point_cloud, K=POINTS_PER_PATCH)
        
        # Gather points for each patch
        patches = torch.gather(
            point_cloud.unsqueeze(1).expand(-1, NUM_PATCHES, -1, -1),
            2,
            idx.unsqueeze(-1).expand(-1, -1, -1, 3)
        )  # (batch_size, num_patches, points_per_patch, 3)
        
        return patches, centers

    def mask_patches(self, patches):
        """
        Randomly mask patches.
        Input: patches (batch_size, num_patches, points_per_patch, 3)
        Output: visible_patches, mask (binary mask indicating masked patches)
        """
        batch_size = patches.shape[0]
        num_patches = patches.shape[1]
        num_masked = int(num_patches * MASK_RATIO)
        
        # Randomly shuffle patch indices
        idx = torch.rand(batch_size, num_patches, device=patches.device).argsort(dim=1)
        mask = torch.ones(batch_size, num_patches, device=patches.device, dtype=torch.bool)
        mask.scatter_(1, idx[:, :num_masked], 0)  # Set masked patches to 0
        
        return mask

    def forward(self, point_cloud):
        """
        Forward pass of Point-MAE.
        Input: point_cloud (batch_size, num_points, 3)
        Output: reconstructed patches, loss
        """
        batch_size = point_cloud.shape[0]
        
        # Step 1: Patchify the point cloud
        patches, centers = self.patchify(point_cloud)  # (batch_size, num_patches, points_per_patch, 3)
        
        # Step 2: Mask patches
        mask = self.mask_patches(patches)  # (batch_size, num_patches)
        
        # Step 3: Embed patches into tokens
        patches_flat = patches.view(batch_size, NUM_PATCHES, -1)  # (batch_size, num_patches, points_per_patch * 3)
        tokens = self.patch_embed(patches_flat)  # (batch_size, num_patches, embed_dim)
        
        # Add positional encoding based on patch centers
        pos_enc = self.pos_embed(centers)  # (batch_size, num_patches, embed_dim)
        tokens = tokens + pos_enc
        
        # Step 4: Separate visible tokens for encoder
        visible_mask = mask.unsqueeze(-1).expand(-1, -1, EMBED_DIM)  # (batch_size, num_patches, embed_dim)
        visible_tokens = tokens * visible_mask  # Zero out masked tokens
        visible_tokens = visible_tokens[mask.bool()].view(batch_size, -1, EMBED_DIM)  # (batch_size, num_visible, embed_dim)
        
        # Step 5: Encoder processes visible tokens
        encoded = self.encoder(visible_tokens.transpose(0, 1)).transpose(0, 1)  # (batch_size, num_visible, embed_dim)
        
        # Step 6: Prepare decoder input (reinsert mask tokens)
        decoder_input = tokens.clone()
        mask_idx = (mask == 0).unsqueeze(-1).expand(-1, -1, EMBED_DIM)
        decoder_input[mask_idx] = self.mask_token.expand(batch_size, NUM_PATCHES, -1)[mask_idx]
        
        # Step 7: Decoder reconstructs masked patches
        decoded = self.decoder(
            tgt=decoder_input.transpose(0, 1),
            memory=encoded.transpose(0, 1)
        ).transpose(0, 1)  # (batch_size, num_patches, embed_dim)
        
        # Step 8: Prediction head reconstructs patches
        recon_patches = self.pred_head(decoded)  # (batch_size, num_patches, points_per_patch * 3)
        recon_patches = recon_patches.view(batch_size, NUM_PATCHES, POINTS_PER_PATCH, 3)
        
        # Step 9: Compute Chamfer Distance loss (only for masked patches)
        masked_patches = patches[mask == 0]  # Ground truth masked patches
        recon_masked = recon_patches[mask == 0]  # Reconstructed masked patches
        if masked_patches.shape[0] > 0:  # Ensure there are masked patches
            loss, _ = chamfer_distance(recon_masked, masked_patches)
        else:
            loss = torch.tensor(0.0, device=point_cloud.device)
        
        return recon_patches, loss

# Example usage
if __name__ == "__main__":
    # Dummy point cloud (batch_size, num_points, 3)
    point_cloud = torch.rand(2, NUM_POINTS, 3)
    
    # Initialize model
    model = PointMAE()
    
    # Forward pass
    recon_patches, loss = model(point_cloud)
    print(f"Chamfer Distance Loss: {loss.item()}")