**Abschlussprojekt: Entwicklung eines eigenen Sprachmodells**

In [None]:
#!pip install datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import wandb
from tqdm import tqdm

In [None]:
# ======================== Config ========================
config = {
    "epochs": 10,
    "batch_size": 32,
    "learning_rate": 1e-4,
    "model_dim": 256,
    "n_heads": 8,
    "n_layers": 6,
    "block_size": 128,
    "dataset": "wikitext",
    "dataset_config": "wikitext-2-raw-v1",
    "dropout": 0.1  # Добавил регуляризацию
}

# ======================== Device ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ======================== wandb ========================
wandb.init(project="my-transformer-lm", config=config)


In [21]:
# ======================== Tokenizer & Dataset ========================
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
config["vocab_size"] = tokenizer.vocab_size

raw_dataset = load_dataset(config["dataset"], config["dataset_config"])

class TokenDataset(Dataset):
    def __init__(self, texts, tokenizer, block_size):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.data = []

        for text in texts:
            tokenized = tokenizer.encode(
                text["text"],
                truncation=True,
                max_length=block_size,
                padding="max_length"
            )
            self.data.append(torch.tensor(tokenized))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Фильтрация пустых текстов
train_texts = [txt for txt in raw_dataset["train"] if len(txt["text"]) > 0]
val_texts = [txt for txt in raw_dataset["validation"] if len(txt["text"]) > 0]

train_dataset = TokenDataset(train_texts, tokenizer, config["block_size"])
val_dataset = TokenDataset(val_texts, tokenizer, config["block_size"])

train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"])

In [22]:
# ======================== Model ========================
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, n_layers, block_size, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, block_size, embed_dim))

        # Autoregressive decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dropout=dropout,
            activation="gelu",
            batch_first=True  # For convenience
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)

        self.ln = nn.LayerNorm(embed_dim)  # Normalization
        self.fc = nn.Linear(embed_dim, vocab_size)

        # Mask of the Future \ кэшируем
        self.register_buffer(
            "future_mask",
            torch.triu(torch.ones(block_size, block_size) * float('-inf')).transpose(0, 1)
        )

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()

    def forward(self, x):
        B, T = x.size()  # Batch, Sequence length

        # Embedding + positional coding
        tok_emb = self.embed(x)  # (B,T,embed_dim)
        pos_emb = self.pos_embed[:, :T, :]  # (1,T,embed_dim)
        x = tok_emb + pos_emb

        # Autoregression with future masking
        tgt_mask = self.future_mask[:T, :T]  # (T,T)
        x = self.decoder(
            tgt=x,
            memory=x,
            tgt_mask=tgt_mask,
            memory_mask=None
        )

        x = self.ln(x)
        return self.fc(x)

model = TransformerLM(
    vocab_size=config["vocab_size"],
    embed_dim=config["model_dim"],
    n_heads=config["n_heads"],
    n_layers=config["n_layers"],
    block_size=config["block_size"],
    dropout=config["dropout"]
).to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


Parameters: 35.29M


In [23]:
# ======================== Training ========================
def train(model, train_dataloader, val_dataloader, epochs):
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
            batch = batch.to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]

            logits = model(inputs)
            loss = loss_fn(logits.view(-1, logits.size(-1)), targets.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Клиппинг градиентов
            optimizer.step()

            train_loss += loss.item()

        # Validation
        val_loss = evaluate(model, val_dataloader, loss_fn)
        scheduler.step()

        # Logging
        wandb.log({
            "train_loss": train_loss / len(train_dataloader),
            "val_loss": val_loss,
            "lr": scheduler.get_last_lr()[0]
        })

        # Saving the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")

        print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_dataloader):.4f} | Val Loss: {val_loss:.4f}")

def evaluate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]

            logits = model(inputs)
            loss = loss_fn(logits.view(-1, logits.size(-1)), targets.reshape(-1))
            total_loss += loss.item()

    return total_loss / len(dataloader)


In [24]:
# Run training
train(model, train_dataloader, val_dataloader, config["epochs"])
wandb.finish()

Epoch 1: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 1 | Train Loss: 6.1381 | Val Loss: 3.9406


Epoch 2: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 2 | Train Loss: 2.7137 | Val Loss: 1.8125


Epoch 3: 100%|██████████| 743/743 [02:55<00:00,  4.24it/s]


Epoch 3 | Train Loss: 1.3564 | Val Loss: 1.0696


Epoch 4: 100%|██████████| 743/743 [02:55<00:00,  4.24it/s]


Epoch 4 | Train Loss: 0.8070 | Val Loss: 0.7350


Epoch 5: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 5 | Train Loss: 0.5413 | Val Loss: 0.5616


Epoch 6: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 6 | Train Loss: 0.3996 | Val Loss: 0.4654


Epoch 7: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 7 | Train Loss: 0.3195 | Val Loss: 0.4114


Epoch 8: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 8 | Train Loss: 0.2747 | Val Loss: 0.3827


Epoch 9: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 9 | Train Loss: 0.2513 | Val Loss: 0.3701


Epoch 10: 100%|██████████| 743/743 [02:55<00:00,  4.23it/s]


Epoch 10 | Train Loss: 0.2418 | Val Loss: 0.3667


0,1
lr,█▇▇▆▅▃▂▂▁▁
train_loss,█▄▂▂▁▁▁▁▁▁
val_loss,█▄▂▂▁▁▁▁▁▁

0,1
lr,0.0
train_loss,0.24181
val_loss,0.36666


In [27]:
def generate(
    model, tokenizer, prompt, max_length=50,
    #temperature=1.0, top_k=50, top_p=0.9,
    temperature=1.0, top_k=None, top_p=None,
    device=device
):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids[:, -config["block_size"]:])  # Обрезаем если длиннее контекста
            logits = logits[:, -1, :] / temperature

            # Top-k фильтрация
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Top-p (nucleus) sampling
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # Удаляем токены с cumulative_probs > top_p
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[..., indices_to_remove] = -float('Inf')

            # Сэмплирование
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Instance
prompt = "What is the summer?"
print(generate(model, tokenizer, prompt))

What is the summer? Safari201 discharge BALL Latinotu Wilson Swe Schw psWire armoured Cam Imam sinkerson assignanti Series Passenger One dollar endpoint Harvard nine activities declared� amb Soraoch ransomicityunderertyard Woodward riding increased spoiler fifteen visitドラゴン� Smithsonian Whenever . nursery awbeam
