In [1]:
from __future__ import annotations

import argparse
import difflib
import json
import math
import os
import platform
import random
import re
import sys
import unicodedata
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from sklearn.model_selection import GroupShuffleSplit




In [2]:
# =========================
# Utils
# =========================

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def normalize_form(s: object) -> str:
    if s is None:
        return ""
    if isinstance(s, float) and np.isnan(s):
        return ""
    s = str(s).strip()
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("’", "'").replace("ʻ", "'").replace("`", "'").replace("´", "'")
    s = re.sub(r"\s+", " ", s)
    return s

def strip_diacritics(s: str) -> str:
    s = unicodedata.normalize("NFD", s)
    s = "".join(ch for ch in s if unicodedata.category(ch) != "Mn")
    return unicodedata.normalize("NFKC", s)

_CYR2LAT = {
    "а":"a","б":"b","в":"v","г":"g","д":"d","е":"e","ё":"yo","ж":"zh","з":"z","и":"i","й":"y",
    "к":"k","л":"l","м":"m","н":"n","о":"o","п":"p","р":"r","с":"s","т":"t","у":"u","ф":"f",
    "х":"kh","ц":"ts","ч":"ch","ш":"sh","щ":"shch","ъ":"","ы":"y","ь":"","э":"e","ю":"yu","я":"ya",
    # монгольские кириллические
    "ө":"o","ү":"u","ң":"ng","һ":"h",
}

def cyrillic_to_latin(s: str) -> str:
    s = normalize_form(s).lower()
    out = []
    for ch in s:
        out.append(_CYR2LAT.get(ch, ch))
    return "".join(out)

def romanize_for_alignment(s: str) -> str:
    """
    Единое представление для сравнения разных алфавитов:
    - если есть кириллица -> простая латинизация
    - затем удаление диакритики (xiào -> xiao)
    """
    s = normalize_form(s)
    # если в строке есть кириллица, латинизируем
    if re.search(r"[А-Яа-яЁёӨөҮүҢңҺһ]", s):
        s = cyrillic_to_latin(s)
    s = s.lower()
    s = strip_diacritics(s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def parse_pairs_txt(path: str) -> pd.DataFrame:
    """
    Ожидаемый формат строк:
      form1 gloss1 form2 gloss2 1
    или
      form1 form2 1
    или (для text2)
      ... 0
    """
    rows = []
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip().replace("\ufeff", "")
            if not line:
                continue
            parts = re.split(r"\s+", line)
            if len(parts) < 3:
                continue

            if re.fullmatch(r"[01]", parts[-1]):
                label = int(parts[-1])
                core = parts[:-1]
            else:
                label = 1
                core = parts

            if len(core) == 2:
                form1, form2 = core
                gloss1, gloss2 = "", ""
            elif len(core) == 3:
                form1, form2, gloss2 = core
                gloss1 = ""
            else:
                form1 = core[0]
                gloss1 = core[1]
                form2 = core[2]
                gloss2 = " ".join(core[3:])

            rows.append(
                {
                    "form1": normalize_form(form1),
                    "form2": normalize_form(form2),
                    "gloss1": gloss1,
                    "gloss2": gloss2,
                    "label": label,
                }
            )

    df = pd.DataFrame(rows)
    return df


def build_id_map(series: pd.Series) -> Dict[str, int]:
    uniq = sorted(set(series.dropna().astype(str).tolist()))
    return {v: i for i, v in enumerate(uniq)}


def compute_features(src_form: str, tgt_form: str) -> np.ndarray:
    a = normalize_form(src_form)
    b = normalize_form(tgt_form)
    la, lb = len(a), len(b)
    minl = max(1, min(la, lb))
    maxl = max(1, max(la, lb))

    ratio = difflib.SequenceMatcher(None, a, b).ratio()

    pref = 0
    for x, y in zip(a, b):
        if x == y:
            pref += 1
        else:
            break

    suf = 0
    for x, y in zip(a[::-1], b[::-1]):
        if x == y:
            suf += 1
        else:
            break

    pref_r = pref / minl
    suf_r = suf / minl

    def bigrams(s: str) -> set:
        return set([s[i : i + 2] for i in range(len(s) - 1)]) if len(s) >= 2 else set()

    A = bigrams(a)
    B = bigrams(b)
    jac = (len(A & B) / max(1, len(A | B))) if (A or B) else 0.0

    feat = np.array(
        [
            la,
            lb,
            abs(la - lb),
            la / maxl,
            lb / maxl,
            ratio,
            pref_r,
            suf_r,
            jac,
        ],
        dtype=np.float32,
    )
    return feat


In [3]:
# =========================
# Encoding (byte-level + meta)
# =========================

PAD_ID = 0
CLS_ID = 1
SEP_ID = 2
BYTE_OFFSET = 3  # bytes 0..255 -> 3..258
META_OFFSET = BYTE_OFFSET + 256  # 259


def to_bytes_ids(s: str, max_bytes: int) -> List[int]:
    b = normalize_form(s).encode("utf-8", errors="ignore")[:max_bytes]
    return [BYTE_OFFSET + int(x) for x in b]

In [4]:
# =========================
# Dataset
# =========================

FEAT_DIM = 18

def _features_for_strings(a: str, b: str) -> np.ndarray:
    la, lb = len(a), len(b)
    minl = max(1, min(la, lb))
    maxl = max(1, max(la, lb))

    ratio = difflib.SequenceMatcher(None, a, b).ratio()

    pref = 0
    for x, y in zip(a, b):
        if x == y:
            pref += 1
        else:
            break

    suf = 0
    for x, y in zip(a[::-1], b[::-1]):
        if x == y:
            suf += 1
        else:
            break

    pref_r = pref / minl
    suf_r = suf / minl

    def bigrams(s: str) -> set:
        return set([s[i:i+2] for i in range(len(s)-1)]) if len(s) >= 2 else set()

    A = bigrams(a)
    B = bigrams(b)
    jac = (len(A & B) / max(1, len(A | B))) if (A or B) else 0.0

    return np.array(
        [la, lb, abs(la-lb), la/maxl, lb/maxl, ratio, pref_r, suf_r, jac],
        dtype=np.float32
    )

def compute_features(src_form: str, tgt_form: str) -> np.ndarray:
    # raw
    a_raw = normalize_form(src_form)
    b_raw = normalize_form(tgt_form)
    f_raw = _features_for_strings(a_raw, b_raw)

    # romanized (склеивает кириллицу/латиницу и убирает диакритику)
    a_rom = romanize_for_alignment(src_form)
    b_rom = romanize_for_alignment(tgt_form)
    f_rom = _features_for_strings(a_rom, b_rom)

    return np.concatenate([f_raw, f_rom], axis=0).astype(np.float32)



@dataclass
class EncodedItem:
    token_ids: List[int]
    seg_ids: List[int]
    feats: np.ndarray
    label: int


class LoanDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        *,
        lang2id: Dict[str, int],
        field2id: Dict[str, int],
        cat2id: Dict[str, int],
        max_src_bytes: int,
        max_tgt_bytes: int,
        max_len: int,
        lang_base: int,
        field_base: int,
        cat_base: int,
    ):
        self.df = df.reset_index(drop=True)
        self.lang2id = lang2id
        self.field2id = field2id
        self.cat2id = cat2id
        self.max_src_bytes = max_src_bytes
        self.max_tgt_bytes = max_tgt_bytes
        self.max_len = max_len
        self.lang_base = lang_base
        self.field_base = field_base
        self.cat_base = cat_base

    def __len__(self) -> int:
        return len(self.df)

    def encode_example(
    self, src_form: str, tgt_form: str, tgt_lang: str, sem_field: str, sem_cat: str
) -> Tuple[List[int], List[int]]:
        lang_id = self.lang2id.get(normalize_form(tgt_lang), 0)
        field_id = self.field2id.get(normalize_form(sem_field), 0)
        cat_id = self.cat2id.get(normalize_form(sem_cat), 0)

        tokens = [
            CLS_ID,
            self.lang_base + lang_id,
            self.field_base + field_id,
            self.cat_base + cat_id,
            SEP_ID,
        ]
        seg = [0, 0, 0, 0, 0]  # meta

        src_raw = normalize_form(src_form)
        src_rom = romanize_for_alignment(src_form)
        tgt_raw = normalize_form(tgt_form)
        tgt_rom = romanize_for_alignment(tgt_form)

        src_raw_ids = to_bytes_ids(src_raw, self.max_src_bytes)
        tokens += src_raw_ids + [SEP_ID]
        seg += [1] * len(src_raw_ids) + [1]

        src_rom_ids = to_bytes_ids(src_rom, self.max_src_bytes)
        tokens += src_rom_ids + [SEP_ID]
        seg += [2] * len(src_rom_ids) + [2]

        tgt_raw_ids = to_bytes_ids(tgt_raw, self.max_tgt_bytes)
        tokens += tgt_raw_ids + [SEP_ID]
        seg += [3] * len(tgt_raw_ids) + [3]

        tgt_rom_ids = to_bytes_ids(tgt_rom, self.max_tgt_bytes)
        tokens += tgt_rom_ids + [SEP_ID]
        seg += [4] * len(tgt_rom_ids) + [4]

        tokens = tokens[: self.max_len]
        seg = seg[: self.max_len]
        return tokens, seg


    def __getitem__(self, idx: int) -> EncodedItem:
        row = self.df.iloc[idx]
        tok, seg = self.encode_example(
            row["src_form"],
            row["tgt_form"],
            row["tgt_lang"],
            row["sem_field"],
            row["sem_cat"],
        )
        feats = compute_features(row["src_form"], row["tgt_form"])
        label = int(row["label_loan"])
        return EncodedItem(tok, seg, feats, label)


