### Masked Autoencoders for 3D point cloud self-supervised learning

##### high-level Notes on the concept

It proposes a BERT-style pre-training strategy by
masking input tokens of the point cloud and then adopts a transformer architecture to predict discrete tokens of the masked tokens.
However, this method is relatively sophisticated as it is required to
train a DGCNN-based discrete Variational AutoEncoder(dVAE)
before pre-training and relies heavily on contrastive learning as well
as data augmentation during pre-training.

Moreover, the masked
tokens from their inputs are processed from the input of transformer during pre-training, leading to early leakage of location information
and high consumption of computing resources. Diﬀerent from their
method, and more importantly, to introduce masked autoencoding
to the point cloud, we aim to design a neat and eﬃcient scheme of
masked autoencoders. To this end, we first analyse the main chal-
lenges of introducing masked autoencoding for point cloud from the
following aspects:

1. The first is the lack of a unified transformer architecture. Com-
pared to transformer in NLP and Vision Transformer (ViT)
in computervision, transformerarchitectures forpointcloud are
lessstudiedandrelatively diverse,mainlybecausesmalldatasets
cannot meet the large data demand of transformer. Diﬀerent
from previous methods that use dedicated transformer or adopt
extra non-transformer models to assist (such as **Point-BERT**
uses an extra **DGCNN**), we aim to build our autoencoder’s
backbone entirely based on standard transformer, which can
serve as a potential unified architecture for point cloud. This
also enables further development for point cloud to join general
multi-modality frameworks, such as **Data2vec**.

2. Positional embeddings for mask tokens lead to **leakage of location** information. In masked autoencoders, each masked part
is replaced by a share-weighted learnable mask token. All the
mask tokens need to be provided with their location information in input data by positional embeddings. Then after
processing by autoencoders, each mask token is used to reconstruct the corresponding masked part. Providing location information is not an issue for languages and images because they
do not contain location information. While point cloud naturally has location information in the data, leakage of location
information to mask tokens makes the reconstruction task less
challenging, which is harmful for autoencoders learning latent
features. We address this issue by shifting mask tokens from the
input of the autoencoder’s encoder to the input of the autoencoder’s decoder. This delays the leakage of location information and enables the encoder to focus on learning features from
unmasked parts.

3. Point cloud carries information in a diﬀerent density compared to languages and images. Languages contain high-density information, while images contain heavy redundant informa-
tion.12 In the point cloud, information density distribution is
relatively uneven. The points that make up key local features
(e.g., sharp corners and edges) contain a much higher den-
sity of information than the points that make up less impor-
tant local features (e.g., flat surfaces). In other words, if being
masked, the points that contain high-density information are
more diﬃcult to be recovered in the reconstruction task. This
can be directly observed in reconstruction examples, as shown
in Figure 2. Taking the last row of Figure 2 for illustration,
the masked desk surface (left) can be easily recovered, while
the reconstruction of the masked motorcycle’s wheel (right) is
much worse. Although the point cloud contains uneven density
of information, we find that random masking at a high ratio
(60–80%) works well, which is surprisingly the same as images.
This indicates the point cloud is similar to images instead of
languages, in terms of information density.

<img src=./images/Point_MAE.png width=650 style="display:block; margin:auto;">

**FPS**:
The point cloud is processed using FPS (Farthest Point Sampling), a technique to select a subset of points (called centers) that are well-distributed across the point cloud. This reduces the number of points while preserving the overall structure

**KNN**: Using KNN (K-Nearest Neighbors), the point cloud is divided into point patches. Each patch is a group of points centered around one of the FPS-selected centers. For example, if a center point is selected, KNN finds its K nearest neighbors in the point cloud, forming a small local patch of points.

After the decoder, a **one-layer prediction head** (a simple neural network layer) is used to map the decoder’s output to the desired format.
The output of the prediction head is a prediction (Pred) of the masked point patches. Specifically, it predicts the coordinates or features of the points in the masked patches.

The loss function used is the **Chamfer Distance (CD)** with an ${l}_2$ norm:
Chamfer Distance measures the distance between two point sets (predicted and ground truth) by finding the nearest point correspondences.
The $l_2$ norm ensures the predicted points are as close as possible to the ground truth points in terms of Euclidean distance.

