In [None]:
from pathlib import Path

# ========= USER CONFIG =========
# Use absolute paths
# You need to load your local fastText model
RESULT_DIR = Path("/home/zhong/test-folder/lang_trans/outputs/llama2-7b_dims400_para/mt")
LID_PATH   = Path("/home/zhong/fasttext/lid.176.bin")

TARGET_LANGS = ["zh", "fr", "es", "de", "ja"]   # or ["all"]

DATA_SOURCES = ["all"]

LID_THRESHOLD = 0.50

ENABLE_EN2X = True
ENABLE_X2EN = False
# ===============================

In [32]:
import json
import re
from decimal import Decimal, ROUND_HALF_UP

import fasttext
import numpy as np
import sacrebleu
from tqdm.auto import tqdm

# Load fastText lid model
lid = fasttext.load_model(str(LID_PATH))

_bleu_cache = {}
def get_bleu_metric(lang: str):
    if lang not in _bleu_cache:
        if lang == "zh":
            _bleu_cache[lang] = sacrebleu.BLEU(tokenize="zh", effective_order=True)
        elif lang == "ja":
            _bleu_cache[lang] = sacrebleu.BLEU(tokenize="ja-mecab", effective_order=True)
        else:
            _bleu_cache[lang] = sacrebleu.BLEU(effective_order=True)
    return _bleu_cache[lang]

In [33]:
# en->x: en.control.<layer>.<coeff>.{lang}.json or .jsonl
EN2X_RE = re.compile(
    r"^en\.control\.(\d{1,2})\.(\d+)\.(\d{2})\.(zh|ja|fr|es|de|en|ko|tr|id|ar)\.(jsonl|json)$"
)

# x->en: {lang}.control.<layer>.<coeff>.en.jsonl (你现在的新命名)
X2EN_RE = re.compile(
    r"^(zh|ja|fr|es|de|en|ko|tr|id|ar)\.control\.(\d{1,2})\.(\d+)\.(\d{2})\.en\.(jsonl|json)$"
)

ALL_LANGS_ALLOWED = ["zh","ja","fr","es","de","en","ko","tr","id","ar"]

def norm_langs(langs):
    langs = [l.lower() for l in langs]
    if "all" in langs:
        return ALL_LANGS_ALLOWED[:]
    return [l for l in langs if l in ALL_LANGS_ALLOWED]

def norm_sources(srcs):
    srcs = [s.lower() for s in srcs]
    if "all" in srcs:
        return None
    return set(srcs)

TARGET_LANGS_NORM = norm_langs(TARGET_LANGS)
DATA_SOURCES_NORM = norm_sources(DATA_SOURCES)

def parse_filename(fp_name: str):
    """
    Returns:
      direction: "en2x" or "x2en"
      layer: int
      coeff: float (rounded to 0.01 exactly as in filename)
      src_lang: str
      tgt_lang: str
    or None if not matched.
    """
    if ENABLE_EN2X:
        m = EN2X_RE.match(fp_name)
        if m:
            layer = int(m.group(1))
            coeff_str = f"{m.group(2)}.{m.group(3)}"
            coeff = float(Decimal(coeff_str).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))
            tgt_lang = m.group(4)
            return ("en2x", layer, coeff, "en", tgt_lang)

    if ENABLE_X2EN:
        m = X2EN_RE.match(fp_name)
        if m:
            src_lang = m.group(1)
            layer = int(m.group(2))
            coeff_str = f"{m.group(3)}.{m.group(4)}"
            coeff = float(Decimal(coeff_str).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))
            return ("x2en", layer, coeff, src_lang, "en")

    return None

In [None]:
# Scan result files
RESULT_DIR = RESULT_DIR.expanduser().resolve()
paths = list(RESULT_DIR.rglob("*.json*"))  # json / jsonl
print(f"Found {len(paths)} files under {RESULT_DIR}")

entries = []  # (fp, direction, layer, coeff, src_lang, tgt_lang)
layer_set = set()
coeff_set = set()