def make_collate_fn(max_len: int):
    def collate_fn(batch: List[EncodedItem]):
        mx = min(max(len(x.token_ids) for x in batch), max_len)

        token_ids = torch.full((len(batch), mx), PAD_ID, dtype=torch.long)
        seg_ids = torch.zeros((len(batch), mx), dtype=torch.long)
        feats = torch.zeros((len(batch), FEAT_DIM), dtype=torch.float32)
        labels = torch.zeros((len(batch),), dtype=torch.float32)

        for i, item in enumerate(batch):
            t = item.token_ids[:mx]
            s = item.seg_ids[:mx]
            token_ids[i, : len(t)] = torch.tensor(t, dtype=torch.long)
            seg_ids[i, : len(s)] = torch.tensor(s, dtype=torch.long)
            feats[i] = torch.tensor(item.feats, dtype=torch.float32)
            labels[i] = float(item.label)

        pad_mask = token_ids == PAD_ID
        return token_ids, seg_ids, pad_mask, feats, labels

    return collate_fn

In [5]:
# =========================
# Model
# =========================

class ByteCrossEncoder(nn.Module):
    def __init__(
        self,
        *,
        vocab_size: int,
        max_len: int,
        d_model: int,
        n_layers: int,
        n_heads: int,
        ff_mult: int,
        dropout: float,
        feat_dim: int,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.d_model = d_model

        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID)
        self.seg_emb = nn.Embedding(5, d_model)  # 0 meta, 1 src_raw, 2 src_rom, 3 tgt_raw, 4 tgt_rom

        self.pos_emb = nn.Embedding(max_len, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * ff_mult,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.ln = nn.LayerNorm(d_model)

        self.cls_mlp = nn.Sequential(
            nn.Linear(d_model + feat_dim, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1),
        )

    def forward(
        self,
        token_ids: torch.Tensor,
        seg_ids: torch.Tensor,
        pad_mask: torch.Tensor,
        feats: torch.Tensor,
    ) -> torch.Tensor:
        B, T = token_ids.shape
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, T)
        x = self.tok_emb(token_ids) + self.seg_emb(seg_ids) + self.pos_emb(pos)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        x = self.ln(x)
        cls = x[:, 0, :]
        z = torch.cat([cls, feats], dim=1)
        logits = self.cls_mlp(z).squeeze(1)
        return logits