In [None]:
# based on Point-MAE: Masked Autoencoders for Point Clouds paper
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import sample_farthest_points, knn_points

# Hyperparameters
NUM_POINTS = 2048  # Total points in the point cloud
NUM_PATCHES = 128  # Number of patches
POINTS_PER_PATCH = 16  # Points per patch (NUM_POINTS / NUM_PATCHES)
MASK_RATIO = 0.8  # Mask 80% of patches
EMBED_DIM = 64  # Embedding dimension for tokens
NUM_HEADS = 4  # Number of attention heads in Transformer
NUM_LAYERS = 2  # Number of Transformer layers

class PointMAE(nn.Module):
    def __init__(self):
        super(PointMAE, self).__init__()
        
        # Token embedding layer for point patches
        self.patch_embed = nn.Linear(POINTS_PER_PATCH * 3, EMBED_DIM)
        
        # Positional encoding for patch centers
        self.pos_embed = nn.Linear(3, EMBED_DIM)
        
        # Mask token (learnable parameter)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, EMBED_DIM))
        
        # Encoder: Transformer for visible tokens
        encoder_layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
        
        # Decoder: Lightweight Transformer
        decoder_layer = nn.TransformerDecoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=1)
        
        # Prediction head to reconstruct point patches
        self.pred_head = nn.Linear(EMBED_DIM, POINTS_PER_PATCH * 3)

    def patchify(self, point_cloud):
        """
        Divide point cloud into patches using FPS and KNN.
        Input: point_cloud (batch_size, num_points, 3)
        Output: patches (batch_size, num_patches, points_per_patch, 3), centers (batch_size, num_patches, 3)
        """
        batch_size = point_cloud.shape[0]
        
        # FPS to select patch centers
        centers, _ = sample_farthest_points(point_cloud, K=NUM_PATCHES)
        
        # KNN to group points into patches
        _, idx, _ = knn_points(centers, point_cloud, K=POINTS_PER_PATCH)
        
        # Gather points for each patch
        patches = torch.gather(
            point_cloud.unsqueeze(1).expand(-1, NUM_PATCHES, -1, -1),
            2,
            idx.unsqueeze(-1).expand(-1, -1, -1, 3)
        )  # (batch_size, num_patches, points_per_patch, 3)
        
        return patches, centers

    def mask_patches(self, patches):
        """
        Randomly mask patches.
        Input: patches (batch_size, num_patches, points_per_patch, 3)
        Output: visible_patches, mask (binary mask indicating masked patches)
        """
        batch_size = patches.shape[0]
        num_patches = patches.shape[1]
        num_masked = int(num_patches * MASK_RATIO)
        
        # Randomly shuffle patch indices
        idx = torch.rand(batch_size, num_patches, device=patches.device).argsort(dim=1)
        mask = torch.ones(batch_size, num_patches, device=patches.device, dtype=torch.bool)
        mask.scatter_(1, idx[:, :num_masked], 0)  # Set masked patches to 0
        
        return mask

    def forward(self, point_cloud):
        """
        Forward pass of Point-MAE.
        Input: point_cloud (batch_size, num_points, 3)
        Output: reconstructed patches, loss
        """
        batch_size = point_cloud.shape[0]
        
        # Step 1: Patchify the point cloud
        patches, centers = self.patchify(point_cloud)  # (batch_size, num_patches, points_per_patch, 3)
        
        # Step 2: Mask patches
        mask = self.mask_patches(patches)  # (batch_size, num_patches)
        
        # Step 3: Embed patches into tokens
        patches_flat = patches.view(batch_size, NUM_PATCHES, -1)  # (batch_size, num_patches, points_per_patch * 3)
        tokens = self.patch_embed(patches_flat)  # (batch_size, num_patches, embed_dim)
        
        # Add positional encoding based on patch centers
        pos_enc = self.pos_embed(centers)  # (batch_size, num_patches, embed_dim)
        tokens = tokens + pos_enc
        
        # Step 4: Separate visible tokens for encoder
        visible_mask = mask.unsqueeze(-1).expand(-1, -1, EMBED_DIM)  # (batch_size, num_patches, embed_dim)
        visible_tokens = tokens * visible_mask  # Zero out masked tokens
        visible_tokens = visible_tokens[mask.bool()].view(batch_size, -1, EMBED_DIM)  # (batch_size, num_visible, embed_dim)
        
        # Step 5: Encoder processes visible tokens
        encoded = self.encoder(visible_tokens.transpose(0, 1)).transpose(0, 1)  # (batch_size, num_visible, embed_dim)
        
        # Step 6: Prepare decoder input (reinsert mask tokens)
        decoder_input = tokens.clone()
        mask_idx = (mask == 0).unsqueeze(-1).expand(-1, -1, EMBED_DIM)
        decoder_input[mask_idx] = self.mask_token.expand(batch_size, NUM_PATCHES, -1)[mask_idx]
        
        # Step 7: Decoder reconstructs masked patches
        decoded = self.decoder(
            tgt=decoder_input.transpose(0, 1),
            memory=encoded.transpose(0, 1)
        ).transpose(0, 1)  # (batch_size, num_patches, embed_dim)
        
        # Step 8: Prediction head reconstructs patches
        recon_patches = self.pred_head(decoded)  # (batch_size, num_patches, points_per_patch * 3)
        recon_patches = recon_patches.view(batch_size, NUM_PATCHES, POINTS_PER_PATCH, 3)
        
        # Step 9: Compute Chamfer Distance loss (only for masked patches)
        masked_patches = patches[mask == 0]  # Ground truth masked patches
        recon_masked = recon_patches[mask == 0]  # Reconstructed masked patches
        if masked_patches.shape[0] > 0:  # Ensure there are masked patches
            loss, _ = chamfer_distance(recon_masked, masked_patches)
        else:
            loss = torch.tensor(0.0, device=point_cloud.device)
        
        return recon_patches, loss