for fp in tqdm(paths, desc="Scanning"):
    parsed = parse_filename(fp.name)
    if not parsed:
        continue
    direction, layer, coeff, src_lang, tgt_lang = parsed

    output_lang = tgt_lang
    if output_lang not in TARGET_LANGS_NORM:
        continue

    entries.append((fp, direction, layer, coeff, src_lang, tgt_lang))
    layer_set.add(layer)
    coeff_set.add(coeff)

if not entries:
    raise ValueError("No matched files. Check RESULT_DIR / patterns / TARGET_LANGS.")

LAYERS = sorted(layer_set)
COEFFS = sorted(coeff_set)

print(f"Layers   : {LAYERS[0]} .. {LAYERS[-1]} (count={len(LAYERS)})")
print(f"Coeffs   : {COEFFS[0]} .. {COEFFS[-1]} (unique={len(COEFFS)})")
print(f"Entries  : {len(entries)}")

Found 5 files under /mnt/zamia/zhong/test-folder/lang_trans/outputs/llama2-7b_dims400_para/mt


Scanning:   0%|          | 0/5 [00:00<?, ?it/s]

Layers   : 18 .. 18 (count=1)
Coeffs   : 0.4 .. 0.4 (unique=1)
Entries  : 5


In [35]:
# stats[(layer, coeff)] = [total, success, bleu_sum_success, bleu_cnt_success]
from pathlib import Path
import json

stats = {(L, C): [0, 0, 0.0, 0] for L in LAYERS for C in COEFFS}

def iter_samples(fp: Path):
    """
    Yield per-sample dict from either:
      - JSONL: one JSON object per line
      - JSON : list[dict] OR dict containing a list (results/data) OR single dict
    """
    suffix = fp.suffix.lower()

    if suffix == ".jsonl":
        with fp.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                yield json.loads(line)
        return

    if suffix == ".json":
        with fp.open("r", encoding="utf-8") as f:
            obj = json.load(f)

        if isinstance(obj, list):
            for x in obj:
                if isinstance(x, dict):
                    yield x
            return

        if isinstance(obj, dict):
            # common containers
            for k in ("results", "data", "samples", "items"):
                if k in obj and isinstance(obj[k], list):
                    for x in obj[k]:
                        if isinstance(x, dict):
                            yield x
                    return
            # fallback: treat as a single sample dict
            yield obj
        return

    raise ValueError(f"Unsupported file suffix: {fp.suffix}")


def fasttext_lang_and_prob(text: str):
    text = text.replace("\n", " ")
    labels, probs = lid.predict(text[:4096])
    pred_lang = labels[0].replace("__label__", "")
    prob = float(probs[0])
    return pred_lang, prob

for fp, direction, layer, coeff, src_lang, tgt_lang in tqdm(entries, desc="Aggregating"):
    key = (layer, coeff)
    bleu_metric = get_bleu_metric(tgt_lang)

    for js in iter_samples(fp):
        # 数据源过滤（如果你文件里有 data 字段）
        if DATA_SOURCES_NORM is not None:
            if js.get("data") not in DATA_SOURCES_NORM:
                continue

        pred_lang, prob = fasttext_lang_and_prob(js.get("trans", ""))
        is_lang_ok = (pred_lang == tgt_lang) and (prob >= LID_THRESHOLD)

        stats[key][0] += 1  # total

        if is_lang_ok:
            stats[key][1] += 1  # success
            hyp = js.get("trans", "")
            ref = js.get("refer", "")
            score = bleu_metric.sentence_score(hyp, [ref]).score
            stats[key][2] += score
            stats[key][3] += 1

            # 数据源过滤（如果你文件里有 data 字段）
            if DATA_SOURCES_NORM is not None:
                if js.get("data") not in DATA_SOURCES_NORM:
                    continue

            # 语言识别 + 阈值成功判定
            pred_lang, prob = fasttext_lang_and_prob(js.get("trans", ""))
            is_lang_ok = (pred_lang == tgt_lang) and (prob >= LID_THRESHOLD)

            stats[key][0] += 1  # total

            if is_lang_ok:
                stats[key][1] += 1  # success

                # BLEU 只在成功样本上累计
                hyp = js.get("trans", "")
                ref = js.get("refer", "")
                score = bleu_metric.sentence_score(hyp, [ref]).score
                stats[key][2] += score
                stats[key][3] += 1

