In [1]:
from utils.position_encodings import SinusoidalPosEmb, RotaryPositionEncoding, RotaryPositionEncoding3D
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
from torch import nn
import math

class RotaryPositionEncoding(nn.Module):
    def __init__(self, feature_dim, pe_type='Rotary1D'):
        super().__init__()

        self.feature_dim = feature_dim
        self.pe_type = pe_type

    @staticmethod
    def embed_rotary(x, cos, sin):
        x2 = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).reshape_as(x).contiguous()
        x = x * cos + x2 * sin
        return x

    def forward(self, x_position):
        bsize, npoint = x_position.shape
        div_term = torch.exp(
            torch.arange(0, self.feature_dim, 2, dtype=torch.float, device=x_position.device)
            * (-math.log(10000.0) / (self.feature_dim)))
        div_term = div_term.view(1, 1, -1) # [1, 1, d]

        sinx = torch.sin(x_position * div_term)  # [B, N, d]
        cosx = torch.cos(x_position * div_term)

        sin_pos, cos_pos = map(
            lambda feat: torch.stack([feat, feat], dim=-1).view(bsize, npoint, -1),
            [sinx, cosx]
        )
        position_code = torch.stack([cos_pos, sin_pos] , dim=-1)

        if position_code.requires_grad:
            position_code = position_code.detach()

        return position_code

In [22]:
xy = torch.randn(128, 1)
pos_enc = RotaryPositionEncoding(128) # has to be devisible by 2
xy_enc = pos_enc(xy)
print("xy_enc.shape:", xy_enc.shape)
q_cos = xy_enc[...,0]
q_sin = xy_enc[...,1]
query = torch.randn(128, 128)
q_embed = RotaryPositionEncoding.embed_rotary(query, q_cos, q_sin)
print("q_embed.shape:", q_embed.shape)

xy_enc.shape: torch.Size([128, 1, 128, 2])
q_embed.shape: torch.Size([128, 128, 128])


In [2]:

xyz = torch.randn(2, 128, 3)
pos_enc = RotaryPositionEncoding3D(60) # has to be devisible by 6
xyz_enc = pos_enc(xyz)
print("xyz_enc.shape:", xyz_enc.shape)
q_cos = xyz_enc[...,0]
q_sin = xyz_enc[...,1]
query = torch.randn(2, 128, 60)
q_embed = RotaryPositionEncoding.embed_rotary(query, q_cos, q_sin)
print("q_embed.shape:", q_embed.shape)

xyz_enc.shape: torch.Size([2, 128, 60, 2])
q_embed.shape: torch.Size([2, 128, 60])
