In [1]:
import re
import json
import random
import csv
from collections import Counter
from typing import Iterable, Tuple, List
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from sklearn.metrics import f1_score
from datasets import load_dataset
from functools import lru_cache
import math
import pandas as pd
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")

tokenizer_config.json:   0%|          | 0.00/401 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [3]:
data = []
with open('/kaggle/input/dataset-1937770-3-txt/dataset_1937770_3.txt', 'r', encoding='utf-8') as f:

    header = f.readline().strip().split(',')

    for line in f:
        line = line.strip()

        first_comma_index = line.find(',')

        if first_comma_index != -1:

            row_id = line[:first_comma_index]
            text = line[first_comma_index + 1:]
            data.append([row_id, text])


task_data = pd.DataFrame(data, columns=['id', 'text_no_spaces'])
task_data.head()

Unnamed: 0,id,text_no_spaces
0,0,куплюайфон14про
1,1,ищудомвПодмосковье
2,2,сдаюквартирусмебельюитехникой
3,3,новыйдивандоставканедорого
4,4,отдамдаромкошку


Эвристики для разбиения слов, обработка пунктуации, предлогов и местоимений из одной буквы. Тут же score_word (с кэшированием) с токенизатором от берта. Score_word оценивает стоимость потенциального слова при разбиении токенизатором или, если не получается, просто функцией от длины слова.

In [4]:
_PUNCT = set(",.;:!?…()[]{}\"'—–-")

def _is_cyr(c: str) -> bool:
    c = c.lower()
    return ('а' <= c <= 'я') or (c == 'ё')

def _hard_boundary(prev_ch, ch) -> bool:
    if prev_ch is None or ch is None:
        return False
    if ch in _PUNCT or prev_ch in _PUNCT:
        return True
    if prev_ch.isdigit() != ch.isdigit():
        return True
    if (_is_cyr(prev_ch) != _is_cyr(ch)) and (not prev_ch.isdigit() and not ch.isdigit()):
        return True
    if prev_ch.islower() and ch.isupper():
        return True
    return False


def score_word(word):
    if not word:
        return -1e9
    tok = tokenizer 
    toks = tok.tokenize(word)
    if getattr(tok, "unk_token", None) in toks:
        return -50.0
    L = len(word)
    return -abs(6 - L) * 0.1


@lru_cache(maxsize=200_000)
def _cached_score_word(w):
    return score_word(w)

_SHORT_OK = {"в", "к", "с", "и", "у", "я", "о", "а"}

@lru_cache(maxsize=200_000)
def _cached_score_word(w: str) -> float:
    return score_word(w)

_sent_re = re.compile(r"[.!?]+")
def sentences_from_text(text: str) -> List[str]:
    text = text.replace("\n", " ").strip()
    sents = _sent_re.split(text)
    return [s.strip() for s in sents if len(s.strip()) > 5]

Подгружаем RuBQ-2.0 - датасет с поисковыми запросами с hugging face для обучения LSTM, составляем словарь символ-индекс для запросов.

In [5]:
def stream_sources() -> Iterable[Tuple[str, str]]:
    rubq_dev = load_dataset("d0rj/RuBQ_2.0", split="dev", streaming=True)
    rubq_test = load_dataset("d0rj/RuBQ_2.0", split="test", streaming=True)

    def head_rows(ds, limit, fields=("question_text", "answer_text")):
        taken = 0
        for row in ds:
            for f in fields:
                txt = row.get(f)
                if isinstance(txt, str) and txt.strip():
                    yield txt
                    taken += 1
                    if taken >= limit:
                        return

    for txt in head_rows(rubq_dev, MAX_DOCS_RUBQ_QA):
        yield ("rubq", txt)
    for txt in head_rows(rubq_test, MAX_DOCS_RUBQ_QA):
        yield ("rubq", txt)

    if MAX_DOCS_PARAGRAPHS > 0:
        rubq_par = load_dataset("d0rj/RuBQ_2.0-paragraphs", split="paragraphs", streaming=True)
        taken = 0
        for row in rubq_par:
            txt = row.get("paragraph") or row.get("paragraphs") or row.get("text")
            if isinstance(txt, str) and txt.strip():
                yield ("rubq_par", txt)
                taken += 1
                if taken >= MAX_DOCS_PARAGRAPHS:
                    break

