In [9]:
import torch
def create_temporal_mask(seq_len, look_ahead=0, device=None):
    """
    Create a temporal attention mask where each position i can attend to:
    positions [0, ..., i + look_ahead], up to seq_len - 1

    Returns: (1, 1, seq_len, seq_len) boolean tensor
    """
    i = torch.arange(seq_len, device=device).unsqueeze(1)  # query positions
    j = torch.arange(seq_len, device=device).unsqueeze(0)  # key positions

    mask = j <= i + look_ahead  # allow attention to past and limited future
    return mask.unsqueeze(0).unsqueeze(0)  # shape: (1, 1, seq_len, seq_len)



In [11]:
create_temporal_mask(10)

tensor([[[[ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]]])

In [76]:
from torch import Tensor
from torch import nn
import math
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000, patches_per_time=16):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.repeat_interleave(torch.arange(max_len), repeats=patches_per_time).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        print(div_term)
        pe = torch.zeros(max_len*patches_per_time, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [78]:
pe = PositionalEncoding(8)
pe.pe[16:32, :, 0:8]

tensor([1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03])


tensor([[[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],

        [[0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000]],


In [103]:
import torch
import torch.nn as nn

class LearnablePositionalEmbeddings(nn.Module):
    def __init__(self, N, M, embedding_dim):
        super(LearnablePositionalEmbeddings, self).__init__()
        
        # Set the size of the embeddings
        self.embedding_dim = embedding_dim
        self.N = N
        self.M = M


        # Initialize the learnable positional embeddings
        self.embeddings_N = nn.Parameter(torch.randn(N, embedding_dim))  # Embeddings for every Nth patch
        self.embeddings_M = nn.Parameter(torch.randn(M, embedding_dim))  # Embeddings for every M consecutive patches

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_patches, embedding_dim)
        """
        batch_size, num_patches, _ = x.size()
        
        pos_N_embedding = self.embeddings_N.tile(dims=(num_patches // N, 1))
        pos_M_embedding = self.embeddings_M.repeat_interleave(num_patches // M, dim=0)
        
        
        # Add the positional embeddings to the input tensor (x)
        x = x + pos_N_embedding + pos_M_embedding

        return x


# Example usage:
num_patches = 16   # Number of patches (this would be the number of spatial patches or time steps)
embedding_dim = 64  # Embedding dimension size
N = 4               # Nth patch interval for first set of embeddings
M = 2               # M consecutive patches for second set of embeddings

# Initialize the model
model = LearnablePositionalEmbeddings(N, M, embedding_dim)

# Example input tensor of shape (batch_size, num_patches, embedding_dim)
batch_size = 8
x = torch.randn(batch_size, num_patches, embedding_dim)

# Forward pass
output = model(x)
print(output.shape)  # Should print (batch_size, num_patches, embedding_dim)


torch.Size([8, 16, 64]) torch.Size([16, 64]) torch.Size([16, 64])
torch.Size([8, 16, 64])


In [52]:
import torch
from torch import nn
import math

class HybridSpatiotemporalPosEmb(nn.Module):
    def __init__(self, num_space, max_time, embedding_dim):
        """
        num_space: number of spatial positions (N)
        max_time: number of time steps (T)
        embedding_dim: size of each positional embedding (must be even)
        """
        super().__init__()
        self.embedding_dim = embedding_dim
        self.N = num_space
        self.T = max_time

        assert embedding_dim % 2 == 0, "Embedding dimension must be even for sin/cos"

        # Learnable spatial embeddings
        self.space_embedding = nn.Parameter(torch.randn(num_space, embedding_dim))

        # Fixed sinusoidal temporal embeddings
        self.register_buffer("time_embedding", self._build_sin_cos_embedding(max_time, embedding_dim))

    def _build_sin_cos_embedding(self, length, dim):
        """
        Generate fixed sinusoidal embeddings of shape (length, dim)
        """
        position = torch.arange(1, length + 1).unsqueeze(1).float()  # (length, 1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))  # (dim/2,)
        sinusoid = torch.zeros(length, dim)
        sinusoid[:, 0::2] = torch.sin(position * div_term)
        sinusoid[:, 1::2] = torch.cos(position * div_term)
        return sinusoid  # (length, dim)

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_patches, embedding_dim)
        Assumes patches are ordered as:
            [t0_p0, t0_p1, ..., t0_pN-1, t1_p0, ..., tT-1_pN-1]
        """
        batch_size, num_patches, _ = x.size()
        T = num_patches // self.N

        # Compute spatial and temporal indices
        spatial_idx = torch.arange(num_patches, device=x.device) % self.N
        temporal_idx = torch.arange(num_patches, device=x.device) // self.N

        # Lookup embeddings
        pos_space_embedding = self.space_embedding[spatial_idx]     # (num_patches, embedding_dim)
        pos_time_embedding = self.time_embedding[temporal_idx]      # (num_patches, embedding_dim)

        # Combine and expand to batch
        pos_embedding = pos_space_embedding + pos_time_embedding
        pos_embedding = pos_embedding.unsqueeze(0).expand(batch_size, -1, -1)

        return x + pos_embedding, pos_space_embedding, pos_time_embedding


In [58]:
x = torch.zeros((1,20,6))
learnable_embeds = HybridSpatiotemporalPosEmb(num_space=5, max_time=20, embedding_dim=6)
x2, space_embeds, time_embeds = learnable_embeds(x)

In [5]:
import torch
def create_temporal_attention_mask(num_patches, patches_per_timestep=4, N=2):
    
    mask = torch.full((num_patches, num_patches), 0)

    timesteps = num_patches // patches_per_timestep
    
    for t_q in range(timesteps):  # time index of query
        for dt in range(N + 1):  # how far back to look
            t_k = t_q - dt  # key timestep
            if t_k < 0:
                continue
            q_start = t_q * patches_per_timestep
            q_end = (t_q + 1) * patches_per_timestep
            k_start = t_k * patches_per_timestep
            k_end = (t_k + 1) * patches_per_timestep
            # allow attention: set to 0 (non-masked)
            mask[q_start:q_end, k_start:k_end] = 1

    return mask  # shape: [Num Patches, Num Patches]

In [6]:
create_temporal_attention_mask(32)

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])