# Example usage
if __name__ == "__main__":
    # Dummy point cloud (batch_size, num_points, 3)
    point_cloud = torch.rand(2, NUM_POINTS, 3)
    
    # Initialize model
    model = PointMAE()
    
    # Forward pass
    recon_patches, loss = model(point_cloud)
    print(f"Chamfer Distance Loss: {loss.item()}")

so the whole concept of **Point_MAE**:

Driven by the analysis, we propose a novel self-supervisedlearning
framework for Point cloud by designing a neat and eﬃcient scheme
of Masked AutoEncoders, termed as Point-MAE. this Point-MAE mainly consists of a point cloud masking
and embedding module and an autoencoder. The input point cloud
is divided into irregular point patches, which are randomly masked
at a high ratio to reduce data redundancy. Then, the autoencoder
learnshigh-level latentfeaturesfromunmaskedpointpatches, aiming
to reconstruct masked point patches in coordinate space. Specifically,
our autoencoder’s backbone is entirely built by standard transformer
blocks and adopts an asymmetric encoder–decoder structure.12 The
encoder only processes unmasked point patches. Then taking both
encoded tokens and mask tokens as input, the lightweight decoder
with a simple prediction head reconstructs masked point patches.
Compared to processing mask tokens from the input of the encoder,
shifting mask tokens to the lightweight decoder results in significant
computational savings, more importantly, avoiding early leakage of
location information.

various SSL (self-supervised learning) strategy and pre-text task used for Point cloud data:

SSL has also been widely studied for point cloud representation learning. Pre-text tasks are relatively diverse. Among them, DepthContrast sets an instance discrimination task for two augmented
versions of an input point cloud. OcCo attempts to recover the
original point cloud from the occluded point cloud in camera views.
IAE adopts an autoencoder to reconstruct implicit features from
augmented inputs. A recent work Point-BERT proposes a BERT-style pre-training strategy by masking input tokens and aims to predict discrete tokens of masked parts, with the assistance of dVAE.
Diﬀerent from previous methods,we attempt to design a neat scheme
for point cloud self-supervised learning.

#### Important note on point-clouds patching and embedding:

