# Tokenization (A_bytes + B_raw/B_norm) — Notebook

- **A_bytes**: контрольный вариант (byte-level UTF-8)
- **B_raw / B_norm**: основной вариант (структурные токены Python + словарь по частоте)



In [1]:
from __future__ import annotations

import ast
import hashlib
import io
import json
import keyword
import re
import tokenize
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Optional, Tuple


In [2]:
# -------------------------
# Helpers: IO
# -------------------------

def read_jsonl(path: Path) -> Iterator[dict]:
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)

def write_jsonl(path: Path, rows: Iterable[dict]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def write_json(path: Path, obj: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")


In [3]:
# -------------------------
# Schema: row adapters
# -------------------------

@dataclass(frozen=True)
class RowView:
    repo: str
    path: str
    func_name: str
    doc: str
    code: str
    sha: str
    language: str
    keep: Optional[bool]
    parsable: Optional[bool]


def _get_first(d: dict, keys: List[str], default=None):
    for k in keys:
        if k in d and d[k] is not None:
            return d[k]
    return default


def adapt_row(d: dict) -> Optional[RowView]:
    """
    Supports at least:
    - CodeXGLUE code-to-text: repo/path/func_name/docstring/code_clean/sha256/language/keep/parsable
    - CodeSearchNet: repository_name/func_path_in_repository/func_name/docstring/code_clean/sha256/language/keep/parsable
    """
    language = (d.get("language") or "").lower()
    keep = d.get("keep")
    parsable = d.get("parsable")

    doc = _get_first(d, ["docstring", "func_documentation_string"], default="")
    code = _get_first(d, ["code_clean", "func_code_string", "code"], default="")

    repo = _get_first(d, ["repository_name", "repo", "repository"], default="")
    path = _get_first(d, ["func_path_in_repository", "path", "func_code_url"], default="")
    func_name = _get_first(d, ["func_name"], default="")

    sha = _get_first(d, ["sha256", "sha"], default="")
    if not sha:
        sha = hashlib.sha256((repo + "\n" + path + "\n" + code).encode("utf-8", errors="ignore")).hexdigest()

    if not isinstance(doc, str) or not isinstance(code, str):
        return None
    if not doc.strip() or not code.strip():
        return None

    return RowView(
        repo=repo, path=path, func_name=func_name,
        doc=doc, code=code, sha=sha,
        language=language, keep=keep, parsable=parsable
    )


def default_filter(rv: RowView, language: str = "python") -> bool:
    if language and rv.language and rv.language != language.lower():
        return False
    if rv.keep is False:
        return False
    # по умолчанию parsable не обязателен (зависит от датасета)
    return True


In [4]:
# -------------------------
# Text normalization (docstrings)
# -------------------------

_WS_RE = re.compile(r"\s+")

def normalize_docstring(doc: str) -> str:
    doc = (doc or "").strip()
    doc = _WS_RE.sub(" ", doc)
    return doc


In [5]:
# -------------------------
# Variant B: structured Python tokens
# -------------------------

def python_struct_tokens(code: str, mode: str = "B_raw", max_ids: int = 64) -> List[str]:
    """
    Tokenize Python code into a reversible-ish sequence.

    Emits:
      - <nl>, <indent>, <dedent>
      - raw token strings
      - in B_norm mode: identifiers -> <id_k>, numbers -> <num>, strings -> <str>

    Identifier scoping (B_norm):
      - mapping is reset per function/code snippet
      - first time we see a NAME (not keyword) that is not an attribute access, we assign <id_1>, <id_2>, ...
      - same original name -> same <id_k> within the snippet
      - attribute names after '.' are kept as-is (helps preserve API surface like obj.method)
    """
    code = (code or "").expandtabs(4)

    id_map: Dict[str, int] = {}
    next_id = 1

    out: List[str] = []
    prev_token: Optional[str] = None

    g = tokenize.generate_tokens(io.StringIO(code).readline)
    for tok in g:
        ttype, tstr = tok.type, tok.string

        if ttype in (tokenize.NL, tokenize.NEWLINE):
            out.append("<nl>")
            prev_token = "<nl>"
            continue
        if ttype == tokenize.INDENT:
            out.append("<indent>")
            prev_token = "<indent>"
            continue
        if ttype == tokenize.DEDENT:
            out.append("<dedent>")
            prev_token = "<dedent>"
            continue
        if ttype == tokenize.ENDMARKER:
            break
        if ttype == tokenize.COMMENT:
            continue

        if mode == "B_norm":
            if ttype == tokenize.NAME and not keyword.iskeyword(tstr):
                # Keep attributes after dot to preserve API surface:
                if prev_token == ".":
                    out.append(tstr)
                else:
                    if tstr not in id_map and next_id <= max_ids:
                        id_map[tstr] = next_id
                        next_id += 1
                    if tstr in id_map:
                        out.append(f"<id_{id_map[tstr]}>")
                    else:
                        out.append("<id>")
                prev_token = tstr
                continue
            if ttype == tokenize.NUMBER:
                out.append("<num>")
                prev_token = "<num>"
                continue
            if ttype == tokenize.STRING:
                out.append("<str>")
                prev_token = "<str>"
                continue

        out.append(tstr)
        prev_token = tstr

    return out


# ---- detokenize (для проверки parsable после токенизации) ----

_OPERATORS = {
    "+","-","*","/","//","%","**",
    "==","!=",">=","<=","<",">",
    "=", "+=","-=","*=","/=","%=","//=","**=",
    "(",")","[","]","{","}",
    ",",":",".",";","@",
    "and","or","not","in","is","|","&","^","~",
    "<<",">>",
}

def detokenize_python(tokens: List[str]) -> str:
    """
    Rough reversible detokenizer for our structured token stream.

    Goal: produce code string that is often AST-parsable (not perfect pretty-printer).
    """
    lines: List[str] = []
    cur: List[str] = []
    indent = 0

    def flush_line():
        nonlocal cur
        line = " " * (indent * 4) + "".join(cur).rstrip()
        lines.append(line)
        cur = []

    prev: Optional[str] = None

    for t in tokens:
        if t == "<nl>":
            flush_line()
            prev = "<nl>"
            continue
        if t == "<indent>":
            indent += 1
            prev = "<indent>"
            continue
        if t == "<dedent>":
            indent = max(0, indent - 1)
            prev = "<dedent>"
            continue

        # B_norm placeholders -> make them syntactically valid
        if t.startswith("<id_") or t == "<id>":
            t = "v"
        elif t == "<num>":
            t = "0"
        elif t == "<str>":
            t = "''"

        # spacing heuristic
        if not cur:
            cur.append(t)
            prev = t
            continue

        need_space = False
        if prev in ("<nl>", "<indent>", "<dedent>"):
            need_space = False
        elif t in (")","]","}"):
            need_space = False
        elif prev in ("(", "[", "{", ".", ",", ":"):
            need_space = False
        elif t == ".":
            need_space = False
        else:
            # default add space between NAME/NUMBER/keywords and others
            if prev.isalnum() and t.isalnum():
                need_space = True
            elif prev.isalnum() and t == "(":
                need_space = False
            elif prev in _OPERATORS or t in _OPERATORS:
                need_space = True
            else:
                need_space = True

        if need_space:
            cur.append(" ")
        cur.append(t)
        prev = t

    if cur or not lines:
        flush_line()

    # Remove trailing empty lines
    while lines and not lines[-1].strip():
        lines.pop()

    return "\n".join(lines) + "\n"


In [6]:
# -------------------------
# Splitting (stable by repo)
# -------------------------

def split_by_repo(repo: str, train: float = 0.90, val: float = 0.05) -> str:
    """
    Deterministic split by repository string.
    """
    h = int(hashlib.sha256((repo or "").encode("utf-8", errors="ignore")).hexdigest(), 16)
    r = (h % 10_000) / 10_000.0
    if r < train:
        return "train"
    if r < train + val:
        return "val"
    return "test"


In [7]:
# -------------------------
# Example builder: doc2code
# -------------------------

def build_example_tokens(rv: RowView, mode: str, max_id_vars: int = 64) -> Tuple[List[str], List[str]]:
    """
    Returns (input_tokens, target_tokens) for doc2code task:
        input  = <bos> <task> <doc> DOC </doc> <code>
        target = CODE </code> <eos>
    """
    doc = normalize_docstring(rv.doc)
    doc_toks = doc.split()

    code_toks = python_struct_tokens(rv.code, mode=mode, max_ids=max_id_vars)

    input_toks = ["<bos>", "<task:doc2code>", "<doc>"] + doc_toks + ["</doc>", "<code>"]
    target_toks = code_toks + ["</code>", "<eos>"]
    return input_toks, target_toks


In [8]:
# -------------------------
# Vocab & encoding (Variant B)
# -------------------------

SPECIAL_B = [
    "<pad>", "<bos>", "<eos>", "<unk>", "<sep>",
    "<nl>", "<indent>", "<dedent>",
    "<mask>", "<num>", "<str>", "<id>"
]
# We also allow <id_k> tokens dynamically (B_norm) - they go through <unk> unless present in vocab.

def build_vocab(train_rows: List[RowView], mode: str, vocab_size: int, min_freq: int, max_id_vars: int = 64) -> Tuple[Dict[str,int], Dict[int,str], Counter]:
    cnt = Counter()
    for rv in train_rows:
        inp, tgt = build_example_tokens(rv, mode, max_id_vars=max_id_vars)
        cnt.update(inp)
        cnt.update(tgt)

    # start with specials + structural/task tags
    base_special = SPECIAL_B + ["<task:doc2code>", "<doc>", "</doc>", "<code>", "</code>"]
    token2id: Dict[str,int] = {}
    for t in base_special:
        if t not in token2id:
            token2id[t] = len(token2id)

    # remove base specials to avoid duplicates
    for t in base_special:
        cnt.pop(t, None)

    # keep by frequency then lexicographic for stability
    items = [(t,f) for t,f in cnt.items() if f >= min_freq]
    items.sort(key=lambda x: (-x[1], x[0]))

    for t,f in items:
        if len(token2id) >= vocab_size:
            break
        token2id[t] = len(token2id)

    id2token = {i:t for t,i in token2id.items()}
    return token2id, id2token, cnt


def encode_tokens(tokens: List[str], token2id: Dict[str,int]) -> Tuple[List[int], int]:
    unk_id = token2id.get("<unk>")
    oov = 0
    ids = []
    for t in tokens:
        i = token2id.get(t)
        if i is None:
            ids.append(unk_id)
            oov += 1
        else:
            ids.append(i)
    return ids, oov


In [9]:
# -------------------------
# Variant A: byte-level baseline
# -------------------------

SPECIAL_A = ["<pad>", "<bos>", "<eos>", "<unk>", "<sep>"]  # keep minimal, fixed
BYTE_OFFSET = len(SPECIAL_A)  # bytes 0..255 map to ids [BYTE_OFFSET..BYTE_OFFSET+255]

def build_line_text(rv: RowView) -> str:
    # One-line "doc + code" for baseline. Keep it simple and deterministic.
    doc = normalize_docstring(rv.doc)
    code = (rv.code or "").strip()
    return f"<doc> {doc} </doc> <code> {code} </code>"

def encode_bytes(text: str, max_len: int) -> Tuple[List[int], int, bool]:
    """
    Returns (ids, oov_count, truncated) for baseline A.
    """
    b = (text or "").encode("utf-8", errors="replace")
    # +2 for <bos>/<eos>
    truncated = False
    if len(b) + 2 > max_len:
        b = b[: max_len - 2]
        truncated = True
    ids = [SPECIAL_A.index("<bos>")] + [BYTE_OFFSET + x for x in b] + [SPECIAL_A.index("<eos>")]
    return ids, 0, truncated


In [10]:
# -------------------------
# Metrics helpers
# -------------------------

def safe_ast_parse(code: str) -> bool:
    try:
        ast.parse(code)
        return True
    except Exception:
        return False

def quantiles_int(xs: List[int]) -> Dict[str,int]:
    if not xs:
        return {"p50": 0, "p90": 0, "p95": 0, "p99": 0, "min": 0, "max": 0}
    xs = sorted(xs)
    def q(p: float) -> int:
        i = int(round((len(xs)-1) * p))
        return int(xs[max(0, min(len(xs)-1, i))])
    return {
        "min": int(xs[0]),
        "p50": q(0.50),
        "p90": q(0.90),
        "p95": q(0.95),
        "p99": q(0.99),
        "max": int(xs[-1]),
    }


In [11]:
# -------------------------
# Pipeline runner (notebook-friendly)
# -------------------------

def run_tokenization(
    input_path: str,
    out_dir: str = "out_tokenization",
    language: str = "python",
    max_len: int = 2048,
    train_ratio: float = 0.90,
    val_ratio: float = 0.05,
    vocab_size: int = 32000,
    min_freq: int = 2,
    b_mode: str = "B_raw",        # "B_raw" or "B_norm"
    also_b_norm: bool = True,     # export B_norm alongside chosen b_mode
    export_a: bool = True,
    max_id_vars: int = 64,
) -> dict:
    in_path = Path(input_path)
    out_dir_p = Path(out_dir)
    out_dir_p.mkdir(parents=True, exist_ok=True)

    # 1) load + filter + dedup
    raw_rows: List[RowView] = []
    for d in read_jsonl(in_path):
        rv = adapt_row(d)
        if rv is None:
            continue
        if not default_filter(rv, language=language):
            continue
        raw_rows.append(rv)

    # exact dedup by sha
    seen = set()
    dedup: List[RowView] = []
    for rv in raw_rows:
        if rv.sha in seen:
            continue
        seen.add(rv.sha)
        dedup.append(rv)

    # 2) split by repo (stable)
    splits: Dict[str, List[RowView]] = {"train": [], "val": [], "test": []}
    for rv in dedup:
        s = split_by_repo(rv.repo, train=train_ratio, val=val_ratio)
        splits[s].append(rv)

    # 3) build vocab for B on train split only
    token2id, id2token, freq = build_vocab(
        splits["train"], b_mode, vocab_size=vocab_size, min_freq=min_freq, max_id_vars=max_id_vars
    )
    write_json(out_dir_p/"tokenizer_B"/"token2id.json", token2id)
    write_json(out_dir_p/"tokenizer_B"/"id2token.json", id2token)
    write_json(out_dir_p/"tokenizer_B"/"config.json", {
        "mode": b_mode,
        "vocab_size": vocab_size,
        "min_freq": min_freq,
        "special_tokens": SPECIAL_B,
        "task_tokens": ["<task:doc2code>", "<doc>", "</doc>", "<code>", "</code>"],
        "max_id_vars": max_id_vars
    })

    # 4) export splits + metrics
    stats = {
        "input_path": str(in_path),
        "kept_after_filter": len(raw_rows),
        "kept_after_dedup": len(dedup),
        "splits": {k: len(v) for k, v in splits.items()},
        "max_len": max_len,
        "variant_B": {},
        "variant_A": {},
    }

    def export_B(mode: str):
        for split_name, rvs in splits.items():
            encoded = []
            oov_doc = 0
            oov_code = 0
            total_doc = 0
            total_code = 0
            lens = []
            trunc = 0
            parsable_ok = 0
            parsable_total = 0

            for rv in rvs:
                inp_toks, tgt_toks = build_example_tokens(rv, mode, max_id_vars=max_id_vars)
                inp_ids, _ = encode_tokens(inp_toks, token2id)
                tgt_ids, _ = encode_tokens(tgt_toks, token2id)

                # length policy: hard cut (truncate) to max_len, count as truncation
                total_len = len(inp_ids) + len(tgt_ids)
                if total_len > max_len:
                    trunc += 1
                    overflow = total_len - max_len
                    if overflow > 0:
                        if overflow >= len(tgt_ids):
                            tgt_ids = [token2id["<eos>"]]
                        else:
                            tgt_ids = tgt_ids[:-overflow]
                            if tgt_ids and tgt_ids[-1] != token2id["<eos>"]:
                                tgt_ids[-1] = token2id["<eos>"]

                # OOV accounting: doc tokens and code tokens separately
                doc_tokens = normalize_docstring(rv.doc).split()
                _, doc_oov = encode_tokens(doc_tokens, token2id)
                oov_doc += doc_oov
                total_doc += len(doc_tokens)

                code_tokens = python_struct_tokens(rv.code, mode=mode, max_ids=max_id_vars)
                _, code_oov = encode_tokens(code_tokens, token2id)
                oov_code += code_oov
                total_code += len(code_tokens)

                lens.append(len(inp_ids) + len(tgt_ids))

                parsable_total += 1
                det_code = detokenize_python(code_tokens)
                if safe_ast_parse(det_code):
                    parsable_ok += 1

                encoded.append({
                    "repo": rv.repo,
                    "path": rv.path,
                    "func_name": rv.func_name,
                    "sha256": rv.sha,
                    "split": split_name,
                    "mode": mode,
                    "input_ids": inp_ids,
                    "labels": tgt_ids,
                })

            out_path = out_dir_p / f"{split_name}.{mode}.jsonl"
            write_jsonl(out_path, encoded)

            stats["variant_B"].setdefault(mode, {})
            stats["variant_B"][mode][split_name] = {
                "n_examples": len(rvs),
                "n_encoded": len(encoded),
                "truncation_rate": (trunc / max(1, len(rvs))),
                "len_quantiles": quantiles_int(lens),
                "oov_rate_doc": (oov_doc / max(1, total_doc)),
                "oov_rate_code": (oov_code / max(1, total_code)),
                "parsable_after_detokenize_rate": (parsable_ok / max(1, parsable_total)),
                "out_path": str(out_path),
            }

    # export chosen B mode
    export_B(b_mode)
    if also_b_norm and b_mode != "B_norm":
        export_B("B_norm")

    # export A if requested
    if export_a:
        for split_name, rvs in splits.items():
            encoded = []
            lens = []
            trunc = 0
            for rv in rvs:
                text = build_line_text(rv)
                ids, _, tr = encode_bytes(text, max_len)
                if tr:
                    trunc += 1
                lens.append(len(ids))
                encoded.append({
                    "repo": rv.repo,
                    "path": rv.path,
                    "func_name": rv.func_name,
                    "sha256": rv.sha,
                    "split": split_name,
                    "mode": "A_bytes",
                    "input_ids": ids,
                })
            out_path = out_dir_p / f"{split_name}.A_bytes.jsonl"
            write_jsonl(out_path, encoded)
            stats["variant_A"][split_name] = {
                "n_examples": len(rvs),
                "truncation_rate": trunc / max(1, len(rvs)),
                "len_quantiles": quantiles_int(lens),
                "vocab_size": BYTE_OFFSET + 256,
                "out_path": str(out_path),
            }
        write_json(out_dir_p/"tokenizer_A"/"config.json", {
            "special_tokens": SPECIAL_A,
            "byte_offset": BYTE_OFFSET,
            "vocab_size": BYTE_OFFSET + 256
        })

    write_json(out_dir_p/"stats.json", stats)
    return stats


In [12]:
# -------------------------
# Run 
# -------------------------

INPUT_JSONL = "/Users/polzovatel/projects/GAN/GAN-and-Diffusion-Models-EDA/csn_python_jsonl_full/csn_test_clean.jsonl"  
OUT_DIR = "out_tokenization"

# Вариант B: основной режим
B_MODE = "B_raw"   # or "B_norm"

stats = run_tokenization(
    input_path=INPUT_JSONL,
    out_dir=OUT_DIR,
    language="python",
    max_len=2048,
    train_ratio=0.90,
    val_ratio=0.05,
    vocab_size=32000,
    min_freq=2,
    b_mode=B_MODE,
    also_b_norm=True,
    export_a=True,
    max_id_vars=64,
)

stats["splits"], OUT_DIR




({'train': 12032, 'val': 2010, 'test': 1177}, 'out_tokenization')