def build_alphabet() -> Tuple[dict, dict]:
    counter = Counter()
    seen = 0
    for _, raw in stream_sources():
        for s in sentences_from_text(raw):
            s = s.lower().replace("ё", "е")
            counter.update(ch for ch in s if ch != " ")
            seen += 1
            if seen >= WARMUP_SENT:
                break
        if seen >= WARMUP_SENT:
            break
    char2id = {"<pad>": 0, "<unk>": 1}
    for i, ch in enumerate(sorted(counter), start=2):
        char2id[ch] = i
    id2char = {i: c for c, i in char2id.items()}
    print(f"[alphabet] unique chars: {len(char2id)} (from {seen} sentences)")
    return char2id, id2char

char2id, id2char = build_alphabet()

README.md: 0.00B [00:00, ?B/s]

README.md:   0%|          | 0.00/911 [00:00<?, ?B/s]

[alphabet] unique chars: 1017 (from 80000 sentences)


In [None]:
SEED = 37
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

MAX_DOCS_RUBQ_QA = 10_000
MAX_DOCS_PARAGRAPHS = 60_000
MAX_SENT_TOTAL = 1_000_000
WARMUP_SENT = 80_000
EPOCHS = 12
EMB_DIM = 128
HIDDEN = 256

BATCH_SIZE = 128
MAX_LEN = 128
LR = 1e-3

make_xy для разметки данных (позиция - наличие пробела (1/0));

итерируемый датасет с батчингом загружаем в DataLoader

In [6]:
def make_xy(sentence: str, char2id: dict, max_len: int):
    s = sentence.lower().replace("ё", "е")
    xs, ys = [], []
    prev_nonspace = False
    for ch in s:
        if ch == " ":
            if prev_nonspace and ys:
                ys[-1] = 1
            continue
        xs.append(char2id.get(ch, 1))
        ys.append(0)
        prev_nonspace = True
    if len(xs) <= 1:
        return None
    if len(xs) > max_len:
        xs = xs[:max_len]
        ys = ys[:max_len]
    ys = ys[:-1]
    if not ys:
        return None
    return torch.tensor(xs, dtype=torch.long), torch.tensor(ys, dtype=torch.float32)

class SegmIterable(IterableDataset):
    def __init__(self, char2id, max_len, max_sent_total, split_ratio, role):
        super().__init__()
        self.char2id = char2id
        self.max_len = max_len
        self.max_sent_total = max_sent_total
        self.split_ratio = split_ratio
        self.role = role

    def __iter__(self):
        rng = random.Random(SEED)
        sent_count = 0
        for _, raw in stream_sources():
            sents = sentences_from_text(raw)
            rng.shuffle(sents)
            for s in sents:
                h = hash(s) ^ SEED
                in_train = (h % 1000) < int(self.split_ratio * 1000)
                if (self.role == "train" and not in_train) or (self.role == "val" and in_train):
                    continue
                ex = make_xy(s, self.char2id, self.max_len)
                if ex is None:
                    continue
                yield ex
                sent_count += 1
                if sent_count >= self.max_sent_total:
                    return

def collate_pad(batch):
    xs, ys = zip(*batch)
    T = max(len(x) for x in xs)
    B = len(xs)
    x_pad = torch.zeros((B, T), dtype=torch.long)
    y_pad = torch.full((B, T - 1), -100.0, dtype=torch.float32)
    lengths = []
    for i, (x, y) in enumerate(zip(xs, ys)):
        x_pad[i, :len(x)] = x
        y_pad[i, :len(y)] = y
        lengths.append(len(x))
    return x_pad, y_pad, torch.tensor(lengths, dtype=torch.long)

train_set = SegmIterable(char2id, MAX_LEN, int(MAX_SENT_TOTAL * 0.9), split_ratio=0.9, role="train")
val_set = SegmIterable(char2id, MAX_LEN, int(MAX_SENT_TOTAL * 0.1), split_ratio=0.9, role="val")

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_pad)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_pad)

Модель -- двунаправленная LSTM: 

принимает индексы символов, возвращает логиты вероятности пробела для каждой позиции.

In [7]:
class CharBoundaryTagger(nn.Module):
    def __init__(self, vocab_size, emb_dim=EMB_DIM, hidden=HIDDEN, num_layers=1, dropout=0.0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = nn.LSTM(emb_dim, hidden, num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=True)
        self.out = nn.Linear(hidden * 2, 1)

    def forward(self, x, lengths):
        emb = self.emb(x)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        h, _ = self.rnn(packed)
        h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first=True)
        logits = self.out(h).squeeze(-1)
        return logits[:, :-1]

model = CharBoundaryTagger(len(char2id), num_layers=2, dropout=0.2).to(DEVICE) # <-- 2 слоя, dropout
opt = torch.optim.Adam(model.parameters(), lr=LR)

Обучаем модель с использованием Adam'а, оцениваем бинарной кросс-энтропией. Валидация с помощью F1_score (см. функция evaluate)