Unlike images in computer vision that can be naturally divided into
regular patches, point cloud consists of unorderedpoints in 3D space.
Based onits property, weprocesstheinputpointcloud throughthree
stages: **point patches generation, masking, and embedding**.

$$
CT = FPS(X^i), \quad CT \in \mathbb{R}^{n \times 3} \\

P = KNN(X^i, CT), \quad P \in \mathbb{R}^{n \times k \times 3}.

$$

For the embedding of each masked point patch, we replace it with
a share-weighted learnable mask token. We denote the full set of
mask tokens as $T_m \in R^{mn \times C}$ , where $C$ is the embedding dimension.
For the unmasked (visible) point patches, a naive idea is to flatten
and embed them with a trainable linear projection, similar to ViT

encoder-decoder structure:

$$

T_e = \text{Encoder}(T_v), \quad T_e \in \mathbb{R}^{(1-m)n \times C} \\

H_m = \text{Decoder}(\text{concat}(T_e, T_m)), \quad H_m \in \mathbb{R}^{mn \times C}.

$$

prediction head --> predicted masked point patches $P_{pre}$:

$$
P_{\text{pre}} = \text{Reshape}(\text{FC}(H_m)), \quad P_{\text{pre}} \in \mathbb{R}^{mn \times k \times
3}.
$$

$H_m$ is the output from the decoder. which is given $T_e$ from encoder concatenated by $T_m$ which is masked tokens.

Our reconstruction target is to recover the coordinates of the points
in every masked point patch. Given the predicted point patches
$P_{pre}$ and ground truth $P_{gt}$, we compute the reconstruction loss by
**l2 Chamfer Distance**:

$$L = \frac{1}{|P_{\text{pre}}|} \sum_{a \in P_{\text{pre}}} \min_{b \in P_{\text{gt}}} \| a - b \|_2^2 + \frac{1}{|P_{\text{gt}}|} \sum_{b \in P_{\text{gt}}} \min_{a \in P_{\text{pre}}} \| a - b \|_2^2.$$

more on l2 chamfer distance:

Chamfer Distance (CD) is a metric used to measure the similarity between two point sets. It’s widely used in 3D shape reconstruction and generative models.

It’s a bidirectional matching between two point sets:
- The first term says: “For each predicted point, find the closest ground-truth point, and penalize the squared L2 distance.”
- The second term says: “For each ground-truth point, find the closest predicted point, and penalize the squared L2 distance.”


This ensures that:
1.	Every predicted point is close to some real point (**precision**).
2.	Every real point is reconstructed by some predicted point (**recall**)

experimentation:

In our Point-MAE, for diﬀerent resolutions of the input point
cloud, we divide them into diﬀerent numbers of patches with a linear
scaling. A typical input with p = 1024 points is divided into n =
64 point patches. For the KNN algorithm, we set k = 32 to keep
the number of points in each patch constant. In the autoencoder’s
backbone, the encoder has 12 transformer blocks, while the decoder
has 4 transformer blocks. Each transformer block has 384 hidden
dimensions and 6 heads. MLP ratio in transformer blocks is set to 4.
Note in downstream tasks, the decoder is discarded.

### Joint_MAE (leveraging 2D images for point cloud data)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Vision Transformer for 2D images (simplified)
class ViT2D(nn.Module):
    def __init__(self, patch_size=16, dim=256, depth=6, heads=8):
        super().__init__()
        self.patch_size = patch_size
        self.dim = dim
        self.patch_embed = nn.Conv2d(3, dim, patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, 196, dim))
        self.transformer = nn.ModuleList([
            nn.TransformerEncoderLayer(dim, heads) for _ in range(depth)
        ])
    
    def forward(self, x, mask=None):
        # x: (B, 3, H, W), e.g., (B, 3, 224, 224)
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # (B, num_patches, dim)
        x = x + self.pos_embed
        if mask is not None:
            x = x * mask.unsqueeze(-1)  # Apply mask
        for layer in self.transformer:
            x = layer(x)
        return x  # (B, num_patches, dim)

