<a href="https://colab.research.google.com/github/kekys778/avito_ds_2025/blob/main/training_part/training_part.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ***Тетрадка по обучению модели для контеста на авито***
Предобработка текстов стандартная, на вкус и цвет, в зависимости от того, насколько фривольными тексты будут, которые мы собираемся инференсить

In [None]:
import math
import pandas as pd
import random
import re
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from typing import List, Tuple, Dict

In [None]:
PAD_ID = 256
VOCAB_SIZE = 257

class RussianByteTokenizer:
    """Самописный токенайзер, который переводит текст в байты для байтовой модели"""
    def __init__(self, pad_id=PAD_ID):
        self.pad_id = pad_id
        self._vocab_size = VOCAB_SIZE

    def encode(self, text: str, max_length=384, truncation=True):
        text = unicodedata.normalize('NFKC', text) #нормализуем Юникод
        b = text.encode("utf-8") #переводим в байты
        if truncation:
            b = b[:max_length] #трункуйтим до длины, модель обучалась на последорвательностях до 512
        ids = list(b)
        if len(ids) < max_length:
            ids += [self.pad_id] * (max_length - len(ids))
        return ids

    def decode(self, ids):
        valid = [i for i in ids if i != self.pad_id]
        return bytes(valid).decode("utf-8", errors="ignore")

    def pad_token_id(self):
        return self.pad_id

    def vocab_size(self):
        return self._vocab_size

In [None]:
def create_space_insertion_data(text: str, removal_prob: float = 0.7) -> Tuple[str, List[int]]:
    """
    Удаляет пробелы с определенной вероятностью
    """
    corrupted = ""
    labels = []

    for i, char in enumerate(text):
        if char == ' ':
            if random.random() < removal_prob:
                if labels:
                    labels[-1] = 1
            else:
                corrupted += char
                labels.append(0)
        else:
            corrupted += char
            if i < len(text) - 1 and text[i + 1] == ' ':
                labels.append(0)
            else:
                labels.append(0)

    return corrupted, labels

In [None]:
class SpaceDataset(Dataset):
    def __init__(self, filepath: str, tokenizer: RussianByteTokenizer,
                 seq_len: int = 1024, max_samples: int = None, removal_prob: float = 0.7):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.removal_prob = removal_prob

        self.samples = []

        with open(filepath, "r", encoding="utf-8") as f:
            count = 0
            for line in f:
                self.samples.append(line)
                count += 1

                if count % 10000000 == 0:
                    print(f"Загружено {count} примеров")
                if count == 15_000_000:
                  break

        print(f"Всего строк: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        original_text = self.samples[idx]

        # убираем пробелы
        corrupted_text, space_labels = create_space_insertion_data(
            original_text, self.removal_prob
        )

        # наш токенизатор
        input_ids = self.tokenizer.encode(corrupted_text, max_length=self.seq_len)

        # наши лейблы с байтами
        byte_labels = self.align_labels_with_bytes(corrupted_text, space_labels)

        # паддим
        if len(byte_labels) < self.seq_len:
            byte_labels += [2] * (self.seq_len - len(byte_labels))
        else:
            byte_labels = byte_labels[:self.seq_len]

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(byte_labels, dtype=torch.long)

    def align_labels_with_bytes(self, text: str, char_labels: List[int]) -> List[int]:
        if not text:
            return []

        bytes_text = text.encode('utf-8')
        byte_labels = []
        char_idx = 0

        i = 0
        while i < len(bytes_text) and char_idx < len(char_labels):
            byte_val = bytes_text[i]

            if byte_val < 128:
                byte_labels.append(char_labels[char_idx])
                char_idx += 1
                i += 1
            else:
                if (byte_val & 0xE0) == 0xC0:
                    char_bytes = 2
                elif (byte_val & 0xF0) == 0xE0:
                    char_bytes = 3
                elif (byte_val & 0xF8) == 0xF0:
                    char_bytes = 4
                else:
                    char_bytes = 1

                label = char_labels[char_idx] if char_idx < len(char_labels) else 0
                for j in range(char_bytes):
                    if i + j < len(bytes_text):
                        byte_labels.append(label)

                char_idx += 1
                i += char_bytes

        return byte_labels

In [None]:
class SpaceInsertionModel(nn.Module):
    """Модель выдает вероятность пробела после каждого символа"""
    def __init__(self, vocab_size: int = 257, hidden_dim: int = 384,
                 num_layers: int = 6, num_heads: int = 6):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.byte_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=PAD_ID)
        self.pos_embedding = nn.Embedding(512, hidden_dim) #позиционные обучаемые эмбеддинги

        # энкодер на трансформере (EO-архитектура)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, #архтектура попроще, чем в статье
            nhead=num_heads,
            dim_feedforward=hidden_dim * 2,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # последний слой: классфикатор выдает вероятность пробела
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 2)
        )
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size, seq_len = x.shape

        # Embeddings
        token_emb = self.byte_embedding(x)
        pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.pos_embedding(pos_ids)

        emb = token_emb + pos_emb
        emb = self.dropout(emb)

        padding_mask = (x == PAD_ID)

        encoded = self.encoder(emb, src_key_padding_mask=padding_mask)

        logits = self.classifier(encoded)  # (batch, seq, 2)

        return logits