In [6]:
# =========================
# Data construction
# =========================

def load_book2(book2_path: str) -> pd.DataFrame:
    book2 = pd.read_excel(book2_path)

    expected_cols = [
        "ID",
        "Meaning",
        "SemanticField",
        "SemanticCategory",
        "Target_Form",
        "Target_Language_Name",
        "Donor",
    ]
    missing = [c for c in expected_cols if c not in book2.columns]
    if missing:
        raise ValueError(f"В Книга2.xlsx отсутствуют ожидаемые столбцы: {missing}")

    book2 = book2.copy()
    book2["src_form"] = book2["Donor"].map(normalize_form)
    book2["tgt_form"] = book2["Target_Form"].map(normalize_form)
    book2["tgt_lang"] = book2["Target_Language_Name"].astype(str).map(normalize_form)
    book2["meaning"] = book2["Meaning"].astype(str)
    book2["sem_field"] = book2["SemanticField"].astype(str).map(normalize_form)
    book2["sem_cat"] = book2["SemanticCategory"].astype(str).map(normalize_form)
    book2["label_loan"] = 1
    book2["source"] = "book2"
    book2["neg_type"] = ""

    book2 = book2[(book2["src_form"] != "") & (book2["tgt_form"] != "")].reset_index(drop=True)
    return book2


def sample_meta_from_book2(book2: pd.DataFrame, n: int, rng: random.Random) -> pd.DataFrame:
    idx = rng.choices(range(len(book2)), k=n)
    meta = book2.iloc[idx][["tgt_lang", "meaning", "sem_field", "sem_cat"]].reset_index(drop=True)
    return meta


def build_cognate_negatives(book2: pd.DataFrame, cognates_path: Optional[str], rng: random.Random) -> pd.DataFrame:
    cols = ["src_form", "tgt_form", "tgt_lang", "meaning", "sem_field", "sem_cat", "label_loan", "neg_type", "source"]
    if not cognates_path or not os.path.exists(cognates_path):
        return pd.DataFrame(columns=cols)

    raw = parse_pairs_txt(cognates_path)
    raw = raw[raw["label"] == 1].copy()
    if raw.empty:
        return pd.DataFrame(columns=cols)

    a = raw.rename(columns={"form1": "src_form", "form2": "tgt_form"})[["src_form", "tgt_form"]]
    b = raw.rename(columns={"form2": "src_form", "form1": "tgt_form"})[["src_form", "tgt_form"]]
    cog = pd.concat([a, b], ignore_index=True)
    cog = cog[(cog["src_form"] != "") & (cog["tgt_form"] != "")].reset_index(drop=True)

    meta = sample_meta_from_book2(book2, len(cog), rng)
    df = pd.concat([cog, meta], axis=1)
    df["label_loan"] = 0
    df["neg_type"] = "cognate"
    df["source"] = "cognates_txt"
    return df[cols]


def build_txt2_negatives(book2: pd.DataFrame, txt2_path: Optional[str], rng: random.Random) -> pd.DataFrame:
    cols = ["src_form", "tgt_form", "tgt_lang", "meaning", "sem_field", "sem_cat", "label_loan", "neg_type", "source"]
    if not txt2_path or not os.path.exists(txt2_path):
        return pd.DataFrame(columns=cols)

    raw = parse_pairs_txt(txt2_path)
    raw = raw[raw["label"] == 0].copy()
    if raw.empty:
        return pd.DataFrame(columns=cols)

    a = raw.rename(columns={"form1": "src_form", "form2": "tgt_form"})[["src_form", "tgt_form"]]
    b = raw.rename(columns={"form2": "src_form", "form1": "tgt_form"})[["src_form", "tgt_form"]]
    negpairs = pd.concat([a, b], ignore_index=True)
    negpairs = negpairs[(negpairs["src_form"] != "") & (negpairs["tgt_form"] != "")].reset_index(drop=True)

    meta = sample_meta_from_book2(book2, len(negpairs), rng)
    df = pd.concat([negpairs, meta], axis=1)
    df["label_loan"] = 0
    df["neg_type"] = "txt2"
    df["source"] = "text2_txt"
    return df[cols]


def build_random_negatives(book2: pd.DataFrame, n: int, rng: random.Random) -> pd.DataFrame:
    idx_t = rng.choices(range(len(book2)), k=n)
    idx_s = rng.choices(range(len(book2)), k=n)

    tgt = book2.iloc[idx_t][["tgt_form", "tgt_lang", "meaning", "sem_field", "sem_cat"]].reset_index(drop=True)
    src = book2.iloc[idx_s][["src_form"]].reset_index(drop=True)
    df = pd.concat([src, tgt], axis=1)

    true_src = book2.iloc[idx_t]["src_form"].reset_index(drop=True)
    mask_same = df["src_form"].values == true_src.values
    if mask_same.any():
        repl_idx = rng.choices(range(len(book2)), k=int(mask_same.sum()))
        df.loc[mask_same, "src_form"] = book2.iloc[repl_idx]["src_form"].values

    df["label_loan"] = 0
    df["neg_type"] = "random"
    df["source"] = "generated_random"
    return df


