In [None]:
import json, os, hashlib
from collections import defaultdict
from typing import List, Dict

INPUT_JSONLS = [
    "/content/out_instructions_855_1355.jsonl",
    "/content/out_instructions_last (5).jsonl",
    "/content/out_instructions_last1000_last_pls.jsonl",
    "/content/out_instructions_second_last1000 (3).jsonl",
]
MERGED_JSONL = "/content/out_instructions_merged.jsonl"

def read_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rows.append(json.loads(line))
            except Exception:
                pass
    return rows

def fallback_row_id(row: Dict) -> str:
    q = (row.get("query_ru") or "").strip()
    p = (row.get("positive_ru") or "").strip()
    h = hashlib.md5()
    h.update((q + "\n" + p).encode("utf-8"))
    return h.hexdigest()

def dedupe_instructions(instrs: List[Dict]) -> List[Dict]:
    seen = set()
    out = []
    for it in instrs:
        key = (
            (it.get("style") or "").strip(),
            (it.get("length_format") or "").strip(),
            (it.get("instruction") or "").strip(),
        )
        if key in seen:
            continue
        seen.add(key)
        out.append({
            "style": key[0],
            "length_format": key[1],
            "instruction": key[2],
            "relevant_docs": (it.get("relevant_docs") or "").strip(),
            "non-relevant_docs": (it.get("non-relevant_docs") or "").strip(),
        })
    return out

bucket: Dict[str, Dict] = {}
count_in = 0
for path in INPUT_JSONLS:
    if not os.path.exists(path):
        print(f"WARNING: file not found: {path}")
        continue
    rows = read_jsonl(path)
    count_in += len(rows)
    for r in rows:
        rid = r.get("_row_id") or fallback_row_id(r)
        if rid not in bucket:
            bucket[rid] = r
            if not isinstance(bucket[rid].get("instructions"), list):
                bucket[rid]["instructions"] = []
        else:
            old = bucket[rid].get("instructions", [])
            new = r.get("instructions", []) or []
            bucket[rid]["instructions"] = old + new

for rid, row in bucket.items():
    row["instructions"] = dedupe_instructions(row.get("instructions", []))
    row["_row_id"] = rid

os.makedirs(os.path.dirname(MERGED_JSONL), exist_ok=True)
with open(MERGED_JSONL, "w", encoding="utf-8") as f:
    for rid, row in bucket.items():
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print(f"Read total rows: {count_in}")
print(f"Merged unique rows: {len(bucket)}")
print(f"Saved -> {MERGED_JSONL}")


Read total rows: 2147
Merged unique rows: 2147
Saved -> /content/out_instructions_merged.jsonl


In [None]:
!pip -q install -U sentence-transformers