In [8]:
def evaluate(model, loader):
    model.eval()
    all_true, all_pred = [], []
    with torch.no_grad():
        for x, y, lengths in loader:
            x, y, lengths = x.to(DEVICE), y.to(DEVICE), lengths.to(DEVICE)
            logits = model(x, lengths)
            mask = (y != -100.0)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            all_true.append(y[mask].cpu())
            all_pred.append(preds[mask].cpu())
    if not all_true:
        return 0.0
    y_true = torch.cat(all_true).numpy()
    y_pred = torch.cat(all_pred).numpy()
    return f1_score(y_true, y_pred)

for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, y, lengths in pbar:
        x, y, lengths = x.to(DEVICE), y.to(DEVICE), lengths.to(DEVICE)
        logits = model(x, lengths)
        mask = (y != -100.0)
        loss = F.binary_cross_entropy_with_logits(logits[mask], y[mask])
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(loss=float(loss))
    f1 = evaluate(model, val_loader)
    print(f"[epoch {epoch}] val F1 = {f1:.4f}")

Epoch 0: 1574it [02:49,  9.31it/s, loss=0.0109]


[epoch 0] val F1 = 0.9653


Epoch 1: 1574it [02:48,  9.34it/s, loss=0.00454]


[epoch 1] val F1 = 0.9760


Epoch 2: 1574it [02:47,  9.40it/s, loss=0.00145]


[epoch 2] val F1 = 0.9795


Epoch 3: 1574it [02:48,  9.34it/s, loss=0.00329]


[epoch 3] val F1 = 0.9816


Epoch 4: 1574it [02:49,  9.29it/s, loss=0.000541]


[epoch 4] val F1 = 0.9827


Epoch 5: 1574it [02:48,  9.32it/s, loss=0.000579]


[epoch 5] val F1 = 0.9837


Epoch 6: 1574it [02:52,  9.11it/s, loss=0.000566]


[epoch 6] val F1 = 0.9840


Epoch 7: 1574it [02:51,  9.19it/s, loss=0.00058]


[epoch 7] val F1 = 0.9844


Epoch 8: 1574it [02:51,  9.19it/s, loss=0.001]  


[epoch 8] val F1 = 0.9844


Epoch 9: 1574it [02:51,  9.19it/s, loss=0.00229]


[epoch 9] val F1 = 0.9848


Epoch 10: 1574it [02:51,  9.19it/s, loss=7.63e-5]


[epoch 10] val F1 = 0.9852


Epoch 11: 1574it [02:50,  9.22it/s, loss=0.000151]


[epoch 11] val F1 = 0.9853


Сохраняем веса модели и словарь

In [9]:
torch.save(model.state_dict(), "bilstm_ru_stream_v2.pth")
with open("char2id_stream_v2.json", "w", encoding="utf8") as f:
    json.dump(char2id, f, ensure_ascii=False)
print("Saved: LTSM_last.pth, char2id_stream_v2.json")

Saved: LTSM_last.pth, char2id_stream_v2.json


Дальше восстанавливаем модель и инференсим

In [10]:

