In [5]:
import torch
from torch import nn

class PatchfyLayer (nn.Module):
    
    def __init__(self, image_size, patch_size, embed_dim):
        
        super().__init__()
        
        self.conv_proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        new_size = image_size/patch_size 
        # comprimento da sequência (número de tokens/patchs)
        seq_length = new_size**2
        self.embed_dim = embed_dim 
        self.new_size = new_size 
        self.seq_length = seq_length 
        
    def forward(self, x):
        
        # (bs, c, image_size,image_size) -> (bs,embed_dim,new_size,new_size)
        x = self.conv_proj(x)
        # (bs, embed_dim, new_size, new_size) -> (bs, embed_dim, new_size**2)
        x = x.reshape(x.shape[0], self.embed_dim, -1)
        ## (bs, embed_dim,new_size**2) -> (bs,new_size**2,embed_dim)
        x = x.permute(0, 2, 1)
        
        return x
    
    
x = torch.rand(8, 3, 224, 224)
pl = PatchfyLayer(image_size=224, patch_size=16, embed_dim=768)
tokens = pl(x)
# 8: tamanho do batch
# 196: número de tokens (tamanho do patch)
# 768: número de valores associados a cada token
tokens.shape

torch.Size([8, 196, 768])

## Multiplicação matricial em batches

In [7]:
x = torch.rand(196, 768)
w = torch.rand(768, 64)
y = torch.matmul(x, w)
y.shape

torch.Size([196, 64])

In [8]:
x = torch.rand(8, 12, 196, 768)
y = torch.matmul(x, w)
y.shape

torch.Size([8, 12, 196, 64])

In [9]:
linear = nn.Linear(768, 64)
y = linear(x)
y.shape

torch.Size([8, 12, 196, 64])

## Operação de atenção

In [12]:
def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2,-1))
    p_attn = scores.softmax(dim=-1)
    out = torch.matmul(p_attn, value)
    
    return out

out = attention(tokens, tokens, tokens)
out.shape

torch.Size([8, 196, 768])