# Point Cloud Transformer for 3D (simplified)
class PointTransformer3D(nn.Module):
    def __init__(self, num_points=1024, dim=256, depth=6, heads=8):
        super().__init__()
        self.num_points = num_points
        self.dim = dim
        self.point_embed = nn.Linear(3, dim)
        self.pos_embed = nn.Parameter(torch.randn(1, num_points, dim))
        self.transformer = nn.ModuleList([
            nn.TransformerEncoderLayer(dim, heads) for _ in range(depth)
        ])
    
    def forward(self, x, mask=None):
        # x: (B, num_points, 3)
        x = self.point_embed(x)  # (B, num_points, dim)
        x = x + self.pos_embed
        if mask is not None:
            x = x * mask.unsqueeze(-1)
        for layer in self.transformer:
            x = layer(x)
        return x  # (B, num_points, dim)

# Local-Aligned Attention for cross-modal interaction
class LocalAlignedAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.scale = dim ** -0.5
    
    def forward(self, x_2d, x_3d):
        # x_2d: (B, num_patches, dim), x_3d: (B, num_points, dim)
        q = self.query(x_2d)  # (B, num_patches, dim)
        k = self.key(x_3d)    # (B, num_points, dim)
        v = self.value(x_3d)  # (B, num_points, dim)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = attn @ v  # (B, num_patches, dim)
        return out

# Joint-MAE Model
class JointMAE(nn.Module):
    def __init__(self, dim=256, depth=6, heads=8):
        super().__init__()
        self.dim = dim
        
        # 2D and 3D Encoders
        self.encoder_2d = ViT2D(dim=dim, depth=depth, heads=heads)
        self.encoder_3d = PointTransformer3D(dim=dim, depth=depth, heads=heads)
        
        # Joint Encoder
        self.joint_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads), num_layers=depth
        )
        
        # Modal-Shared Decoder
        self.shared_decoder = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        
        # Modal-Specific Decoders
        self.decoder_2d = nn.Linear(dim, dim)
        self.decoder_3d = nn.Linear(dim, dim)
        
        # Local-Aligned Attention
        self.local_attn = LocalAlignedAttention(dim)
    
    def forward(self, img, points, mask_2d, mask_3d):
        # img: (B, 3, H, W), points: (B, num_points, 3)
        # mask_2d: (B, num_patches), mask_3d: (B, num_points)
        
        # Encode unmasked patches/points
        feat_2d = self.encoder_2d(img, mask_2d)  # (B, num_patches, dim)
        feat_3d = self.encoder_3d(points, mask_3d)  # (B, num_points, dim)
        
        # Cross-modal attention
        feat_2d = feat_2d + self.local_attn(feat_2d, feat_3d)
        
        # Concatenate and joint encode
        feat = torch.cat([feat_2d, feat_3d], dim=1)  # (B, num_patches+num_points, dim)
        feat = self.joint_encoder(feat)
        
        # Split back
        feat_2d = feat[:, :feat_2d.size(1)]
        feat_3d = feat[:, feat_2d.size(1):]
        
        # Shared and specific decoding
        shared_2d = self.shared_decoder(feat_2d)
        shared_3d = self.shared_decoder(feat_3d)
        out_2d = self.decoder_2d(shared_2d)  # (B, num_patches, dim)
        out_3d = self.decoder_3d(shared_3d)  # (B, num_points, dim)
        
        return out_2d, out_3d

# Loss Function with Cross-Reconstruction
def joint_mae_loss(out_2d, out_3d, target_2d, target_3d, mask_2d, mask_3d):
    # out_2d, out_3d: predicted embeddings
    # target_2d, target_3d: ground truth embeddings
    # mask_2d, mask_3d: binary masks for masked patches/points
    
    # Reconstruction loss (only for masked patches/points)
    loss_2d = F.mse_loss(out_2d[mask_2d == 0], target_2d[mask_2d == 0])
    loss_3d = F.mse_loss(out_3d[mask_3d == 0], target_3d[mask_3d == 0])
    
    # Cross-reconstruction loss
    cross_loss_2d = F.mse_loss(out_2d[mask_2d == 0], target_3d[mask_2d == 0].mean(dim=1, keepdim=True))
    cross_loss_3d = F.mse_loss(out_3d[mask_3d == 0], target_2d[mask_3d == 0].mean(dim=1, keepdim=True))
    
    return loss_2d + loss_3d + 0.5 * (cross_loss_2d + cross_loss_3d)

