In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

In [2]:
torch.device('cuda')

device(type='cuda')

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm1(x + self.dropout(ff_output))
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        cross_attn = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm1(x + self.dropout(cross_attn))
        ff_output = self.feed_forward(x)
        x = self.norm1(x + self.dropout(ff_output))
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout, device):
        super(Transformer, self).__init__()
        self.device = device
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(self.device)  # (batch, 1, 1, src_len)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2).to(self.device) # (batch_size, 1, 1, tgt_len)
        seq_length = tgt.size(1)
        nopeak_mask = torch.tril(torch.ones(1, seq_length, seq_length)).bool().to(self.device)
        tgt_mask = tgt_mask & nopeak_mask.to(self.device) # (batch, 1, tgt_len, tgt_len)
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [9]:
pairs = [
    ("zero", "ноль"),
    ("one", "один"),
    ("two", "два"),
    ("three", "три"),
    ("four", "четыре"),
    ("five", "пять"),
    ("six", "шесть"),
    ("seven", "семь"),
    ("eight", "восемь"),
    ("nine", "девять"),
    ("ten", "десять"),
    ("eleven", "одинадцать"),
    ("twelve", "двенадцать"),
    ("thirteen", "тринадцать"),
    ("fourteen", "четырнадцать"),
    ("fifteen", "пятнадцать"),
    ("sixteen", "шестнадцать"),
    ("seventeen", "семнадцать"),
    ("eighteen", "восемнадцать"),
    ("nineteen", "девятнадцать"),
    ("twenty", "двадцать"),
    ("one two", "один два"),
    ("two three", "два три"),
    ("three four", "три четыре"),
    ("four five", "четыре пять"),
    ("five six", "пять шесть"),
    ("six seven", "шесть семь"),
    ("seven eight", "семь восемь"),
    ("eight nine", "восемь девять"),
    ("nine ten", "девять десять"),
    ("zero one", "ноль один"),
    ("ten eleven", "десять одинадцать"),
    ("eleven twelve", "одинадцать двенадцать"),
    ("twelve thirteen", "двенадцать тринадцать"),
    ("fifteen sixteen", "пятнадцать шестнадцать"),
    ("eighteen nineteen", "восемнадцать девятнадцать"),
    ("nineteen twenty", "девятнадцать двадцать"),
    ("one two three", "один два три"),
    ("two three four", "два три четыре"),
    ("three four five", "три четыре пять"),
    ("four five six", "четыре пять шесть"),
    ("five six seven", "пять шесть семь"),
    ("six seven eight", "шесть семь восемь"),
    ("seven eight nine", "семь восемь девять"),
    ("eight nine ten", "восемь девять десять"),
    ("zero one two", "ноль один два"),
    ("ten eleven twelve", "десять одинадцать двенадцать"),
    ("eleven twelve thirteen", "одинадцать двенадцать тринадцать"),
    ("eighteen nineteen twenty", "восемнадцать девятнадцать двадцать"),
    ("one two three four", "один два три четыре"),
    ("two three four five", "два три четыре пять"),
    ("three four five six", "три четыре пять шесть"),
    ("four five six seven", "четыре пять шесть семь"),
    ("five six seven eight", "пять шесть семь восемь"),
    ("six seven eight nine", "шесть семь восемь девять"),
    ("seven eight nine ten", "семь восемь девять десять"),
    ("zero one two three", "ноль один два три"),
    ("one three five", "один три пять"),
    ("two four six", "два четыре шесть"),
    ("one five ten", "один пять десять"),
    ("three six nine", "три шесть девять"),
    ("five ten fifteen", "пять десять пятнадцать"),
    ("ten fifteen twenty", "десять пятнадцать двадцать"),
    ("zero five ten", "ноль пять десять"),
    ("one ten twenty", "один десять двадцать"),
    ("two five eight", "два пять восемь"),
    ("three seven eleven", "три семь одинадцать"),
    ("one two three four five", "один два три четыре пять"),
    ("five six seven eight nine", "пять шесть семь восемь девять"),
    ("ten eleven twelve thirteen fourteen", "десять одинадцать двенадцать тринадцать четырнадцать"),
    ("zero one two three four", "ноль один два три четыре"),
]

In [10]:
def build_vocab(sentences):
    vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
    idx = 4
    for sent in sentences:
        for word in sent.split():
            if word not in vocab:
                vocab[word] = idx
                idx += 1
    return vocab

src_vocab = build_vocab([p[0] for p in pairs])
tgt_vocab = build_vocab([p[1] for p in pairs])

print("SRC vocab:", src_vocab)
print("TGT vocab:", tgt_vocab)

SRC vocab: {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'zero': 4, 'one': 5, 'two': 6, 'three': 7, 'four': 8, 'five': 9, 'six': 10, 'seven': 11, 'eight': 12, 'nine': 13, 'ten': 14, 'eleven': 15, 'twelve': 16, 'thirteen': 17, 'fourteen': 18, 'fifteen': 19, 'sixteen': 20, 'seventeen': 21, 'eighteen': 22, 'nineteen': 23, 'twenty': 24}
TGT vocab: {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'ноль': 4, 'один': 5, 'два': 6, 'три': 7, 'четыре': 8, 'пять': 9, 'шесть': 10, 'семь': 11, 'восемь': 12, 'девять': 13, 'десять': 14, 'одинадцать': 15, 'двенадцать': 16, 'тринадцать': 17, 'четырнадцать': 18, 'пятнадцать': 19, 'шестнадцать': 20, 'семнадцать': 21, 'восемнадцать': 22, 'девятнадцать': 23, 'двадцать': 24}


