In [1]:
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat

In [27]:
x = torch.rand(8, 3, 30, 480, 480) # b, c, t, w, h

## PATCH TOKENIZATION

In [28]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=480, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = [img_size, img_size]
        patch_size = [patch_size, patch_size]
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, T, H, W = x.shape
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        print(f'x shape 1: {x.shape}')
        x = self.proj(x)
        print(f'x shape 2: {x.shape}')
        W = x.size(-1)
        x = rearrange(x, 'b c h w -> b (h w) c') 
        return x, T, W

In [29]:
patching = PatchEmbed()

In [30]:
x, T, W = patching(x)

x shape 1: torch.Size([240, 3, 480, 480])
x shape 2: torch.Size([240, 768, 30, 30])


In [31]:
print(f'x shape: {x.shape}')
print(f'T: {T}')
print(f'W: {W}')

x shape: torch.Size([240, 900, 768])
T: 30
W: 30