# build matrices
H, W = len(LAYERS), len(COEFFS)
ACC_mat = np.zeros((H, W), dtype=np.float32)
BLEU_mat = np.zeros((H, W), dtype=np.float32)
COMB_mat = np.zeros((H, W), dtype=np.float32)

row_index = {L: i for i, L in enumerate(LAYERS)}
col_index = {C: j for j, C in enumerate(COEFFS)}

for (L, C), (tot, suc, bleu_sum, bleu_cnt) in stats.items():
    i, j = row_index[L], col_index[C]
    acc = suc / tot if tot else 0.0
    bleu = bleu_sum / bleu_cnt if bleu_cnt else 0.0  # success-only average
    comb = acc * bleu if (tot and bleu_cnt) else 0.0

    ACC_mat[i, j] = acc
    BLEU_mat[i, j] = bleu
    COMB_mat[i, j] = comb

print("ACC_mat shape:", ACC_mat.shape, "BLEU_mat shape:", BLEU_mat.shape)

Aggregating:   0%|          | 0/5 [00:00<?, ?it/s]

ACC_mat shape: (1, 1) BLEU_mat shape: (1, 1)


In [None]:
# Draw heatmap and 3D surface
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FormatStrFormatter, FuncFormatter

X_coeff, Y_layer = np.meshgrid(np.array(COEFFS, dtype=float), np.array(LAYERS, dtype=float))

def plot_surface_paper(X, Y, Z, zlabel="ACC", cmap="viridis", view=(28, 230), figsize=(7.9, 3.6)):
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(1, 2, figure=fig, width_ratios=[1.5, 1.0])

    # 3D surface
    ax3d = fig.add_subplot(gs[0, 0], projection="3d")
    surf = ax3d.plot_surface(X, Y, Z, cmap=cmap, linewidth=0, antialiased=True)
    ax3d.set_proj_type("ortho")
    ax3d.view_init(elev=view[0], azim=view[1])
    ax3d.set_box_aspect((1.25, 1.0, 0.7))
    ax3d.set_xlabel(r"$\alpha$")
    ax3d.set_ylabel("Layer")
    ax3d.set_zlabel(zlabel)

    for a in (ax3d.xaxis, ax3d.yaxis, ax3d.zaxis):
        a.pane.fill = False
        a.pane.set_edgecolor("white")

    cbar = fig.colorbar(surf, ax=ax3d, shrink=0.80, pad=0.06)
    cbar.set_label(zlabel)

    # Heatmap (aligned)
    ax2d = fig.add_subplot(gs[0, 1])
    xs = np.unique(X.ravel())
    ys = np.unique(Y.ravel())
    dx = np.diff(xs).mean() if len(xs) > 1 else 1.0
    dy = np.diff(ys).mean() if len(ys) > 1 else 1.0
    extent = (xs[0]-dx/2, xs[-1]+dx/2, ys[0]-dy/2, ys[-1]+dy/2)

    im = ax2d.imshow(
        Z,
        origin="lower",
        aspect="auto",
        cmap=cmap,
        extent=extent,
        interpolation="nearest",
    )

    Xc, Yc = np.meshgrid(xs, ys)
    cs = ax2d.contour(
        Xc, Yc, Z,
        colors="k",
        linewidths=0.6,
        alpha=0.35,
        levels=7,
    )
    ax2d.clabel(cs, inline=True, fontsize=8, fmt="%.2f")

    ax2d.set_xlabel(r"Scaling Coefficient $\alpha$")
    ax2d.set_ylabel("Layer")

    # optional: y tick shows 1-indexed layer
    ax2d.yaxis.set_major_formatter(FuncFormatter(lambda y, pos: f"{int(round(y))+1}"))

    fig.subplots_adjust(left=0.1, right=0.98, wspace=0.25)
    plt.show()

