### Atenção de canais



In [1]:
import torch
from torch import nn

class SqueezeExcitation(nn.Module):

    def __init__(self, in_channels, squeeze_channels=None):
        super().__init__()

        if squeeze_channels is None:
            squeeze_channels = in_channels

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, squeeze_channels, 1)
        self.activation = nn.ReLU()
        self.fc2 = nn.Conv2d(squeeze_channels, in_channels, 1)
        self.scale_activation = nn.Sigmoid()

    def forward(self, x):

        # bs x C x 1 x 1
        scale = self.avgpool(x)
        # Mistura os canais
        scale = self.fc1(scale)
        scale = self.activation(scale)
        scale = self.fc2(scale)
        # Normaliza no intervalo [0,1]
        scale = self.scale_activation(scale)
        # Reescala os canais 
        x = scale*x

        return x
    
se = SqueezeExcitation(16, 4)
y = se(torch.rand(8, 16, 64, 64))
y.shape

torch.Size([8, 16, 64, 64])

### Atenção espacial

In [2]:
class PatchifyLayer(nn.Module):
        
    def __init__(self, image_size=224, patch_size=16, hidden_dim=768):
        '''`image_size` precisa ser divisível por `patch_size`.'''
        super().__init__()

        # Cada região patch_size x patch_size da imagem é transformada 
        # em 1 x 1
        self.conv_proj = nn.Conv2d(
            3, hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        new_size = image_size//patch_size
        seq_length = new_size**2

        self.new_size = new_size
        self.seq_length = seq_length

    def forward(self, x):

        # (bs, c, image_size, image_size) -> (bs, hidden_dim, new_size, new_size)
        x = self.conv_proj(x)
        # (n, hidden_dim, new_size, new_size) -> (n, hidden_dim, (new_size*new_size))
        x = x.reshape(x.shape[0], x.shape[1], -1)
        # Coloca a dimensão espacial como segunda, pois o padrão de camadas de 
        # atenção é bs x seq_length x hidden_dim
        x = x.permute(0, 2, 1)

        return x

x = torch.rand(8, 3, 224, 224)
pl = PatchifyLayer()
y = pl(x)
y.shape

torch.Size([8, 196, 768])

Cada imagem em um batch é representada por 196 tokens, cada um possuindo 768 atributos

Artigo transformers: d_model = 512, heads=8
ViT: d_model/embed_dim = 768, heads=12

In [3]:
def attention(query, key, value):
    
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5
    p_attn = scores.softmax(dim=-1)
    out = torch.matmul(p_attn, value)

    return out

x = torch.rand(16, 196, 768)
out = attention(x, x, x)
out.shape

torch.Size([16, 196, 768])

In [4]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, heads=12, d_model=768):
        super().__init__()

        # Assumimos d_k=d_v
        d_k = d_model//heads
        self.heads = heads
        self.d_k = d_k
        self.proj_query = nn.Linear(d_model, d_model)
        self.proj_key = nn.Linear(d_model, d_model)
        self.proj_value = nn.Linear(d_model, d_model)
        self.final = nn.Linear(d_model, d_model)

    def proj_and_reshape(self, layer, x):
        '''Aplica uma transformação linear e redimensiona o resultado
        para ser usado na função de atenção.'''

        nbatches = x.shape[0]
        # A multiplicação x*layer.weight abaixo possui dimensão:
        # bs x n x d_model * d_model x heads*d_k
        # Cada sequência de tamanho d_model é multiplicada por uma coluna da 
        # camada linear. Isso é equivalente a fazer a sequinte operação:
        # Aplicar `heads`` camadas, cada uma com tamanho d_model x d_k, nas sequências
        # e depois concatenar. 
        x = layer(x)
        # Visualiza o resultado como uma matriz bs x heads x n x d_k. Isso
        # possibilita aplicar a função `attention` nas dimensões n x d_k
        x = x.view(nbatches, -1, self.heads, self.d_k).transpose(1, 2)

        return x

    def forward(self, query, key, value):

        nbatches = query.shape[0]
        query_proj = self.proj_and_reshape(self.proj_query, query)
        key_proj = self.proj_and_reshape(self.proj_key, key)
        value_proj = self.proj_and_reshape(self.proj_value, value)

        x = attention(query_proj, key_proj, value_proj)
        # Redimensiona de bs x heads x n x d_k para bs x n x d_model
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.heads*self.d_k)

        return self.final(x)
    
x = torch.rand(8, 196, 768)
mha = MultiHeadedAttention()
y = mha(x, x, x)
y.shape


torch.Size([8, 196, 768])

In [43]:
mha = nn.MultiheadAttention(embed_dim=768, num_heads=12, batch_first=True)
y = mha(x, x, x)

In [None]:
class MLP(nn.Module):

    def __init__(self, in_channels=768, hidden_channels=3072):

        self.layers = nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            torch.nn.Linear(hidden_channels, in_channels),
        )

    def forward(self, x):
        return self.layers(x)

class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLP(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y
    
mlp = MLP()