def build_hard_negatives(book2: pd.DataFrame, n: int, rng: random.Random, sample_pool: int = 80) -> pd.DataFrame:
    donors = book2["src_form"].tolist()
    rows = []
    for _ in range(n):
        j = rng.randrange(len(book2))
        tgt_form = book2.iloc[j]["tgt_form"]
        true_src = book2.iloc[j]["src_form"]
        meta = book2.iloc[j][["tgt_lang", "meaning", "sem_field", "sem_cat"]].to_dict()

        cand_idx = rng.sample(range(len(donors)), k=min(sample_pool, len(donors)))
        best = None
        best_score = -1.0
        for ci in cand_idx:
            s = donors[ci]
            if s == true_src:
                continue
            score = difflib.SequenceMatcher(None, s, tgt_form).ratio()
            if score > best_score:
                best_score = score
                best = s

        if best is None:
            best = donors[rng.randrange(len(donors))]

        rows.append(
            {
                "src_form": best,
                "tgt_form": tgt_form,
                "tgt_lang": meta["tgt_lang"],
                "meaning": meta["meaning"],
                "sem_field": meta["sem_field"],
                "sem_cat": meta["sem_cat"],
                "label_loan": 0,
                "neg_type": "hard_similar",
                "source": "generated_hard",
            }
        )

    return pd.DataFrame(rows)


In [7]:
def assemble_dataset(
    book2: pd.DataFrame,
    cognates_path: Optional[str],
    txt2_path: Optional[str],
    seed: int,
    neg_frac_cognate: float,
    neg_frac_random: float,
    neg_frac_hard: float,
) -> pd.DataFrame:
    rng = random.Random(seed)

    pos_df = book2[
        ["src_form", "tgt_form", "tgt_lang", "meaning", "sem_field", "sem_cat", "label_loan", "neg_type", "source"]
    ].copy()

    cog_df = build_cognate_negatives(book2, cognates_path, rng)
    txt2_df = build_txt2_negatives(book2, txt2_path, rng)

    available = pd.concat([cog_df, txt2_df], ignore_index=True)
    available = available.drop_duplicates(subset=["src_form", "tgt_form", "tgt_lang", "sem_field", "sem_cat"])

    N_pos = len(pos_df)
    n_need = N_pos

    if len(available) >= n_need:
        neg_df = available.sample(n=n_need, random_state=seed).reset_index(drop=True)
    else:
        neg_df = available.copy().reset_index(drop=True)
        remaining = n_need - len(neg_df)

        n_cog_target = int(remaining * neg_frac_cognate)
        n_rand_target = int(remaining * neg_frac_random)
        n_hard_target = remaining - n_cog_target - n_rand_target

        # если когнатов мало, переносим дефицит в random/hard
        if len(cog_df) < n_cog_target:
            deficit = n_cog_target - len(cog_df)
            n_cog_target = len(cog_df)
            n_rand_target += deficit // 2
            n_hard_target += deficit - deficit // 2

        if n_cog_target > 0 and len(cog_df) > 0:
            extra_cog = cog_df.sample(
                n=n_cog_target,
                replace=(len(cog_df) < n_cog_target),
                random_state=seed,
            )
            neg_df = pd.concat([neg_df, extra_cog], ignore_index=True)

        if n_rand_target > 0:
            neg_df = pd.concat([neg_df, build_random_negatives(book2, n_rand_target, rng)], ignore_index=True)

        if n_hard_target > 0:
            neg_df = pd.concat([neg_df, build_hard_negatives(book2, n_hard_target, rng)], ignore_index=True)

        neg_df = neg_df.sample(frac=1.0, random_state=seed).reset_index(drop=True)

    data = pd.concat([pos_df, neg_df], ignore_index=True).dropna()
    data["src_form_norm"] = data["src_form"].map(normalize_form)
    data["tgt_form_norm"] = data["tgt_form"].map(normalize_form)
    data = data[(data["src_form_norm"] != "") & (data["tgt_form_norm"] != "")].reset_index(drop=True)

    return data