plot_surface_paper(X_coeff, Y_layer, ACC_mat, zlabel="ACC")
plot_surface_paper(X_coeff, Y_layer, BLEU_mat, zlabel="BLEU (success-only)")
plot_surface_paper(X_coeff, Y_layer, COMB_mat, zlabel="ACC*BLEU")

In [None]:
# Report the best (layer, strength) by ACC*BLEU
from decimal import Decimal, ROUND_HALF_UP

def _fmt(x, nd=4):
    return float(Decimal(str(x)).quantize(Decimal("0." + "0"*nd), rounding=ROUND_HALF_UP))

best = {
    "layer": None,
    "strength": None,
    "acc": 0.0,
    "bleu_success": 0.0,
    "a_times_b": -1.0,
    "total": 0,
    "correct": 0,
    "bleu_cnt": 0,
}

for (layer, strength), (total, correct, bleu_sum, bleu_cnt) in stats.items():
    if total == 0:
        continue
    acc = correct / total
    bleu_success = (bleu_sum / bleu_cnt) if bleu_cnt > 0 else 0.0
    a_times_b = acc * bleu_success  # success-only BLEU
    if a_times_b > best["a_times_b"]:
        best.update({
            "layer": layer,
            "strength": strength,
            "acc": acc,
            "bleu_success": bleu_success,
            "a_times_b": a_times_b,
            "total": total,
            "correct": correct,
            "bleu_cnt": bleu_cnt,
        })

print("=== Global Best (by ACC*BLEU) ===")
print(f"Layer: {best['layer']}")
print(f"Strength: {best['strength']:.2f}")
print(f"ACC: {_fmt(best['acc'], 4)}  ({best['correct']}/{best['total']})")
print(f"BLEU (success-only): {_fmt(best['bleu_success'], 4)}  (count={best['bleu_cnt']})")
print(f"ACC*BLEU: {_fmt(best['a_times_b'], 4)}")

=== Global Best (by ACC*BLEU) ===
Layer: 18
Strength: 0.40
ACC: 0.9954  (2378/2389)
BLEU (success-only): 15.5329  (count=2378)
ACC*BLEU: 15.4614


In [37]:
# Report for specific (layer, strength) you want
from decimal import Decimal, ROUND_HALF_UP

# ===== user inputs =====
QUERY_LAYER = 18
QUERY_STRENGTH = 0.40  # e.g., 0.80 / 1.00 / 1.20 ...
# =======================

def quantize_strength(x: float) -> float:
    return float(Decimal(str(x)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))

layer = int(QUERY_LAYER)
strength = quantize_strength(float(QUERY_STRENGTH))

key = (layer, strength)
if key not in stats:
    print(f"[Not found] (layer={layer}, strength={strength:.2f}) not in stats.")
    print("Hint: available layers range:", (min(LAYERS), max(LAYERS)))
    print("Hint: strength examples:", STRENGTHS[:5], "...", STRENGTHS[-5:])
else:
    total, correct, bleu_sum, bleu_cnt = stats[key]
    acc = (correct / total) if total else 0.0
    bleu_success = (bleu_sum / bleu_cnt) if bleu_cnt else 0.0
    a_times_b = acc * bleu_success

    def _fmt(x, nd=4):
        return float(Decimal(str(x)).quantize(Decimal("0." + "0"*nd), rounding=ROUND_HALF_UP))

    print("=== Query Result ===")
    print(f"Layer: {layer}")
    print(f"Strength: {strength:.2f}")
    print(f"ACC: {_fmt(acc, 4)}  ({correct}/{total})")
    print(f"BLEU (success-only): {_fmt(bleu_success, 4)}  (count={bleu_cnt})")
    print(f"ACC*BLEU: {_fmt(a_times_b, 4)}")

=== Query Result ===
Layer: 18
Strength: 0.40
ACC: 0.9954  (2378/2389)
BLEU (success-only): 15.5329  (count=2378)
ACC*BLEU: 15.4614
