# Генератор текста на базе Transformer

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import math
from tqdm import tqdm
from typing import Optional, Tuple

from tokenizers import Tokenizer


In [3]:
tokenizer = Tokenizer.from_file("mistral_tokenizer.json")
tokenizer.add_special_tokens(["<pad>", "<s>", "</s>"])

pad_id = tokenizer.token_to_id("<pad>")
bos_id = tokenizer.token_to_id("<s>")
eos_id = tokenizer.token_to_id("</s>")
vocab_size = tokenizer.get_vocab_size()

In [4]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer, max_length: int = 128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        tokens = tokenizer.encode(text).ids
        self.samples = []

        for i in range(0, len(tokens) - max_length - 1, max_length):
            input_ids = tokens[i:i + max_length]
            target_ids = tokens[i + 1:i + max_length + 1]
            self.samples.append((
                torch.tensor(input_ids, dtype=torch.long),
                torch.tensor(target_ids, dtype=torch.long)
            ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x


class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, pad_index):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_index)
        self.pos_embed = PositionalEncoding(d_model)

    def forward(self, x):
        x = self.token_embed(x)
        x = self.pos_embed(x)
        return x


In [6]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_head = d_model // num_heads
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, D = x.size()
        qkv = self.qkv_proj(x)  # (B, T, 3*D)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        attn_weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)

        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(attn_weights, dim=-1)
        out = attn @ v  # (B, heads, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out)


class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.attn = MultiheadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = x + self.dropout(self.attn(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x


In [7]:
class GeneratorTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=8, d_ff=1024,
                 num_layers=6, dropout=0.1, max_len=128, pad_index=0, eos_index=2, tokenizer=None):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, pad_index)
        self.decoder = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)

        self.max_len = max_len
        self.pad_index = pad_index
        self.eos_index = eos_index
        self.tokenizer = tokenizer

    def make_mask(self, x):
        T = x.size(1)
        return torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0).bool()

    def forward(self, x):
        mask = self.make_mask(x)
        x = self.embedding(x)
        for layer in self.decoder:
            x = layer(x, mask)
        x = self.norm(x)
        return self.output(x)

    def generate(self, prompt, context_len=50, temperature=1.0, max_out_tokens=100):
        self.eval()
        with torch.no_grad():
            input_ids = self.tokenizer.encode(prompt).ids
            input_ids = torch.tensor([input_ids], device=next(self.parameters()).device)

            generated = input_ids.clone()

            for _ in range(max_out_tokens):
                context = generated[:, -context_len:]
                logits = self.forward(context)
                next_token_logits = logits[:, -1, :] / temperature
                next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), 1)
                generated = torch.cat([generated, next_token], dim=1)

                if next_token.item() == self.eos_index:
                    break

        return self.tokenizer.decode(generated[0].tolist())


In [19]:
def train(model, dataloader, optimizer, device, epochs=3):
    model.train()
    criterion = nn.CrossEntropyLoss(ignore_index=model.pad_index)

    for epoch in range(epochs):
        total_loss = 0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")


In [20]:
# Загрузка и обучение
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("war_and_peace.txt", "r", encoding="windows-1251") as f:
    text = f.read()

dataset = TextDataset(text, tokenizer=tokenizer, max_length=32)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

model = GeneratorTransformer(
    vocab_size=vocab_size,
    pad_index=pad_id,
    eos_index=eos_id,
    tokenizer=tokenizer
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, dataloader, optimizer, device, epochs=3)


Epoch 1: 100%|██████████| 6365/6365 [14:22<00:00,  7.38it/s]


Epoch 1: Loss = 4.7989


Epoch 2: 100%|██████████| 6365/6365 [14:20<00:00,  7.40it/s]


Epoch 2: Loss = 3.8317


Epoch 3: 100%|██████████| 6365/6365 [16:03<00:00,  6.61it/s]

Epoch 3: Loss = 3.4613





In [23]:
prompt = "Однажды генерал табуретка"
output = model.generate(prompt, context_len=50, temperature=0.8, max_out_tokens=50)
print("Generated:", output)

Generated: Однажды генерал табуреткая, как будто не видела нежется, – сказала Анна Павловна Михайловна, несловна, но не взявловлан, улым и упреки,