def group_split_by_tgt_form(
    data: pd.DataFrame, seed: int, test_size: float = 0.15, val_size: float = 0.15
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    groups = data["tgt_form_norm"].values

    gss = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
    train_idx, test_idx = next(gss.split(data, groups=groups))
    train_df = data.iloc[train_idx].reset_index(drop=True)
    test_df = data.iloc[test_idx].reset_index(drop=True)

    gss2 = GroupShuffleSplit(n_splits=1, test_size=val_size, random_state=seed)
    tr_idx, va_idx = next(gss2.split(train_df, groups=train_df["tgt_form_norm"].values))
    tr_df = train_df.iloc[tr_idx].reset_index(drop=True)
    va_df = train_df.iloc[va_idx].reset_index(drop=True)
    return tr_df, va_df, test_df


In [8]:
# =========================
# Train / Eval
# =========================

def evaluate(loader: DataLoader, model: nn.Module, device: torch.device) -> Dict[str, float]:
    model.eval()
    loss_fn = nn.BCEWithLogitsLoss()

    all_y: List[np.ndarray] = []
    all_p: List[np.ndarray] = []
    total_loss = 0.0
    n = 0

    with torch.no_grad():
        for token_ids, seg_ids, pad_mask, feats, labels in loader:
            token_ids = token_ids.to(device)
            seg_ids = seg_ids.to(device)
            pad_mask = pad_mask.to(device)
            feats = feats.to(device)
            labels = labels.to(device)

            logits = model(token_ids, seg_ids, pad_mask, feats)
            loss = loss_fn(logits, labels)

            total_loss += float(loss.item()) * labels.size(0)
            n += labels.size(0)

            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_p.append(probs)
            all_y.append(labels.detach().cpu().numpy())

    y = np.concatenate(all_y) if all_y else np.array([], dtype=np.float32)
    p = np.concatenate(all_p) if all_p else np.array([], dtype=np.float32)
    pred = (p >= 0.5).astype(int) if len(p) else np.array([], dtype=int)

    acc = float(accuracy_score(y, pred)) if len(y) else 0.0
    pr, rc, f1, _ = precision_recall_fscore_support(y, pred, average="binary", zero_division=0) if len(y) else (0,0,0,None)
    try:
        auc = float(roc_auc_score(y, p)) if len(y) else float("nan")
    except Exception:
        auc = float("nan")

    return {
        "loss": total_loss / max(1, n),
        "acc": acc,
        "precision": float(pr),
        "recall": float(rc),
        "f1": float(f1),
        "auc": auc,
    }


def train_loop(
    *,
    model: nn.Module,
    tr_loader: DataLoader,
    va_loader: DataLoader,
    device: torch.device,
    outdir: Path,
    epochs: int,
    lr: float,
    weight_decay: float,
    grad_clip_norm: float,
    patience: int,
    use_amp: bool,
) -> Dict[str, object]:
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = GradScaler(enabled=(use_amp and device.type == "cuda"))
    loss_fn = nn.BCEWithLogitsLoss()

    best_f1 = -1.0
    bad_epochs = 0
    best_path = outdir / "best_model.pt"
    metrics_path = outdir / "train_metrics.jsonl"

    with metrics_path.open("w", encoding="utf-8") as mf:
        for epoch in range(1, epochs + 1):
            model.train()
            total_loss = 0.0
            n = 0

            for token_ids, seg_ids, pad_mask, feats, labels in tr_loader:
                token_ids = token_ids.to(device)
                seg_ids = seg_ids.to(device)
                pad_mask = pad_mask.to(device)
                feats = feats.to(device)
                labels = labels.to(device)

                optimizer.zero_grad(set_to_none=True)

                with autocast(enabled=(use_amp and device.type == "cuda")):
                    logits = model(token_ids, seg_ids, pad_mask, feats)
                    loss = loss_fn(logits, labels)

                scaler.scale(loss).backward()
                if grad_clip_norm is not None and grad_clip_norm > 0:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)

                scaler.step(optimizer)
                scaler.update()

                total_loss += float(loss.item()) * labels.size(0)
                n += labels.size(0)

            train_loss = total_loss / max(1, n)
            val_metrics = evaluate(va_loader, model, device)

            log_row = {
                "epoch": epoch,
                "train_loss": float(train_loss),
                **{f"val_{k}": float(v) if isinstance(v, (int, float, np.floating)) else v for k, v in val_metrics.items()},
            }
            mf.write(json.dumps(log_row, ensure_ascii=False) + "\n")
            mf.flush()

            print(
                f"Epoch {epoch:02d} | "
                f"train_loss={train_loss:.4f} | "
                f"val_loss={val_metrics['loss']:.4f} "
                f"acc={val_metrics['acc']:.4f} "
                f"f1={val_metrics['f1']:.4f} "
                f"auc={val_metrics['auc']:.4f}"
            )

            if val_metrics["f1"] > best_f1 + 1e-5:
                best_f1 = val_metrics["f1"]
                bad_epochs = 0
                torch.save({"model_state": model.state_dict()}, best_path)
            else:
                bad_epochs += 1
                if bad_epochs >= patience:
                    print(f"Early stopping: no F1 improvement for {patience} epochs.")
                    break

    ckpt = torch.load(best_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    return {"best_f1": float(best_f1), "best_model_path": str(best_path)}


In [9]:
# =========================
# Save / Load artifacts
# =========================

def save_artifacts(
    *,
    outdir: Path,
    config: Dict[str, object],
    lang2id: Dict[str, int],
    field2id: Dict[str, int],
    cat2id: Dict[str, int],
) -> None:
    outdir.mkdir(parents=True, exist_ok=True)

    (outdir / "config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
    maps = {"lang2id": lang2id, "field2id": field2id, "cat2id": cat2id}
    (outdir / "meta_maps.json").write_text(json.dumps(maps, ensure_ascii=False, indent=2), encoding="utf-8")

    tokenizer_cfg = {
        "normalization": {
            "unicode": "NFKC",
            "apostrophes": ["’", "ʻ", "`", "´"],
            "apostrophe_repl": "'",
            "collapse_spaces": True,
        }
    }
    (outdir / "tokenizer_config.json").write_text(json.dumps(tokenizer_cfg, ensure_ascii=False, indent=2), encoding="utf-8")

    env = {
        "python": sys.version,
        "platform": platform.platform(),
        "torch": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
    }
    (outdir / "env.json").write_text(json.dumps(env, ensure_ascii=False, indent=2), encoding="utf-8")


def load_model_for_inference(artifact_dir: str, device: Optional[str] = None) -> Tuple[ByteCrossEncoder, Dict[str, object]]:
    ad = Path(artifact_dir)
    cfg = json.loads((ad / "config.json").read_text(encoding="utf-8"))
    maps = json.loads((ad / "meta_maps.json").read_text(encoding="utf-8"))

    model = ByteCrossEncoder(
        vocab_size=int(cfg["vocab_size"]),
        max_len=int(cfg["max_len"]),
        d_model=int(cfg["d_model"]),
        n_layers=int(cfg["n_layers"]),
        n_heads=int(cfg["n_heads"]),
        ff_mult=int(cfg["ff_mult"]),
        dropout=float(cfg["dropout"]),
        feat_dim=int(cfg["feat_dim"]),
    )

    ckpt = torch.load(ad / "best_model.pt", map_location="cpu")
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    if device is None:
        device_t = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device_t = torch.device(device)
    model.to(device_t)

    runtime = {
        "config": cfg,
        "maps": maps,
        "device": str(device_t),
    }
    return model, runtime

In [10]:
# =====================
# CONFIG (edit here)
# =====================
BOOK2_PATH = "/Users/oksanagoncarova/Desktop/ipynb/Книга2.xlsx"          
COGNATES_TXT_PATH = "/Users/oksanagoncarova/Desktop/ipynb/test_pos.txt"     
NEGATIVES_TXT2_PATH = "/Users/oksanagoncarova/Desktop/ipynb/test_neg.txt"                   

OUTDIR = "loanword_artifacts"

SEED = 42
BATCH_SIZE = 128
EPOCHS = 12
LR = 2e-4
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
PATIENCE = 4
USE_AMP = True  # работает только при CUDA

MAX_SRC_BYTES = 64
MAX_TGT_BYTES = 64

D_MODEL = 256
N_LAYERS = 8
N_HEADS = 8
FF_MULT = 4
DROPOUT = 0.10

NEG_FRAC_COGNATE = 0.50
NEG_FRAC_RANDOM  = 0.25
NEG_FRAC_HARD    = 0.25

FAST_DEV_RUN = False
FAST_ROWS = 5000

# =====================
# RUN
# =====================
set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device)

# Проверка файлов
assert os.path.exists(BOOK2_PATH), f"Не найден файл: {BOOK2_PATH}"
assert os.path.exists(COGNATES_TXT_PATH), f"Не найден файл: {COGNATES_TXT_PATH}"
if NEGATIVES_TXT2_PATH is not None:
    assert os.path.exists(NEGATIVES_TXT2_PATH), f"Не найден файл: {NEGATIVES_TXT2_PATH}"

outdir = Path(OUTDIR)
outdir.mkdir(parents=True, exist_ok=True)

# 1) Позитивы
book2 = load_book2(BOOK2_PATH)
print("Book2 positives:", len(book2))

# 2) Общий датасет (позитивы + негативы)
data = assemble_dataset(
    book2=book2,
    cognates_path=COGNATES_TXT_PATH,
    txt2_path=NEGATIVES_TXT2_PATH,
    seed=SEED,
    neg_frac_cognate=NEG_FRAC_COGNATE,
    neg_frac_random=NEG_FRAC_RANDOM,
    neg_frac_hard=NEG_FRAC_HARD,
)

if FAST_DEV_RUN:
    data = data.sample(n=min(len(data), FAST_ROWS), random_state=SEED).reset_index(drop=True)
    print("[FAST_DEV_RUN] Using rows:", len(data))

print("Total rows:", len(data))
print("Label distribution:", data["label_loan"].value_counts().to_dict())

# 3) Split без утечек (группировка по tgt_form)
tr_df, va_df, te_df = group_split_by_tgt_form(data, seed=SEED, test_size=0.1, val_size=0.1)
print("Train/Val/Test:", len(tr_df), len(va_df), len(te_df))

# 4) Словари метаданных
lang2id = build_id_map(data["tgt_lang"])
field2id = build_id_map(data["sem_field"])
cat2id = build_id_map(data["sem_cat"])

lang_base = META_OFFSET
field_base = lang_base + len(lang2id)
cat_base = field_base + len(field2id)
vocab_size = cat_base + len(cat2id)

max_len = 1 + 3 + 1 + 2 * (MAX_SRC_BYTES + 1) + 2 * (MAX_TGT_BYTES + 1)


print("VOCAB_SIZE:", vocab_size, "MAX_LEN:", max_len)

# 5) DataLoaders
collate_fn = make_collate_fn(max_len)

tr_ds = LoanDataset(
    tr_df,
    lang2id=lang2id,
    field2id=field2id,
    cat2id=cat2id,
    max_src_bytes=MAX_SRC_BYTES,
    max_tgt_bytes=MAX_TGT_BYTES,
    max_len=max_len,
    lang_base=lang_base,
    field_base=field_base,
    cat_base=cat_base,
)
va_ds = LoanDataset(
    va_df,
    lang2id=lang2id,
    field2id=field2id,
    cat2id=cat2id,
    max_src_bytes=MAX_SRC_BYTES,
    max_tgt_bytes=MAX_TGT_BYTES,
    max_len=max_len,
    lang_base=lang_base,
    field_base=field_base,
    cat_base=cat_base,
)
te_ds = LoanDataset(
    te_df,
    lang2id=lang2id,
    field2id=field2id,
    cat2id=cat2id,
    max_src_bytes=MAX_SRC_BYTES,
    max_tgt_bytes=MAX_TGT_BYTES,
    max_len=max_len,
    lang_base=lang_base,
    field_base=field_base,
    cat_base=cat_base,
)

tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, collate_fn=collate_fn)
va_loader = DataLoader(va_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)
te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)