In [None]:
!pip -q install -U FlagEmbedding


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/163.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.9/163.9 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m866.1/866.1 kB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m148.3/148.3 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.1/45.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for FlagEmbedding (setup.py) ... [?25l[?25hdone
  Building wheel for warc3-wet-clueweb09 (setup.p

In [None]:
import json, numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer, util

INSTRUCTIONS_JSONL = "/content/out_instructions_merged.jsonl"
SBERT_RESULTS_JSONL = "/content/relevancy_checked_sbert_ru.jsonl"
ROW_LIMIT = None
EPS = 0.01

MODEL_NAME = "ai-forever/sbert_large_nlu_ru"

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(MODEL_NAME, device=device)

def cos(a, b):
    return float(util.cos_sim(a, b).cpu().numpy()[0][0])

def emb(texts):
    return model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)

rows = []
with open(INSTRUCTIONS_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            rows.append(json.loads(line))
        except:
            pass

if ROW_LIMIT is not None:
    rows = rows[:ROW_LIMIT]

checked = []

for r in tqdm(rows, desc="SBERT judge"):
    q = (r.get("query_ru") or "").strip()
    p = (r.get("positive_ru") or "").strip()
    if not q or not p:
        continue
    instrs = r.get("instructions", []) or []
    if not instrs:
        continue

    q_emb = emb([f"query: {q}"])[0]
    p_emb = emb([f"passage: {p}"])[0]
    base = cos(q_emb, p_emb)

    for it in instrs:
        inst = (it.get("instruction") or "").strip()
        if not inst:
            continue
        qi = emb([f"query: {q} {inst}"])[0]
        si = cos(qi, p_emb)
        delta = si - base
        label = "hurt" if delta < -EPS else ("improved" if delta > EPS else "neutral")
        checked.append({
            "_row_id": r.get("_row_id",""),
            "style": (it.get("style","") or "unknown").strip() or "unknown",
            "length_format": it.get("length_format",""),
            "instruction": inst,
            "base_sim": round(base,4),
            "sim_with_instruction": round(si,4),
            "delta": round(delta,4),
            "label": label,
            "success": label != "hurt",
        })

with open(SBERT_RESULTS_JSONL, "w", encoding="utf-8") as f:
    for row in checked:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print("Saved:", SBERT_RESULTS_JSONL)

total = len(checked)
imp = sum(1 for x in checked if x["label"] == "improved")
neu = sum(1 for x in checked if x["label"] == "neutral")
hrt = sum(1 for x in checked if x["label"] == "hurt")
succ_rate = (imp + neu) / total if total else 0.0

print("\nPer-instruction stats")
print(f"Total: {total} | improved: {imp} | neutral: {neu} | hurt: {hrt}")
print(f"Success rate (not worse than baseline): {succ_rate:.3f}")

by_style = defaultdict(list)
for x in checked:
    by_style[x["style"]].append(x)

print("\nBy style:")
for style in sorted(by_style.keys()):
    arr = by_style[style]
    n = len(arr)
    s_imp = sum(1 for x in arr if x["label"] == "improved")
    s_neu = sum(1 for x in arr if x["label"] == "neutral")
    s_hrt = sum(1 for x in arr if x["label"] == "hurt")
    s_succ = (s_imp + s_neu) / n if n else 0.0
    print(f"  {style:<15} | n={n:<4} | imp: {s_imp:<4} neu: {s_neu:<4} hurt: {s_hrt:<4} | success:{s_succ:.3f}")


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/195 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/863 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

SBERT judge: 100%|██████████| 2147/2147 [04:22<00:00,  8.17it/s]

Saved: /content/relevancy_checked_sbert_ru.jsonl

Per-instruction stats
Total: 6408 | improved: 5624 | neutral: 303 | hurt: 481
Success rate (not worse than baseline): 0.925

By style:
  background_long | n=2138 | imp: 1831 neu: 103  hurt: 204  | success:0.905
  persona         | n=2135 | imp: 1869 neu: 104  hurt: 162  | success:0.924
  short_strict    | n=2135 | imp: 1924 neu: 96   hurt: 115  | success:0.946





In [None]:
# ==== Build a filtered training set (keeps hard negatives) ====
import json, re
from collections import defaultdict, Counter

# Inputs/outputs
SRC_ROWS      = "/content/out_instructions_merged.jsonl"
SRC_SCORES    = "/content/relevancy_checked_sbert_ru.jsonl"
OUT_TRAIN     = "/content/instructions_ru_sbert_filtered_with_negs.jsonl"

DELTA_MIN       = 0.01
CAP_PER_STYLE   = 1
CAP_PER_ROW     = 3
MIN_NONWS_LEN   = 30
NEG_TOP_K       = 5
INCLUDE_EN      = True

def nonws_len(s):
    return len(re.sub(r"\s+", "", s or ""))

def normalize_hnegs(val):
    out = []
    if isinstance(val, list):
        it = val
    elif isinstance(val, str) and val.strip():
        try:
            j = json.loads(val)
            it = j if isinstance(j, list) else [val]
        except Exception:
            it = [val]
    else:
        it = []
    seen = set()
    for x in it:
        s = str(x).strip()
        if s and s not in seen:
            out.append(s); seen.add(s)
    return out

rows = []
with open(SRC_ROWS, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            rows.append(json.loads(line))
        except:
            pass

by_id = {r.get("_row_id",""): r for r in rows if r.get("_row_id")}

judged = []
with open(SRC_SCORES, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            judged.append(json.loads(line))
        except:
            pass

grouped = defaultdict(list)
for j in judged:
    rid = j.get("_row_id","")
    inst = j.get("instruction","") or ""
    if not rid or not inst:
        continue
    if nonws_len(inst) < MIN_NONWS_LEN:
        continue
    if j.get("sim_with_instruction") is None or j.get("base_sim") is None:
        continue
    if float(j.get("delta", 0.0)) < DELTA_MIN:
        continue
    grouped[rid].append(j)

kept_total = 0
kept_rows  = 0
style_counts = Counter()
neg_stats = Counter()

with open(OUT_TRAIN, "w", encoding="utf-8") as out:
    for rid, cand in grouped.items():
        base = by_id.get(rid)
        if not base:
            continue

        per_style = defaultdict(list)
        for c in cand:
            per_style[(c.get("style") or "unknown").strip()].append(c)
        chosen = []
        for style, arr in per_style.items():
            arr.sort(key=lambda x: float(x.get("delta", 0.0)), reverse=True)
            chosen.extend(arr[:CAP_PER_STYLE])

        if CAP_PER_ROW is not None and len(chosen) > CAP_PER_ROW:
            chosen.sort(key=lambda x: float(x.get("delta", 0.0)), reverse=True)
            chosen = chosen[:CAP_PER_ROW]
        if not chosen:
            continue

        negs_ru = normalize_hnegs(base.get("hard_negs_ru", []))[:NEG_TOP_K]
        negs_en = normalize_hnegs(base.get("hard_negs_en", []))[:NEG_TOP_K] if INCLUDE_EN else []

        for c in chosen:
            out_obj = {
                "_row_id": rid,
                "query_ru": base.get("query_ru",""),
                "positive_ru": base.get("positive_ru",""),
                "hard_negs_ru": negs_ru,
                "style": c.get("style",""),
                "length_format": c.get("length_format",""),
                "instruction": c.get("instruction",""),
                "base_sim": c.get("base_sim"),
                "sim_with_instruction": c.get("sim_with_instruction"),
                "delta": c.get("delta"),
            }
            if INCLUDE_EN:
                out_obj.update({
                    "query_en": base.get("query_en",""),
                    "positive_en": base.get("positive_en",""),
                    "hard_negs_en": negs_en,
                })

            out.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
            kept_total += 1
            style_counts[out_obj["style"]] += 1
            neg_stats['ru_total'] += len(negs_ru)
            if INCLUDE_EN:
                neg_stats['en_total'] += len(negs_en)
        kept_rows += 1

avg_ru = (neg_stats['ru_total'] / kept_total) if kept_total else 0.0
avg_en = (neg_stats['en_total'] / kept_total) if kept_total and INCLUDE_EN else 0.0
print("Wrote:", OUT_TRAIN)
print(f"Rows with ≥1 kept instruction: {kept_rows}")
print(f"Total kept instructions: {kept_total}")
print("By style:", dict(style_counts))
print(f"Avg hard_negs_ru per kept example: {avg_ru:.2f}")
if INCLUDE_EN:
    print(f"Avg hard_negs_en per kept example: {avg_en:.2f}")


Wrote: /content/instructions_ru_sbert_filtered_with_negs.jsonl
Rows with ≥1 kept instruction: 2052
Total kept instructions: 5624
By style: {'short_strict': 1924, 'persona': 1869, 'background_long': 1831}
Avg hard_negs_ru per kept example: 5.00
Avg hard_negs_en per kept example: 5.00
