In [None]:
import torch
import torch.nn as nn

In [None]:
def get_spatial_pos_embedding():

    pos_emb_dim = 768

    grid_h = torch.arange(3, dtype=torch.float32)
    grid_w = torch.arange(3, dtype=torch.float32)
    grid = torch.meshgrid(grid_h, grid_w, indexing='ij')
    grid = torch.stack(grid, dim=0)

    # grid_h_positions -> (Number of patch tokens,)
    grid_h_positions = grid[0].reshape(-1)
    grid_w_positions = grid[1].reshape(-1)

    # Converting to (B, temb_dim / 2)
    grid_h_positions = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4)
    grid_w_positions = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4)

    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange( start=0, end=pos_emb_dim // 4, dtype=torch.float32) / (pos_emb_dim // 4)))

    grid_h_emb = grid_h_positions / factor
    grid_w_emb = grid_w_positions / factor

    return pos_emb = torch.cat([torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1), torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1)], dim=-1)

In [None]:
def get_temporal_pos_embedding():

    temb_dim = 768

    # 1D tensor of length batch size
    time_steps = torch.arange(28, dtype=torch.float32)

    # converting timestepsfrom (B) => (B, 1) => (B, temb_dim / 2)
    time_steps = time_steps[:, None].repeat(1, temb_dim // 2)

    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32) / (temb_dim // 2)) )

    # pos / factor
    t_emb = time_steps / factor

    return torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)

In [None]:
class DITVideo(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):

        # Shape of x is Batch_size x num_frames x Channels x H x W
        B, F, C, H, W = x.shape



        '''Patchify'''

        # rearrange to (Batch_size * num_frames) x Channels x H x W
        x = rearrange(x, 'b f c h w -> (b f) c h w')

        # B, C, H, W -> B, Number of Tokens, Patch Dimension || Number of tokens = Patches along height * Patches along width
        out = rearrange(x, 'b c (nh patch_height) (nw patch_width) -> b (nh nw) (patch_height patch_width c)', patch_height=2, patch_width=2)

        out = nn.Linear(2*2*3, 768)(out)

        # out->(Batch_size * num_frames) x num_patch_tokens x hidden_size
        num_patch_tokens = out.shape[1]



        '''First BLock'''
        # spatial layer
        out = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=4 * 768)(out + get_spatial_pos_embedding()) # dim => (Batch_size * num_frames) x num_patch_tokens x hidden_size

        # rearranging for temporal layer
        out = rearrange(out, '(b f) p d -> (b p) f d', b=B) # dim => (B * patch_tokens) x num_frames x hidden_size

        # temporal layer
        out = nn.TransformerEncoderLayer( d_model=768, nhead=12, dim_feedforward=4 * 768)(out + get_temporal_pos_embedding())




        '''Second Block'''

        # 2.a. spatial transformer
        out = rearrange(out, '(b p) f d -> (b f) p d',f=self.num_frames, p=num_patch_tokens) # Rearrange to (B * num_frames) x num_patch_tokens x hidden_size
        out = nn.TransformerEncoderLayer( d_model=768, nhead=12, dim_feedforward=4 * 768)(out)

        # 2.b. temporal transformer
        out = rearrange(out, '(b f) p d -> (b p) f d', b=B)
        out = nn.TransformerEncoderLayer( d_model=768, nhead=12, dim_feedforward=4 * 768)(out)



        '''Third Block'''

        # 3.a. spatial transformer
        out = rearrange(out, '(b p) f d -> (b f) p d',f=self.num_frames, p=num_patch_tokens)
        out = nn.TransformerEncoderLayer( d_model=768, nhead=12, dim_feedforward=4 * 768)(out)

        # 3.b. temporal transformer
        out = rearrange(out, '(b f) p d -> (b p) f d', b=B)
        out = nn.TransformerEncoderLayer( d_model=768, nhead=12, dim_feedforward=4 * 768)(out)




        ''' Rearranging back ''''
        out = rearrange(out, '(b p) f d -> (b f) p d',f=28, p=num_patch_tokens)



        ''' Unpatchify '''
        # (Batch_size * num_frames) x patches x hidden_size =>> (B * num_frames) x patches x (patch height*patch width*channels)
        out = nn.Linear(768, 2 * 2 * 3)(out)
        out = rearrange(out, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)', ph=2, pw=2, nw=3, nh=3)


        # out -> (Batch_size * num_frames) x channels x h x w
        out = out.reshape((B, F, C, H, W))
        return out