In [43]:
import torch
import torch.nn as nn
import numpy as np

from torch.nn import functional as F

In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, p_dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.proj = nn.Linear(d_model, d_model)
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.heads_dim = d_model // num_heads
        
        self.dropout = nn.Dropout(p_dropout)
        
    def forward(self, q, k, v, mask=None):
        
        _, s_seq_len, _ = q.shape
        batch_size, t_seq_len, d_model = k.shape
        
        query = self.query(q)
        key = self.key(k)
        value = self.value(v)
        
        query = query.reshape((batch_size, s_seq_len, self.num_heads, self.heads_dim)).transpose(1, 2)
        key = key.reshape((batch_size, t_seq_len, self.num_heads, self.heads_dim)).transpose(1, 2)
        value = value.reshape((batch_size, t_seq_len, self.num_heads, self.heads_dim)).transpose(1, 2)
        
        scaled_dot_product = (query @ key.transpose(2, 3)) / np.sqrt(self.heads_dim)
        
        if mask is not None:
            scaled_dot_product = scaled_dot_product.masked_fill(mask == 0, float('-inf'))
            
        attention = (F.softmax(scaled_dot_product, dim=-1) @ value).transpose(1, 2).reshape(batch_size, s_seq_len, d_model)
        
        output = self.proj(attention)
        
        return self.dropout(output)

In [55]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, p_dropout=0.1, max_length=5000):
        super(PositionalEncoding, self).__init__()
        
        self.pe = torch.zeros(1, max_length, d_model)
        self.dropout = nn.Dropout(p_dropout)
        
        for pos in range(max_length):
            for i in range(0, d_model, 2):
                self.pe[:, pos, i] = np.sin(i / 10_000**(i/d_model))
                self.pe[:, pos, i+1] = np.cos(i / 10_000**(i/d_model))
        
    def forward(self, x):
        batch_size, sequence_length, d_model = x.shape
        
        return self.dropout(x + self.pe[:, :sequence_length, :])

In [56]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, p_dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        self.pe = PositionalEncoding(d_model, p_dropout)
        self.mha = MultiHeadAttention(d_model, num_heads, p_dropout)
        self.ln_mha = nn.LayerNorm(d_model)
        
        self.ff1 = nn.Linear(d_model, d_model * 4)
        self.relu = nn.ReLU()
        self.ff2 = nn.Linear(d_model * 4, d_model)
        self.ln_ff = nn.LayerNorm(d_model)
        
    def forward(self, x):
        x = self.pe(x)
        out = self.mha(x, x, x)
        x = out + x
        x = self.ln_mha(x)
        
        out = self.ff1(x)
        out = self.relu(out)
        out = self.ff2(out)
        x = out + x 
        x = self.ln_ff(x)
        
        return x

In [57]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nun_heads, p_dropout=0.1):
        super(TransformerDecoder, self).__init__()
        
        self.pe = PositionalEncoding(d_model, p_dropout)
        self.mmha = MultiHeadAttention(d_model, nun_heads, p_dropout)
        self.ln_mmha = nn.LayerNorm(d_model)
        
        self.mha = MultiHeadAttention(d_model, nun_heads, p_dropout)
        self.ln_mha = nn.LayerNorm(d_model)
                
        self.ff1 = nn.Linear(d_model, d_model * 4)
        self.relu = nn.ReLU()
        self.ff2 = nn.Linear(d_model * 4, d_model)
        self.ln_ff = nn.LayerNorm(d_model)
        
    def forward(self, x, y):
        _, t_seq_len, _ = y.shape
        
        y = self.pe(y)
        
        mask = torch.tril(torch.ones(size=(t_seq_len, t_seq_len)))
        out = self.mmha(y, y, y, mask)
        y = out = y
        y = self.ln_mmha(y)
        
        out = self.mha(y, x, x)
        y = out + y
        y = self.ln_mha(y)
        
        out = self.ff1(y)
        out = self.relu(out)
        out = self.ff2(out)
        y = out + y
        y = self.ln_ff(y)
        
        return y

In [58]:
class Transformer(nn.Module):
    def __init__(self, d_model, vocab_size, num_heads, p_dropout=0.1):
        super(Transformer, self).__init__()
        
        self.enc_embedding = nn.Embedding(vocab_size, d_model)
        self.dec_embedding = nn.Embedding(vocab_size, d_model)
        
        self.encoder = TransformerEncoder(d_model, num_heads)
        self.decoder = TransformerDecoder(d_model, num_heads)
        
        self.proj = nn.Linear(d_model, vocab_size)
        
        
    def forward(self, src, tgt):
        src_emb = self.enc_embedding(src)
        src_out = self.encoder(src_emb)
        
        tgt_emb = self.dec_embedding(tgt)
        out = self.decoder(src_out, tgt_emb)
        
        out = self.proj(out)
        
        return out

In [59]:
d_model = 16
vocab_size = 10000
num_heads = 2
seq_length = 10
batch_size = 8

# Transformer instance
model = Transformer(d_model, vocab_size, num_heads)

# Генерируем случайные значения для X и y
X = torch.randint(0, vocab_size, (batch_size, seq_length))
y = torch.randint(0, vocab_size, (batch_size, seq_length))

# Проходим через модель
output = model(X, y)

print(output.shape) 

torch.Size([8, 10, 10000])