class CharBoundaryTagger(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = nn.LSTM(emb_dim, hidden, num_layers=2, dropout=0.2, batch_first=True, bidirectional=True)
        self.out = nn.Linear(hidden * 2, 1)
    def forward(self, x, lengths):
        emb = self.emb(x)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        h, _ = self.rnn(packed)
        h, _ = nn.utils.rnn.pad_packed_sequence(h, batch_first=True)
        logits = self.out(h).squeeze(-1)
        return logits[:, :-1]

def load_bilstm_autoconfig(paths, device="cpu"):
    with open(paths["char2id"], "r", encoding="utf8") as f:
        char2id = json.load(f)

    sd = torch.load(paths["weights"], map_location="cpu")
    emb_dim = sd["emb.weight"].shape[1] 
    out_in = sd["out.weight"].shape[1] 
    hidden = out_in // 2

    model = CharBoundaryTagger(len(char2id), emb_dim=emb_dim, hidden=hidden)
    model.load_state_dict(sd, strict=True)
    model.to(device).eval()
    print(f"[BiLSTM] loaded emb_dim={emb_dim}, hidden={hidden}, vocab={len(char2id)}")
    return model, char2id

@torch.no_grad()
def boundary_probs_for_string(s: str, model, char2id, device="cpu"):
    s = s.lower().replace("ё", "е")
    if len(s) <= 1:
        return []
    def _one_pass(seq: str):
        x = torch.tensor([[char2id.get(ch, 1) for ch in seq]], dtype=torch.long, device=device)
        lengths = torch.tensor([x.size(1)], dtype=torch.long, device=device)
        logits = model(x, lengths) # [1, T-1]
        return torch.sigmoid(logits).squeeze(0).tolist()
    if len(s) <= MAX_LEN:
        return _one_pass(s)
    win, step = MAX_LEN, MAX_LEN - 16
    acc = [0.0]*(len(s)-1); cnt = [0]*(len(s)-1)
    start = 0
    while start < len(s):
        end = min(len(s), start+win)
        probs = _one_pass(s[start:end])
        for k,p in enumerate(probs):
            gi = start + k
            if gi < len(acc):
                acc[gi]+=p; cnt[gi]+=1
        if end == len(s): break
        start += step
    return [acc[i]/max(1,cnt[i]) for i in range(len(acc))]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BILSTM_PATHS = {
    "weights": "bilstm_ru_stream_v2.pth",
    "char2id": "char2id_stream_v2.json"  
}

model_bi, char2id_bi = load_bilstm_autoconfig(BILSTM_PATHS, device=DEVICE)
print("LSTM loaded")

[BiLSTM] loaded emb_dim=128, hidden=256, vocab=1017
LSTM loaded


Собственно, основная функция сегментации restore_spaces_dp_lstm

In [11]:
def restore_spaces_dp_lstm(
    text: str,
    model=None, char2id=None, device="cpu",
    alpha: float = 0.65, 
    max_word_len: int = 28,
    hint_bonus: float = 0.15,
    use_logit: bool = True,
    temp: float = 1.3,   
    bias: float = -0.15, 
    conf_threshold: float = 0.55,
    lambda_cut: float = 0.07,    
    short_penalty: float = 0.8  
) -> str:
    if not isinstance(text, str) or len(text) == 0:
        return ""

    parts, i = [], 0
    while i < len(text):
        ch = text[i]
        if ch in _PUNCT:
            parts.append((ch, True))
            i += 1
        else:
            j = i + 1
            while j < len(text) and text[j] not in _PUNCT:
                j += 1
            parts.append((text[i:j], False))
            i = j

    tokens = []
    for chunk, is_punct in parts:
        if is_punct:
            tokens.append(chunk)
            continue

        run = chunk
        n = len(run)
        if n == 0:
            continue

        hints = [False] * (n + 1)
        for k in range(1, n):
            if _hard_boundary(run[k - 1], run[k]):
                hints[k] = True

        probs = None
        if (model is not None) and (char2id is not None) and (n > 1):
            probs = boundary_probs_for_string(run, model, char2id, device=device)

        dp = np.full(n + 1, -1e18, dtype=float)
        prv = np.full(n + 1, -1, dtype=int)
        dp[0] = 0.0

        for end in range(1, n + 1):
            start_lim = max(0, end - max_word_len)
            best_val, best_j = -1e18, -1

            for j in range(start_lim, end):
                cand = run[j:end]
                s = _cached_score_word(cand)

                if len(cand) == 1 and cand.lower() not in _SHORT_OK:
                    s -= short_penalty

                if j > 0 and hints[j]:
                    s += hint_bonus

                if (probs is not None) and (j > 0) and (j < n):
                    p = min(max(probs[j - 1], 1e-6), 1 - 1e-6)
                    if p >= conf_threshold:
                        if use_logit:
                            logit = math.log(p) - math.log(1 - p)
                            nn_term = (logit + bias) / max(1e-6, temp)
                        else:
                            logit = math.log(p) - math.log(1 - p)
                            p_adj = 1 / (1 + math.exp(-(logit + bias) / max(1e-6, temp)))
                            nn_term = math.log(p_adj)
                        s += alpha * nn_term

                if j > 0:
                    s -= lambda_cut

                val = dp[j] + s
                if (val > best_val) or (val == best_val and j == start_lim):
                    best_val, best_j = val, j

            dp[end], prv[end] = best_val, best_j

        if prv[n] == -1:
            tokens.append(run)
        else:
            rev = []
            k = n
            while k > 0:
                j = prv[k]
                rev.append(run[j:k])
                k = j
            tokens.extend(rev[::-1])

    out = []
    for k, tok in enumerate(tokens):
        out.append(tok)
        if k == len(tokens) - 1:
            break
        nxt = tokens[k + 1]
        if tok in _PUNCT:
            if tok in {",", "; ", ":", "!", "?", "…", "."}:
                out.append(" ")
            else:
                out.append(" ")
        elif nxt in _PUNCT:
            pass
        else:
            out.append(" ")

    result = "".join(out).strip()
    if not isinstance(result, str):
        result = str(result)
    return result

# ВЕС (alpha) вклада LSTM в итоговую оценку слова в DP. 
# Чем выше, тем больше доверия модели LSTM при принятии решения о разбиении.
ALPHA = 0.65

# МАКСИМАЛЬНАЯ длина потенциального "слова", которое алгоритм будет пытаться 
# рассмотреть при сегментации (в DP). Слова длиннее будут отброшены на этапе перебора.
MAX_WORD_LEN = 28

# БОНУС, который добавляется к оценке слова, если оно начинается 
# в позиции, помеченной эвристикой `_hard_boundary` (например, смена регистра).
HINT_BONUS = 0.15

# ФЛАГ: использовать ли логит (разность логарифмов вероятностей) от LSTM 
# или напрямую вероятность пробела при расчёте вклада нейросети.
USE_LOGIT = True

# ТЕМПЕРАТУРА для "разогрева" (или охлаждения) логита от LSTM перед тем, 
# как он пойдёт в оценку. Меньшее значение -> мягче вклад модели.
TEMP = 1.3

# СМЕЩЕНИЕ (bias) для логита от LSTM перед тем, как он пойдёт в оценку. 
# Положительное смещение -> сеть "увереннее" в пробеле, отрицательное - наоборот.
BIAS = -0.15

# ПОРОГ уверенности (вероятности пробела от LSTM), выше которого 
# её вклад вообще учитывается в оценке слова.
CONF_TH = 0.55

# ШТРАФ за разрез (то есть за установку пробела). 
# Уменьшение этого значения делает разбиение "щедрее".
LAMBDA_CUT = 0.07

# ШТРАФ за односимвольные слова (если они не входят в `_SHORT_OK`).
# Уменьшение делает такие слова "дешевле", то есть алгоритм охотнее их разрешает.
SHORT_PEN = 0.8

# Устройство (GPU или CPU), на котором будет выполняться инференс модели.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
ALPHA = 0.65 # вклад LSTM в оценку слова в DP
MAX_WORD_LEN = 28
HINT_BONUS = 0.15 # эвристический бонус  
USE_LOGIT = True 
TEMP = 1.3 
BIAS = -0.15 
CONF_TH = 0.55 # порог уверенности LSTM (при уверенности ниже порога вклад LSTM не учитываем в DP)
LAMBDA_CUT = 0.07 # штраф пробела 
SHORT_PEN = 0.8 # штраф за односимвольные слова
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

pred_texts = []
for s in tqdm(task_data["text_no_spaces"].tolist(), desc="Predicting DP+LSTM"):
    restored = restore_spaces_dp_lstm(
        s,
        model=model_bi, char2id=char2id_bi, device=DEVICE,
        alpha=ALPHA, max_word_len=MAX_WORD_LEN, hint_bonus=HINT_BONUS,
        use_logit=USE_LOGIT, temp=TEMP, bias=BIAS, conf_threshold=CONF_TH,
        lambda_cut=LAMBDA_CUT, short_penalty=SHORT_PEN
    )
    pred_texts.append(restored)

task_data["restored_text"] = pred_texts
out_csv = "predicted_text.csv"
task_data[["id", "restored_text"]].to_csv(out_csv, index=False, encoding="utf-8")
print("Saved:", out_csv)

Predicting DP+LSTM: 100%|██████████| 1005/1005 [00:02<00:00, 369.47it/s]

Saved: predicted_text.csv





In [16]:
def space_positions_compressed(restored_text: str) -> list[int]:
    positions = []
    cursor = 0
    for ch in restored_text:
        if ch == ' ':
            positions.append(cursor)
        else:
            cursor += 1
    return positions

predicted_positions_list = []
for _, row in tqdm(task_data.iterrows(), total=len(task_data), desc="Finding space positions"):
    restored = row["restored_text"]
    positions = space_positions_compressed(restored)
    predicted_positions_list.append(positions)

task_data["predicted_positions"] = [str(xs) for xs in predicted_positions_list]

def create_submission_file(task_data, output_file: str):
    sub = task_data[["id", "predicted_positions"]].copy()
    sub.to_csv(output_file, index=False, encoding="utf-8")
    print(f"Submission file saved to {output_file}")
    print(sub.head())

create_submission_file(task_data, "submission.csv")

Finding space positions: 100%|██████████| 1005/1005 [00:00<00:00, 20689.59it/s]

Submission file saved to submission.csv
  id     predicted_positions
0  0             [5, 10, 12]
1  1           [1, 6, 7, 13]
2  2  [4, 5, 12, 13, 20, 21]
3  3             [5, 10, 18]
4  4                 [5, 10]



