In [10]:
import torch
from torch import nn
from einops import rearrange


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, last):
        super().__init__()
        if not last:
            last=dim
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, last),
        )
    def forward(self, x):
        return self.net(x)

    
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

    
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, out_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.depth = depth
        for d in range(depth):
            if d == depth-1:
                last=out_dim
            else: 
                last=None
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim, last)
            ]))
    def forward(self, x):
        for d in range(self.depth):
            attn, ff = self.layers[d]
            x = attn(x) + x
            x = ff(x)
            if d!=self.depth-1:
                x+=x
        return x

In [14]:
x = torch.randn(1, 8, 3, 400, 400) # batch, frames, channels, h, w 
x = x.reshape(1, 256, 8, 3, 25, 25) # split int 256 25x25 patches

In [17]:
# example of convnet to flatten into vector

conv1 = nn.Conv2d(3, 32, 5, 1)
conv2 = nn.Conv2d(32, 128, 3, 3)
conv3 = nn.Conv2d(128, 256, 3, 3)

batch_idx = 0
patch = 0
frame = 0

patch = x[batch_idx, patch, frame,:]
res = conv1(patch)
res = conv2(res)
res = conv3(res)
res = res.flatten()
res.shape

torch.Size([1024])

In [24]:
x = torch.rand(1,256,8,1024) # batch, patches, frames, features

In [19]:
t = Transformer(dim=1024, depth=6, heads=8, dim_head=64, mlp_dim=2048, out_dim=1536) # transformer

In [25]:
# example of applying transformer across frames
patch = 0
res = x[:, patch, :, :]
x = t(res)
x.shape

torch.Size([1, 8, 1536])

In [28]:
x = x.reshape(1,3,64,64) #reshape to higher resolution image patch
x.shape

torch.Size([1, 3, 64, 64])