# Example Usage
if __name__ == "__main__":
    batch_size, num_points, img_size = 8, 1024, 224
    img = torch.randn(batch_size, 3, img_size, img_size)
    points = torch.randn(batch_size, num_points, 3)
    mask_2d = torch.ones(batch_size, 196)  # Example: 224/16 = 14, 14*14 = 196 patches
    mask_3d = torch.ones(batch_size, num_points)
    mask_2d[:, :int(196*0.75)] = 0  # Mask 75% of 2D patches
    mask_3d[:, :int(num_points*0.75)] = 0  # Mask 75% of 3D points
    
    model = JointMAE()
    out_2d, out_3d = model(img, points, mask_2d, mask_3d)
    
    # Dummy targets (in practice, use encoder outputs on unmasked data)
    target_2d = torch.randn_like(out_2d)
    target_3d = torch.randn_like(out_3d)
    
    loss = joint_mae_loss(out_2d, out_3d, target_2d, target_3d, mask_2d, mask_3d)
    print(f"Loss: {loss.item()}")

the improvement of Joint-MAE models, from it's previous models are the concept of **Local-Aligend Cross Attention** between 2D and 3D. (cross-model learning concept)
- Select local neighborhoods of 2D patches and 3D points that likely correspond to the same region/object part.
- Let 2D tokens attend to their semantically aligned 3D counterparts, and vice versa.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LocalAlignedCrossAttention(nn.Module):
    def __init__(self, dim, k_neighbors=16, heads=8):
        super().__init__()
        self.k = k_neighbors
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        
        self.to_q = nn.Linear(dim, dim)
        self.to_kv = nn.Linear(dim, dim * 2)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, img_tokens, pc_tokens, img_pos, pc_pos):
        """
        img_tokens: (B, N_img, D)
        pc_tokens:  (B, N_pc, D)
        img_pos:    (B, N_img, 3) or 2D (x, y)
        pc_pos:     (B, N_pc, 3)

        Returns updated img_tokens.
        """
        B, N_img, D = img_tokens.shape
        N_pc = pc_tokens.size(1)
        H = self.heads

        # Project to Q (from image) and KV (from point cloud)
        Q = self.to_q(img_tokens).view(B, N_img, H, D // H)
        KV = self.to_kv(pc_tokens).view(B, N_pc, 2, H, D // H)
        K, V = KV[:, :, 0], KV[:, :, 1]  # (B, N_pc, H, D_head)

        # Step 1: Find KNN in pc_pos for each img_pos
        # We assume positions are aligned across modalities
        neighbors_idx = knn(pc_pos, img_pos, k=self.k)  # (B, N_img, k)

        # Step 2: Gather K and V for each image token's local pc neighborhood
        K_neighbors = batched_index_select(K, neighbors_idx)  # (B, N_img, k, H, D_head)
        V_neighbors = batched_index_select(V, neighbors_idx)

        # Step 3: Compute attention
        Q = Q.unsqueeze(2)  # (B, N_img, 1, H, D_head)
        attn = (Q * K_neighbors).sum(-1) * self.scale  # (B, N_img, k, H)
        attn = F.softmax(attn, dim=2)

        # Step 4: Apply attention
        out = (attn.unsqueeze(-1) * V_neighbors).sum(2)  # (B, N_img, H, D_head)
        out = out.reshape(B, N_img, D)

        return self.to_out(out)

in Point-MAE architecture, **transformer dimension** refers to the embedding dimension of tokens (i.e., the feature dimensionality of the inputs and outputs of the transformer layers). This dimension controls the size of the hidden representations in the encoder and decoder, and it is a key hyperparameter in transformer-based architectures.
in the original paper, The authors describe a transformer architecture that operates on grouped point patches, where each patch is converted into a token (a feature vector). These tokens are then fed into a ViT-like transformer encoder.
In their Point-MAE-Small configuration, they use D=384, and in Point-MAE-Base, D=768.