In [None]:
def get_space_positions(text: str) -> set:
    """Возвращает индексы символов, после которых стоит пробел"""
    return {i for i, ch in enumerate(text) if ch == " "}


def f1_score_spaces(true_texts, pred_texts) -> float:
    """Вычисляет средний ф1"""
    f1_scores = []

    for true, pred in zip(true_texts, pred_texts):
        true_spaces = get_space_positions(true)
        pred_spaces = get_space_positions(pred)

        if not true_spaces and not pred_spaces:
            f1_scores.append(1.0)
            continue
        if not pred_spaces:
            f1_scores.append(0.0)
            continue

        tp = len(true_spaces & pred_spaces)
        precision = tp / len(pred_spaces) if pred_spaces else 0.0
        recall = tp / len(true_spaces) if true_spaces else 0.0

        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)

        f1_scores.append(f1)

    return sum(f1_scores) / len(f1_scores)

In [None]:
def insert_spaces_silent(model, text: str, tokenizer: RussianByteTokenizer, device) -> str:

    model.eval()

    input_text = text.replace(' ', '')

    input_ids = tokenizer.encode(input_text, max_length=384)
    input_tensor = torch.tensor([input_ids]).to(device)

    with torch.no_grad():
        logits = model(input_tensor)
        probabilities = torch.softmax(logits, dim=-1)
        space_probs = probabilities[0, :, 1]
        predictions = (space_probs > 0.5).cpu().numpy()

    result = ""
    input_bytes = input_text.encode('utf-8')

    byte_idx = 0
    for char in input_text:
        result += char

        char_bytes = len(char.encode('utf-8'))
        should_add_space = False

        for i in range(char_bytes):
            if byte_idx + i < len(predictions) and predictions[byte_idx + i]:
                should_add_space = True
                break

        if should_add_space and result[-1] != ' ':
            result += ' '

        byte_idx += char_bytes

    return result

In [None]:
def evaluate_f1(model, dataset, tokenizer, device, num_samples=1000, print_examples=3):
    """Вычислить ф1 и выводит примеры"""
    model.eval()
    true_texts, pred_texts = [], []

    print(f"\n=== оценка на  {min(num_samples, len(dataset))} примерах ===")

    for i in tqdm(range(min(num_samples, len(dataset)))):
        original_text = dataset.samples[i]
        corrupted_text, _ = create_space_insertion_data(original_text, removal_prob=0.7)

        pred_text = insert_spaces_silent(model, corrupted_text, tokenizer, device)

        true_texts.append(original_text)
        pred_texts.append(pred_text)

        if i < print_examples:
            print(f"\nПример {i+1}:")
            print(f"Оригинал: '{original_text}'")
            print(f"Испорчено:'{corrupted_text.replace(' ', '')}'")
            print(f"Предскапзано:'{pred_text}'")
            print(f"А совпадает ли?: {original_text == pred_text}")

    f1 = f1_score_spaces(true_texts, pred_texts)
    print(f"\nф1: {f1:.4f}")
    print("=" * 50)
    return f1

In [None]:
# наша обертка для трейна