# 6) Model
model = ByteCrossEncoder(
    vocab_size=vocab_size,
    max_len=max_len,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    ff_mult=FF_MULT,
    dropout=DROPOUT,
    feat_dim=FEAT_DIM,
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print("Model params:", n_params)

# 7) Train
# В ноутбуке USE_AMP имеет смысл только при CUDA
use_amp_effective = bool(USE_AMP and device.type == "cuda")

train_info = train_loop(
    model=model,
    tr_loader=tr_loader,
    va_loader=va_loader,
    device=device,
    outdir=outdir,
    epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    grad_clip_norm=GRAD_CLIP,
    patience=PATIENCE,
    use_amp=use_amp_effective,
)

# 8) Test
test_metrics = evaluate(te_loader, model, device)
print("TEST:", test_metrics)

# 9) Save artifacts
config = {
    "vocab_size": vocab_size,
    "max_len": max_len,
    "d_model": D_MODEL,
    "n_layers": N_LAYERS,
    "n_heads": N_HEADS,
    "ff_mult": FF_MULT,
    "dropout": DROPOUT,
    "feat_dim": FEAT_DIM,
    "pad_id": PAD_ID,
    "cls_id": CLS_ID,
    "sep_id": SEP_ID,
    "byte_offset": BYTE_OFFSET,
    "meta_offset": META_OFFSET,
    "lang_base": lang_base,
    "field_base": field_base,
    "cat_base": cat_base,
    "max_src_bytes": MAX_SRC_BYTES,
    "max_tgt_bytes": MAX_TGT_BYTES,
    "seed": SEED,
    "batch_size": BATCH_SIZE,
    "epochs": EPOCHS,
    "lr": LR,
    "weight_decay": WEIGHT_DECAY,
    "grad_clip": GRAD_CLIP,
    "patience": PATIENCE,
    "use_amp": use_amp_effective,
    "neg_fracs": {
        "cognate": NEG_FRAC_COGNATE,
        "random": NEG_FRAC_RANDOM,
        "hard": NEG_FRAC_HARD,
    },
}

save_artifacts(
    outdir=outdir,
    config=config,
    lang2id=lang2id,
    field2id=field2id,
    cat2id=cat2id,
)

summary = {
    "train_info": train_info,
    "test_metrics": test_metrics,
    "rows": {"total": int(len(data)), "train": int(len(tr_df)), "val": int(len(va_df)), "test": int(len(te_df))},
    "device": str(device),
    "model_params": int(n_params),
}
(outdir / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")

print("Saved artifacts to:", str(outdir))
print("Files:", [p.name for p in sorted(outdir.glob('*'))])


DEVICE: cpu
Book2 positives: 20886
Total rows: 41772
Label distribution: {1: 20886, 0: 20886}
Train/Val/Test: 33897 3719 4156
VOCAB_SIZE: 653 MAX_LEN: 265
Model params: 6761473


  scaler = GradScaler(enabled=(use_amp and device.type == "cuda"))
  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 01 | train_loss=0.3027 | val_loss=0.2279 acc=0.9064 f1=0.9089 auc=0.9769


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 02 | train_loss=0.1725 | val_loss=0.1532 acc=0.9465 f1=0.9453 auc=0.9851


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 03 | train_loss=0.1364 | val_loss=0.1377 acc=0.9473 f1=0.9448 auc=0.9883


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 04 | train_loss=0.1077 | val_loss=0.1147 acc=0.9602 f1=0.9590 auc=0.9916


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 05 | train_loss=0.0819 | val_loss=0.1073 acc=0.9640 f1=0.9627 auc=0.9921


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 06 | train_loss=0.0713 | val_loss=0.1217 acc=0.9572 f1=0.9546 auc=0.9934


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 07 | train_loss=0.0570 | val_loss=0.1099 acc=0.9659 f1=0.9652 auc=0.9940


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 08 | train_loss=0.0468 | val_loss=0.1289 acc=0.9648 f1=0.9629 auc=0.9938


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 09 | train_loss=0.0361 | val_loss=0.1207 acc=0.9642 f1=0.9627 auc=0.9924


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 10 | train_loss=0.0331 | val_loss=0.1422 acc=0.9607 f1=0.9585 auc=0.9927


  with autocast(enabled=(use_amp and device.type == "cuda")):


Epoch 11 | train_loss=0.0270 | val_loss=0.1277 acc=0.9653 f1=0.9644 auc=0.9935
Early stopping: no F1 improvement for 4 epochs.


  ckpt = torch.load(best_path, map_location=device)


TEST: {'loss': 0.0917612046820826, 'acc': 0.9646294513955727, 'precision': 0.9415520628683693, 'recall': 0.9856041131105399, 'f1': 0.9630746043707611, 'auc': 0.9960412938811435}
Saved artifacts to: loanword_artifacts
Files: ['best_model.pt', 'config.json', 'env.json', 'meta_maps.json', 'summary.json', 'tokenizer_config.json', 'train_metrics.jsonl']


In [11]:
import os, json
import numpy as np
import torch

# ---------- Загрузка модели из артефактов (если нужно) ----------
def load_loanword_model(artifact_dir: str, device: str | None = None):
    """
    Ожидает в artifact_dir файлы:
      - config.json
      - meta_maps.json
      - best_model.pt
    Возвращает: (model, runtime_dict)
    """
    ad = artifact_dir
    cfg = json.loads(open(os.path.join(ad, "config.json"), "r", encoding="utf-8").read())
    maps = json.loads(open(os.path.join(ad, "meta_maps.json"), "r", encoding="utf-8").read())

    if device is None:
        device_t = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device_t = torch.device(device)

    model = ByteCrossEncoder(
        vocab_size=int(cfg["vocab_size"]),
        max_len=int(cfg["max_len"]),
        d_model=int(cfg["d_model"]),
        n_layers=int(cfg["n_layers"]),
        n_heads=int(cfg["n_heads"]),
        ff_mult=int(cfg["ff_mult"]),
        dropout=float(cfg["dropout"]),
        feat_dim=int(cfg["feat_dim"]),
    )
    ckpt = torch.load(os.path.join(ad, "best_model.pt"), map_location="cpu")
    model.load_state_dict(ckpt["model_state"])
    model.to(device_t)
    model.eval()

    runtime = {
        "device": device_t,
        "config": cfg,
        "maps": maps,
    }
    return model, runtime


# ---------- Предсказание вероятностей ----------
@torch.no_grad()
def predict_loan_probabilities(
    pairs,
    *,
    artifact_dir: str | None = None,
    model: torch.nn.Module | None = None,
    runtime: dict | None = None,
    tgt_lang: str | None = None,
    sem_field: str | None = None,
    sem_cat: str | None = None,
    batch_size: int = 128,
):
    """
    pairs: list[tuple[str,str]]  -> [(src_form, tgt_form), ...]
    Возвращает: list[float] вероятностей заимствования в том же порядке.

    Варианты использования:
      A) Если модель уже в памяти:
         probs = predict_loan_probabilities(pairs, model=model, runtime=runtime, tgt_lang="...", sem_field="...", sem_cat="...")
      B) Если есть сохранённые артефакты:
         probs = predict_loan_probabilities(pairs, artifact_dir=".../artifacts", tgt_lang="...", sem_field="...", sem_cat="...")

    Если tgt_lang/sem_field/sem_cat не заданы, будут использованы пустые строки.
    """
    if model is None:
        if artifact_dir is None:
            raise ValueError("Нужно передать либо model+runtime, либо artifact_dir с артефактами.")
        model, runtime = load_loanword_model(artifact_dir)

    if runtime is None:
        raise ValueError("runtime обязателен (должен содержать config и maps).")

    device = runtime["device"] if "device" in runtime else next(model.parameters()).device
    cfg = runtime["config"]
    maps = runtime["maps"]

    lang2id = maps["lang2id"]
    field2id = maps["field2id"]
    cat2id = maps["cat2id"]

    PAD_ID = int(cfg["pad_id"])
    CLS_ID = int(cfg["cls_id"])
    SEP_ID = int(cfg["sep_id"])
    BYTE_OFFSET = int(cfg["byte_offset"])
    LANG_BASE = int(cfg["lang_base"])
    FIELD_BASE = int(cfg["field_base"])
    CAT_BASE = int(cfg["cat_base"])

    MAX_LEN = int(cfg["max_len"])
    MAX_SRC_BYTES = int(cfg["max_src_bytes"])
    MAX_TGT_BYTES = int(cfg["max_tgt_bytes"])

    def _to_bytes_ids(s: str, max_bytes: int):
        b = normalize_form(s).encode("utf-8", errors="ignore")[:max_bytes]
        return [BYTE_OFFSET + int(x) for x in b]

    def _encode(src_form: str, tgt_form: str, _tgt_lang: str, _sem_field: str, _sem_cat: str):
        # Важно: при отсутствии ключа используется 0 (как было при обучении)
        lang_id = lang2id.get(normalize_form(_tgt_lang), 0)
        field_id = field2id.get(normalize_form(_sem_field), 0)
        cat_id = cat2id.get(normalize_form(_sem_cat), 0)

        tokens = [CLS_ID, LANG_BASE + lang_id, FIELD_BASE + field_id, CAT_BASE + cat_id, SEP_ID]
        seg =    [0,      0,               0,                0,              0]

        src_ids = _to_bytes_ids(src_form, MAX_SRC_BYTES)
        tokens += src_ids + [SEP_ID]
        seg += [1] * len(src_ids) + [1]

        tgt_ids = _to_bytes_ids(tgt_form, MAX_TGT_BYTES)
        tokens += tgt_ids + [SEP_ID]
        seg += [2] * len(tgt_ids) + [2]

        tokens = tokens[:MAX_LEN]
        seg = seg[:MAX_LEN]
        return tokens, seg

    # метаданные по умолчанию
    _tgt_lang = "" if tgt_lang is None else tgt_lang
    _sem_field = "" if sem_field is None else sem_field
    _sem_cat = "" if sem_cat is None else sem_cat

    probs_out = []
    model.eval()

    for i in range(0, len(pairs), batch_size):
        batch = pairs[i : i + batch_size]

        tok_list, seg_list, feat_list = [], [], []
        maxlen = 0
        for src, tgt in batch:
            tok, seg = _encode(src, tgt, _tgt_lang, _sem_field, _sem_cat)
            tok_list.append(tok)
            seg_list.append(seg)
            feat_list.append(compute_features(src, tgt))
            maxlen = max(maxlen, len(tok))

        maxlen = min(maxlen, MAX_LEN)

        token_ids = torch.full((len(batch), maxlen), PAD_ID, dtype=torch.long)
        seg_ids = torch.zeros((len(batch), maxlen), dtype=torch.long)
        feats = torch.tensor(np.stack(feat_list, axis=0), dtype=torch.float32)

        for r in range(len(batch)):
            t = tok_list[r][:maxlen]
            s = seg_list[r][:maxlen]
            token_ids[r, :len(t)] = torch.tensor(t, dtype=torch.long)
            seg_ids[r, :len(s)] = torch.tensor(s, dtype=torch.long)

        pad_mask = token_ids.eq(PAD_ID)

        token_ids = token_ids.to(device)
        seg_ids = seg_ids.to(device)
        pad_mask = pad_mask.to(device)
        feats = feats.to(device)

        logits = model(token_ids, seg_ids, pad_mask, feats)
        p = torch.sigmoid(logits).detach().cpu().numpy().tolist()
        probs_out.extend(p)

    return probs_out


# ---------- Пример ----------
pairs = [("shíhuī", "шохой"), ("tuge", "tuɛ"), ("yàngzi", "янза")]

# Вариант A: если модель уже обучена в ноутбуке и лежит в переменных model, outdir и т.п.
# probs = predict_loan_probabilities(pairs, model=model, runtime={"device": next(model.parameters()).device, "config": config, "maps": {"lang2id": lang2id, "field2id": field2id, "cat2id": cat2id}})

# Вариант B: если есть сохранённые артефакты (путь замените на ваш каталог)
# probs = predict_loan_probabilities(pairs, artifact_dir=OUTDIR)

probs = predict_loan_probabilities(pairs, artifact_dir=OUTDIR)

for (src, tgt), p in zip(pairs, probs):
    print(f"{src}  ->  {tgt} :  p_loan={p:.6f}")


shíhuī  ->  шохой :  p_loan=0.999378
tuge  ->  tuɛ :  p_loan=0.938874
yàngzi  ->  янза :  p_loan=0.996503


  ckpt = torch.load(os.path.join(ad, "best_model.pt"), map_location="cpu")
