In [67]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import OrderedDict
import math

In [68]:
with open("input.txt", "r", encoding="utf-8") as f:
    file = f.read()

letters = list(file)
print(len(letters))
print(file[:15])

1115394
First Citizen:



In [64]:
vocab_size = 256
d_model = 512
num_merges = 200
itos = {i: bytes([i]).decode('utf-8', errors='ignore') for i in range(256)}
tokens = list(file.encode("utf-8"))

vocab_size = vocab_size + num_merges

chars = sorted(set(letters))
token_embedding = nn.Embedding(vocab_size, d_model)

C = token_embedding(torch.tensor(tokens, dtype=torch.long))

print(C.shape)

torch.Size([1115394, 512])


In [83]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, new_id):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_ids.append(new_id)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids


class BPETokenizer:
    def __init__(self, num_merges=200):
        self.num_merges = num_merges
        self.merges = OrderedDict()
        self.vocab = {}

    def train(self, text):
        self.vocab = {i: bytes([i]) for i in range(256)}

        tokens = list(text.encode('utf-8'))
        for i in range(self.num_merges):
            stats = get_stats(tokens)

            if not stats:
                print(f"больше нет пар")
                break

            pair = max(stats, key=stats.get)
            new_id = 256 + i

            self.merges[pair] = new_id
            self.vocab[new_id] = self.vocab[pair[0]] + self.vocab[pair[1]]

            tokens = merge(tokens, pair, new_id)

        print(f"   Сжатие: {len(text.encode('utf-8'))} → {len(tokens)} токенов")
        return tokens

    def encode(self, text):
        tokens = list(text.encode('utf-8'))
        for pair, new_id in self.merges.items():
            tokens = merge(tokens, pair, new_id)

        return tokens

    def decode(self, tokens):
        byte_array = b''.join([self.vocab[int(token)] for token in tokens])
        return byte_array.decode('utf-8', errors='replace')

    @property
    def vocab_size(self):
        return len(self.vocab)

In [84]:
def positional_encoding(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1)

    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [85]:
block_size = 128
batch_size = 32

def get_batch(data, block_size, batch_size):
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i: i+block_size] for i in ix])
    y = torch.stack([data[i+1: i+block_size+1] for i in ix])
    return x, y

data = torch.tensor(indices, dtype=torch.long)
x, y = get_batch(data, block_size=128, batch_size=32) # 32 x 128
x_emb = token_embedding(x) # 32 x 128 x 512
pos_enc = positional_encoding(128, 512)
x_emb = x_emb + pos_enc

In [86]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None): # x.shape is (32 x 128 x 512)
        batch_size, seq_len, d_model = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k) # (32 x 128 x 8 x 64)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k)

        Q = Q.transpose(1, 2) # (32 x 8 x 128 x 64)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k) # (32 x 8 x 128 x 64) @ (32 x 8 x 64 x 128) -> (32 x 8 x 128 x 128)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attn = torch.softmax(scores, dim=-1)
        out = attn @ V # (32 x 8 x 128 x 128) @ (32 x 8 x 128 x 64) -> (32 x 8 x 128 x 64)
        out = out.transpose(1, 2)
        out = out.contiguous().view(batch_size, seq_len, d_model)

        return self.W_o(out)

In [87]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out = self.attn(x)
        x = self.norm1(x + attn_out)

        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)

        return x


In [88]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.masked_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn_out = self.masked_attn(x, mask=mask)
        x = self.norm1(x + attn_out)

        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)

        return x


In [89]:
def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    print(mask)
    mask = mask == 0
    return mask

def decode(indices):
    return ''.join([itos[int(i)] for i in indices])


In [90]:
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, block_size):
        super().__init__()
        self.block_size = block_size

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(block_size, d_model)

        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

        self.register_buffer(
            'causal_mask',
            create_causal_mask(block_size).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, x):
        B, T = x.shape

        tok_emb = self.token_embedding(x) # (B, T, d_model)
        pos = torch.arange(T, device=x.device)
        pos_emb = self.position_embedding(pos)
        x = tok_emb + pos_emb # (B, T, d_model)

        mask = self.causal_mask[:, :, :T, :T]
        for block in self.blocks:
            x = block(x, mask=mask)

        x = self.ln_f(x)
        logits = self.head(x)

        return logits

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]

            logits = self(idx_cond)
            logits = logits[:, -1, :]

            probs = torch.softmax(logits, dim=-1)

            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            idx = torch.cat([idx, idx_next], dim=1)

        return idx


In [91]:
if __name__ == "__main__":
    num_merges = 200
    tokenizer = BPETokenizer(num_merges=num_merges)
    train_tokens = tokenizer.train(file)

    vocab_size = tokenizer.vocab_size
    d_model = 512
    num_heads = 8
    num_layers = 6
    d_ff = d_model * 4
    block_size = 128

    model = GPT(
        vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        num_layers=num_layers,
        d_ff=d_ff,
        block_size=block_size,
    )

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Модель создана:")
    print(f"  Словарь: {vocab_size} токенов")
    print(f"  Параметры: {total_params:,}")
    print(f"  Block size: {block_size}")
    print()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    print()

    num_epochs = 10
    steps_per_epoch = 1000
    eval_interval = 500
    batch_size = 32

    for epoch in range(num_epochs):
        model.train()

        for step in range(steps_per_epoch):
            x, y = get_batch(torch.tensor(train_tokens, dtype=torch.long).to(device), block_size, batch_size)

            logits = model(x)  # (B, T, vocab_size)

            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            y = y.view(B * T)
            loss = criterion(logits, y)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            if step % eval_interval == 0:
                print(f"Epoch {epoch}, Step {step:4d}, Loss: {loss.item():.4f}")

        print("\n" + "=" * 70)
        print(f"ГЕНЕРАЦИЯ (конец эпохи {epoch})")
        print("=" * 70)
        model.eval()

        prompts = ["The", "In", "Kazakhstan"]

        for prompt in prompts:
            context = tokenizer.encode(prompt)
            context = torch.tensor([context], dtype=torch.long, device=device)

            generated = model.generate(
                context,
                max_new_tokens=100,
            )

            generated_text = tokenizer.decode(generated[0])
            print(f"Generated: {generated_text}")
            print("-" * 70)

        print("=" * 70)
        print()

    print("Обучение завершено!")

   Сжатие: 1115394 → 598357 токенов
tensor([[0., 1., 1.,  ..., 1., 1., 1.],
        [0., 0., 1.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Модель создана:
  Словарь: 456 токенов
  Параметры: 19,448,264
  Block size: 128


Epoch 0, Step    0, Loss: 6.2947
Epoch 0, Step  500, Loss: 2.7849

ГЕНЕРАЦИЯ (конец эпохи 0)
Generated: Thee will make our truth, fand, stranken,
To miss to faces urged as the violate,
With assaufice to wind's grave.
To revenge that he determotion ewhen 'tle!

KINGHENRY VI:
Why, blood and se
----------------------------------------------------------------------
Generated: In thri; and, lieu
And be report. If he cannot to do at
defamine of the law of pamitt need.

Nurse:
Though my peace are not wast burder where were the faith chor
To quicklethes: andeth had my fin
--------------------------------------------------------

KeyboardInterrupt: 

In [93]:
loss.item()

0.34625616669654846