import os
def train_space_insertion_model(model, train_loader, device, save_dir, epochs=2):
    os.makedirs(save_dir, exist_ok=True) #для чекпоинтов

    best_space_acc = 0.0
    start_epoch = 0

    optimizer = optim.AdamW(
        model.parameters(),
        lr=3e-4,
        weight_decay=0.01,
        betas=(0.9, 0.95)
    )


    criterion = nn.CrossEntropyLoss(
        ignore_index=2,
        weight=torch.tensor([1.0, 3.0]).to(device)
    )

    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=3e-4,
        epochs=epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1
    )

    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
    accumulation_steps = 4


    checkpoint_path = os.path.join(save_dir, "checkpoint_epoch_0_batch_75000.pth")
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_space_acc = checkpoint.get('best_space_acc', 0.0)
        print(f" Чекпоинт {start_epoch}, лучший скор: {best_space_acc:.4f}")

    model.train()

    for epoch in range(epochs):
        total_loss = 0.0
        total_correct = 0
        total_predictions = 0
        space_predictions = 0
        space_correct = 0

        for batch_idx, (input_ids, labels) in enumerate(train_loader):
            input_ids, labels = input_ids.to(device), labels.to(device)

            if scaler:
                with torch.cuda.amp.autocast():
                    logits = model(input_ids)
                    mask = (labels != 2)

                    if mask.sum() > 0:
                        loss = criterion(logits[mask], labels[mask]) / accumulation_steps

                scaler.scale(loss).backward() #используем скейлер для ускорения обучения
            else:
                logits = model(input_ids)
                mask = (labels != 2)

                if mask.sum() > 0:
                    loss = criterion(logits[mask], labels[mask]) / accumulation_steps
                    loss.backward()

            # использузуем накопление градиента для имитации большего батчсайза
            if (batch_idx + 1) % accumulation_steps == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                optimizer.zero_grad()
                scheduler.step()

            # лосс
            if mask.sum() > 0:
                total_loss += loss.item() * accumulation_steps

                with torch.no_grad():
                    predictions = logits.argmax(dim=-1)
                    correct = (predictions[mask] == labels[mask]).sum().item()
                    total_correct += correct
                    total_predictions += mask.sum().item()

                    # проверяем нашу аккураси
                    space_mask = (labels == 1) & mask
                    if space_mask.sum() > 0:
                        space_pred_correct = (predictions[space_mask] == labels[space_mask]).sum().item()
                        space_correct += space_pred_correct
                        space_predictions += space_mask.sum().item()

            if batch_idx % 1000 == 0:
                current_acc = total_correct / max(total_predictions, 1)
                space_acc = space_correct / max(space_predictions, 1)
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
                      f"Loss: {loss.item()*accumulation_steps:.4f}, "
                      f"Overall Acc: {current_acc:.4f}, "
                      f"Space Acc: {space_acc:.4f}")

            if batch_idx % 5000 == 0 and batch_idx > 0:
              checkpoint = {
              'epoch': epoch,
              'batch': batch_idx,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'best_space_acc': best_space_acc,
              'loss': loss.item() if mask.sum() > 0 else 0.0
                }
              torch.save(checkpoint, os.path.join(save_dir, f"checkpoint_epoch_{epoch}_batch_{batch_idx}.pth"))
              print(f"Checkpoint saved at epoch {epoch}, batch {batch_idx}")



        avg_loss = total_loss / len(train_loader)
        accuracy = total_correct / max(total_predictions, 1)
        space_accuracy = space_correct / max(space_predictions, 1)
        val_f1 = evaluate_f1(model, train_loader.dataset, tokenizer, device, num_samples=500)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} |"
              f"Acc: {accuracy:.4f} | Space Insertion Acc: {space_accuracy:.4f} | Validation F1: {val_f1:.4f}")
        if space_accuracy > best_space_acc:
            best_space_acc = space_accuracy
            best_model_path = os.path.join(save_dir, "best_model.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_space_acc': best_space_acc,
                'accuracy': accuracy,
                'loss': avg_loss
            }, best_model_path)
            print(f"New best model saved! Space accuracy: {space_accuracy:.4f}")

        # сохраняем чекпоин
        latest_checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_space_acc': best_space_acc,
            'accuracy': accuracy,
            'loss': avg_loss
        }
        torch.save(latest_checkpoint, os.path.join(save_dir, "latest_checkpoint.pth"))

        epoch_checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}_checkpoint.pth")
        torch.save(latest_checkpoint, epoch_checkpoint_path)
        print(f"Epoch {epoch+1} checkpoint saved")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = RussianByteTokenizer()
dataset = SpaceDataset(
        # '/kaggle/input/wikicleaned/wiki_all.txt',
        tokenizer,
        seq_len=384,
        max_samples=0, #max_samples,
        removal_prob=0.7  # Remove 70% of spaces for training
    )

In [None]:
 train_loader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=2,
        pin_memory=True if device.type == 'cuda' else False
    )

In [None]:
model = SpaceInsertionModel(
        vocab_size=tokenizer.vocab_size(),
        hidden_dim=384,
        num_layers=6,
        num_heads=6
    ).to(device)

In [None]:
train_space_insertion_model(model, train_loader, device, '/kaggle/working/', epochs=3)