In [None]:
from init_notebook import *

In [None]:
import transformers


In [None]:
class PositionalEmbedding(torch.nn.Module):

    def __init__(
            self, 
            dimensions: int, 
            max_len: int = 128,
            k: Optional[int] = None,
            dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        self._dimensions = dimensions
        self._max_len = max_len
        if k is None:
            k = round(max_len / (20 + max_len / 20), 4)
        self._k = k
        
        pe = [[0] * dimensions for _ in range(max_len)]
        
        for pos in range(max_len):   
            # for each dimension of the each position
            for i in range(0, dimensions, 2):   
                pe[pos][i] = math.sin(pos / (k ** ((2 * i) / dimensions)))
                pe[pos][i + 1] = math.cos(pos / (k ** ((2 * (i + 1)) / dimensions)))

        self.pe = nn.Parameter(torch.tensor(pe, dtype=dtype), requires_grad=False)

    def extra_repr(self):
        return f"dimensions={self._dimensions}, max_len={self._max_len}, k={self._k}"
        
    def forward(self, x: Union[int, Tuple[int], List[int], torch.LongTensor]):
        if isinstance(x, int):
            return self.pe[x]
        elif not isinstance(x, torch.Tensor):       
            x = torch.tensor(x, dtype=torch.int64)

        return self.pe[x]

pe = PositionalEmbedding(32)
#pe([0, 1, 2, 0])
display(VF.to_pil_image(resize(.5+.5*pe.pe.T[:, -400:].unsqueeze(0), 5)))

In [None]:
class PositionalPatchEncoder(torch.nn.Module):

    def __init__(
            self, 
            patch_shape: Tuple[int, int, int],
            max_size: Tuple[int, int], 
    ):
        super().__init__()
        self._patch_shape = patch_shape
        self._max_size = max_size
        embed_dim = math.prod(patch_shape[-2:])
        self.pos_embed = PositionalEmbedding(embed_dim, max_len=max(max_size))
        
    def forward(self, patches: torch.Tensor, positions: torch.LongTensor):
        assert patches.ndim == 4, f"Expected patches.ndim == 4, got {patches.shape}"
        B, C, H, W = patches.shape
        assert positions.shape == torch.Size((B, 2)), f"Expected positions.shape == ({B}, 2), got {positions.shape}"
        
        pos_emb_x = self.pos_embed(positions[:, -1]).view(B, 1, *self._patch_shape[-2:])
        pos_emb_y = self.pos_embed(positions[:, -2]).view(B, 1, *self._patch_shape[-2:])

        patches_with_embeddings = torch.concat([patches, pos_emb_y, pos_emb_x], dim=-3) 
        
        return patches_with_embeddings
        
menc = PositionalPatchEncoder((1, 8, 8), (4, 4))
display(menc)
display(VF.to_pil_image(resize(make_grid(menc(
    patches=torch.randn(2, 1, 8, 8),
    positions=torch.LongTensor([[0, 1], [2, 1]])
    
)), 5).clamp(0, 1)))
#pe([0, 1, 2, 0])
#display(VF.to_pil_image(resize(.5+.5*pe.pe.T[:, -400:].unsqueeze(0), 5)))

In [None]:
from torchvision.cl