# Tranformer

In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

## 1. Механизм внимания (Scaled Dot-Product Attention)

In [6]:
class ScaledDotProductAttention(nn.Module):
    """Attention"""

    def __init__(self, d_k):
        """Args:
        d_k (_type_): Размерность ключа
        """
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, mask=None):
        # Скалярное произведение Q и K
        scores = torch.matmul(Q, K.transponse(-2, -1)) / math.sqrt(self.d_k)

        # Применяем маску, если она есть
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)

        # Скалярное произведение attention и V
        output = torch.matmul(attention, V)
        return output, attention

## Multihead attention

In [7]:
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    """Multi-head attention"""

    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heada = num_heads
        self.d_k = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.fc_out = nn.Linear(d_model, d_model)

    def forward(Q, K, V, mask=None):
        batch_size = Q.shape[0]

        # Линейные преобразования для Q, K, V.
        Q = self.query(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.query(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.query(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Применим механизм внимания
        attn_output, _ = ScaledDotProductAttention(self.d_k)(Q, K, V, mask)

        # Объединяем головы
        attn_output = (
            attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        )
        output = self.fc_out(attn_output)
        return output

## Positional encoding

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_length=5000):
        super(PositionalEncoding, self).__init__()

        self.d_model = d_model

        # Create a matrix of max_len x d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arrange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arrange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

        def forward(self, x):
            # Добавляем позиционные коды к эмбеддингам токенов
            x = x + self.pe[:, : x.size(1)]
            return x

# Feed forward

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.dropout(torch.relu(self.linear1(x)))
        x = self.linear2(x)
        return x


# Add & Norm Layer
normalize the outputs and apply residual connections.

In [None]:
class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(self.norm(sublayer(x)))


# Encoder layer

combine the components to build the encoder layer, which consists of multi-head attention, add & norm, and feed-forward networks.

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.add_norm1 = AddNorm(d_model, dropout)
        self.add_norm2 = AddNorm(d_model, dropout)

    def forward(self, x, mask=None):
        # Multi-head attention and Add & Norm
        x = self.add_norm1(x, lambda x: self.attention(x, x, x, mask))
        # Feed-forward and Add & Norm
        x = self.add_norm2(x, self.ffn)
        return x


# Decoder layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(d_model, num_heads)
        self.attention2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, dim_ff, dropout)
        self.add_norm1 = AddNorm(d_model, dropout)
        self.add_norm2 = AddNorm(d_model, dropout)
        self.add_norm3 = AddNorm(d_model, dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Masked multi-head attention (self-attention for decoder)
        x = self.add_norm1(x, lambda x: self.masked_attention(x, x, x, tgt_mask))

        # Encoder-Decoder multi-head attention (attending to encoder output)
        x = self.add_norm2(
            x, lambda x: self.enc_dec_attention(x, enc_output, enc_output, src_mask)
        )

        # Feed-forward and Add & Norm
        x = self.add_norm3(x, self.ffn)

        return x


# Encoder

Encoder consists of multiple encoder layers stacked together.