In [11]:
class TranslationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab, max_len=100):
        self.pairs = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len

    def encode(self, sentence, vocab):
        tokens = [vocab.get(w, vocab["<unk>"]) for w in sentence.split()]
        tokens = [vocab["<sos>"]] + tokens + [vocab["<eos>"]]
        if len(tokens) < self.max_len:
            tokens += [vocab["<pad>"]] * (self.max_len - len(tokens))
        else:
            tokens = tokens[:self.max_len]
        return torch.tensor(tokens)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_tensor = self.encode(src, self.src_vocab)
        tgt_tensor = self.encode(tgt, self.tgt_vocab)
        return src_tensor, tgt_tensor

In [12]:
train_data = TranslationDataset(pairs, src_vocab, tgt_vocab, max_len=100)
train_loader = DataLoader(train_data, batch_size=4, shuffle=True)

In [13]:
src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
d_model = 32
num_heads = 2
num_layers = 1
d_ff = 64
max_seq_length = 100
dropout = 0.0
device = "cuda"

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout, device).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

for epoch in range(100):
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)
        tgt_inp = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        optimizer.zero_grad()
        out = transformer(src, tgt_inp)
        loss = criterion(out.reshape(-1, out.size(-1)), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, loss={total_loss/len(train_loader):.4f}")

Epoch 1, loss=3.1034
Epoch 2, loss=2.6467
Epoch 3, loss=2.4087
Epoch 4, loss=2.2149
Epoch 5, loss=1.9929
Epoch 6, loss=1.7941
Epoch 7, loss=1.6069
Epoch 8, loss=1.3980
Epoch 9, loss=1.2081
Epoch 10, loss=1.0229
Epoch 11, loss=0.8579
Epoch 12, loss=0.7345
Epoch 13, loss=0.6068
Epoch 14, loss=0.5289
Epoch 15, loss=0.4796
Epoch 16, loss=0.3840
Epoch 17, loss=0.3311
Epoch 18, loss=0.2791
Epoch 19, loss=0.2279
Epoch 20, loss=0.1868
Epoch 21, loss=0.1604
Epoch 22, loss=0.1264
Epoch 23, loss=0.1038
Epoch 24, loss=0.0869
Epoch 25, loss=0.0783
Epoch 26, loss=0.0663
Epoch 27, loss=0.0831
Epoch 28, loss=0.0870
Epoch 29, loss=0.0617
Epoch 30, loss=0.0476
Epoch 31, loss=0.0336
Epoch 32, loss=0.0289
Epoch 33, loss=0.0226
Epoch 34, loss=0.0197
Epoch 35, loss=0.0167
Epoch 36, loss=0.0144
Epoch 37, loss=0.0130
Epoch 38, loss=0.0113
Epoch 39, loss=0.0104
Epoch 40, loss=0.0098
Epoch 41, loss=0.0075
Epoch 42, loss=0.0058
Epoch 43, loss=0.0050
Epoch 44, loss=0.0043
Epoch 45, loss=0.0036
Epoch 46, loss=0.00

In [14]:
def translate_sentence(model, sentence, src_vocab, tgt_vocab, max_len=10):
    model.eval()
    inv_tgt = {i: w for w, i in tgt_vocab.items()}

    # Преобразуем предложение в индексы и добавляем batch dimension
    src = train_data.encode(sentence, src_vocab).unsqueeze(0).to(device)  # (1, seq_len)
    
    # Начинаем с <sos>
    tgt = torch.tensor([[tgt_vocab["<sos>"]]], device=device)  # (1,1)

    for _ in range(max_len):
        # Получаем предсказания модели
        out = model(src, tgt)  # (1, seq_len, vocab_size)

        # Берем токен с максимальной вероятностью в последнем шаге
        next_token = out[:, -1, :].argmax(dim=-1, keepdim=True)  # (1,1)
        
        # Добавляем предсказанный токен к последовательности
        tgt = torch.cat([tgt, next_token], dim=1)  # (1, seq_len+1)
        
        # Если предсказан <eos>, выходим
        if next_token.item() == tgt_vocab["<eos>"]:
            break

    # Преобразуем индексы в слова, игнорируем специальные токены
    translated_tokens = [inv_tgt[idx.item()] for idx in tgt[0] 
                         if idx.item() not in [tgt_vocab["<sos>"], tgt_vocab["<pad>"], tgt_vocab["<eos>"]]]

    return " ".join(translated_tokens)

In [15]:
print(translate_sentence(transformer, "one", src_vocab, tgt_vocab))
print(translate_sentence(transformer, "one two three", src_vocab, tgt_vocab))
print(translate_sentence(transformer, "nineteen twenty", src_vocab, tgt_vocab))

один
один два три
девятнадцать двадцать
