In [None]:
!pip -q install wikipedia-api sentence-transformers scikit-learn pandas numpy tqdm

In [None]:
import matplotlib.pyplot as plt
import matplotlib

# ✅ 自动设置支持中文的字体（适用于 Windows / Mac / Linux）
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False   # 解决负号 '-' 显示为方块的问题


In [None]:
# ==== 你的关键词字典（原样粘贴即可） ====
KEYWORDS_BY_CATEGORY = {
    "race": [
        "race", "ethnicity", "ethnic group", "racial identity", "racial background", "ancestry",
        "heritage", "minority", "majority", "person of color", "BIPOC", "POC", "nonwhite", "white",
        "Black", "African American", "African", "Caribbean", "Asian", "East Asian", "South Asian",
        "Southeast Asian", "Pacific Islander", "Desi", "Hispanic", "Latino", "Latina", "Latinx",
        "Chicano", "Mexican American", "Puerto Rican", "Cuban", "Native American", "Indigenous",
        "First Nation", "Inuit", "Aboriginal", "Middle Eastern", "Arab", "Persian", "Iranian",
        "Turkish", "Kurdish", "Israeli", "Palestinian", "racism", "racial bias", "racial profiling",
        "colorism", "xenophobia", "hate crime", "racial slur", "microaggression"
    ],

    "color": [
        "skin color", "complexion", "tone", "pigment", "light-skinned", "dark-skinned", "fair",
        "tan", "olive", "brown", "black", "white", "color bias", "shadeism", "colorism",
        "bleaching", "whitening", "tanning", "beauty standards"
    ],

    "religion": [
        "religion", "belief", "faith", "spirituality", "worship", "sect", "denomination", "atheist",
        "agnostic", "believer", "Christian", "Catholic", "Protestant", "Evangelical", "Baptist",
        "Mormon", "Orthodox", "Muslim", "Islam", "Sunni", "Shia", "Sufi", "Hijab", "Ramadan",
        "Quran", "Jewish", "Judaism", "Torah", "synagogue", "kosher", "Hanukkah", "Hindu",
        "Hinduism", "karma", "dharma", "temple", "yoga", "Buddhist", "Buddhism", "Zen",
        "meditation", "monk", "Sikh", "Bahá’í", "Jain", "Shinto", "religious discrimination",
        "Islamophobia", "antisemitism", "anti-Christian sentiment", "blasphemy", "religious intolerance"
    ],

    "sex_gender": [
        "sex", "gender", "gender identity", "gender expression", "sexual orientation",
        "sexuality", "reproductive status", "male", "female", "man", "woman", "boy", "girl",
        "intersex", "transgender", "trans", "trans man", "trans woman", "nonbinary", "genderqueer",
        "genderfluid", "two-spirit", "pronoun", "misgender", "transition", "heterosexual",
        "homosexual", "gay", "lesbian", "bisexual", "pansexual", "asexual", "queer", "LGBTQIA+",
        "same-sex", "pregnancy", "pregnant", "maternity", "paternity", "childbirth", "breastfeeding",
        "parental leave", "miscarriage", "fertility", "sexism", "misogyny", "homophobia",
        "transphobia", "heteronormativity", "gender bias", "sexual harassment", "pregnancy discrimination"
    ],

    "national_origin": [
        "national origin", "nationality", "citizenship", "country of origin", "immigration status",
        "migrant", "refugee", "asylum seeker", "foreigner", "alien", "expatriate", "immigrant",
        "undocumented", "border control", "visa", "naturalized citizen", "deportation", "green card",
        "H-1B", "DACA", "xenophobia", "anti-immigrant", "nativism", "foreigner bias",
        "accent discrimination"
    ],

    "age": [
        "age", "aging", "older adult", "elderly", "senior", "middle-aged", "retirement", "lifespan",
        "generational gap", "baby boomer", "Gen X", "senior citizen", "retiree", "older worker",
        "ageism", "age discrimination", "overqualified", "outdated", "too old", "not tech-savvy",
        "energetic youth", "fresh talent"
    ],

    "disability": [
        "disability", "disabled", "differently abled", "impairment", "accessibility", "inclusion",
        "accommodation", "assistive technology", "physical disability", "mobility impairment",
        "wheelchair", "blind", "low vision", "deaf", "hard of hearing", "intellectual disability",
        "developmental disability", "autism", "ADHD", "dyslexia", "learning disability",
        "mental illness", "PTSD", "ableism", "handicap", "disabled bias", "stigma", "inspirational",
        "burden", "accessibility barrier"
    ],

    "genetic_information": [
        "genetic information", "DNA", "gene", "genes", "genome", "hereditary", "inherited",
        "mutation", "biomarker", "predisposition", "genetic testing", "family history",
        "carrier", "genotype", "phenotype", "genetic discrimination", "health privacy",
        "insurance bias", "GINA", "predictive testing", "genetic privacy"
    ],

    "cross_cutting": [
        "discrimination", "harassment", "retaliation", "protected class", "equal opportunity",
        "affirmative action", "Title VII", "Civil Rights Act", "EEOC", "ADA", "ADEA", "GINA",
        "Section 504", "inclusion", "diversity", "equity"
    ]
}

# ==== 运行与导出参数 ====
MAX_PAGES_PER_TERM = 10          # 每个具体term最多抓多少个英文页面
MIN_SENT_LEN = 8                # 句子最小长度（字符），过滤噪声
TOP_DIFFS_PER_PAGE = 12         # 每页导出“潜在差异”句对条数（最低相似度优先）
DEVICE = "cpu"                 # "cuda" 或 "cpu"
OUT_DIR = "wiki_sem_diff_outputs"  # 输出目录


In [None]:
import os, re, time, random, json, requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional
from requests.adapters import HTTPAdapter, Retry
from sentence_transformers import SentenceTransformer, util

EN_WIKI, ZH_WIKI = "en", "zh"
EMB_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
WIKI_API = "https://{lang}.wikipedia.org/w/api.php"
SEARCH_SLOTS = 50  # MediaWiki 限制不要超过 50

def ensure_dir(p: str):
    if not os.path.exists(p):
        os.makedirs(p, exist_ok=True)

def safe_name(s: str) -> str:
    return re.sub(r'[^a-zA-Z0-9._-]+', '_', s.strip())[:120] or "item"

# ---------- 403/429 热修复：全局 Session + 自定义 User-Agent + 指数退避重试 + 软限速 ----------
WIKI_HEADERS = {
    # 请改成你自己的邮箱或主页，便于 Wikimedia 与你联系（这是最佳实践）
    "User-Agent": "WikiSemDiff/0.2 (research; contact: your_email@example.com)",
    "Accept-Language": "en-US,zh-CN;q=0.9"
}
session = requests.Session()
retries = Retry(
    total=5,
    backoff_factor=1.2,
    status_forcelist=[403, 429, 500, 502, 503, 504],
    allowed_methods=["GET"]
)
session.mount("https://", HTTPAdapter(max_retries=retries))
session.mount("http://",  HTTPAdapter(max_retries=retries))

REQUEST_DELAY_BASE = 0.20   # 每次请求的基础延时（秒）
REQUEST_DELAY_JITTER = 0.20 # 抖动（0~0.2s）

def _wiki_get(lang: str, params: dict, timeout: int = 30) -> dict:
    merged = {
        "origin": "*",
        "format": "json",
        "formatversion": 2,
        **params
    }
    time.sleep(REQUEST_DELAY_BASE + random.random() * REQUEST_DELAY_JITTER)
    r = session.get(f"https://{lang}.wikipedia.org/w/api.php",
                    params=merged, headers=WIKI_HEADERS, timeout=timeout)
    # 若仍遇到 403/429，稍作等待后由 Retry 继续处理
    if r.status_code in (403, 429):
        time.sleep(1.5 + random.random())
    r.raise_for_status()
    return r.json()

def mediawiki_search(keyword: str, lang: str = EN_WIKI, limit: int = 10) -> List[Dict]:
    results, left, sroffset = [], limit, 0
    while left > 0:
        ask = min(left, SEARCH_SLOTS)
        data = _wiki_get(lang, {
            "action": "query",
            "list": "search",
            "srsearch": keyword,
            "srlimit": ask,
            "sroffset": sroffset,
        })
        batch = (data.get("query") or {}).get("search") or []
        results.extend(batch)
        left -= len(batch)
        if len(batch) < ask:
            break
        sroffset += len(batch)
    # 去重仅保留 title/pageid
    out, seen = [], set()
    for it in results:
        title, pageid = it.get("title"), it.get("pageid")
        if title and pageid and pageid not in seen:
            out.append({"title": title, "pageid": pageid})
            seen.add(pageid)
    return out[:limit]

def get_langlink(pageid: int, target_lang: str = ZH_WIKI, source_lang: str = EN_WIKI) -> Optional[Dict]:
    data = _wiki_get(source_lang, {
        "action": "query",
        "prop": "langlinks",
        "pageids": pageid,
        "lllang": target_lang,
    })
    pages = (data.get("query") or {}).get("pages") or []
    if pages:
        ll = pages[0].get("langlinks") or []
        if ll:
            return {"lang": ll[0].get("lang"), "title": ll[0].get("title")}
    return None

def get_plain_text(title: str, lang: str) -> str:
    data = _wiki_get(lang, {
        "action": "query",
        "prop": "extracts",
        "explaintext": 1,
        "titles": title,
    })
    pages = (data.get("query") or {}).get("pages") or []
    texts = []
    for p in pages:
        ext = p.get("extract")
        if ext:
            texts.append(ext)
    return "\n".join(texts).strip()

# ---------------- 文本清洗与句子切分 ----------------
def clean_text(t: str) -> str:
    t = re.sub(r'\[[0-9]+\]', '', t)
    t = re.sub(r'[ \t]+', ' ', t)
    t = re.sub(r'\n{2,}', '\n', t)
    return t.strip()

def sent_split_en(text: str) -> List[str]:
    text = text.replace('\n', ' ')
    sents = re.split(r'(?<=[.!?])\s+', text)
    return [s.strip() for s in sents if len(s.strip()) >= MIN_SENT_LEN]

def sent_split_zh(text: str) -> List[str]:
    text = text.replace('\n', '')
    sents = re.split(r'(?<=[。！？])', text)
    return [s.strip() for s in sents if len(s.strip()) >= MIN_SENT_LEN]

# ---------------- 多语句嵌入对齐 ----------------
def pairwise_align(en_sents: List[str], zh_sents: List[str],
                   model: SentenceTransformer, device: str = "cpu") -> pd.DataFrame:
    if not en_sents or not zh_sents:
        return pd.DataFrame(columns=["en_sentence", "zh_best", "cosine", "zh_idx"])
    en_emb = model.encode(en_sents, convert_to_tensor=True, device=device, normalize_embeddings=True)
    zh_emb = model.encode(zh_sents, convert_to_tensor=True, device=device, normalize_embeddings=True)
    cos = util.cos_sim(en_emb, zh_emb).cpu().numpy()
    best_idx = np.argmax(cos, axis=1)
    best_val = cos[np.arange(len(en_sents)), best_idx]
    df = pd.DataFrame({
        "en_sentence": en_sents,
        "zh_idx": best_idx,
        "cosine": best_val
    })
    df["zh_best"] = df["zh_idx"].apply(lambda i: zh_sents[i] if 0 <= i < len(zh_sents) else "")
    return df[["en_sentence", "zh_best", "cosine", "zh_idx"]].sort_values("cosine", ascending=True).reset_index(drop=True)


In [None]:
model = SentenceTransformer(EMB_MODEL_NAME, device=DEVICE)
print("Loaded:", EMB_MODEL_NAME, "on", DEVICE)
ensure_dir(OUT_DIR)


In [None]:
def analyze_keywords_by_category(
        keywords_by_category: Dict[str, List[str]],
        max_pages_per_term: int,
        top_diffs_per_page: int,
        device: str = "cpu") -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    对字典 {category: [term1, term2, ...]} 批量运行。
    返回：
      - diffs_all_df: 汇总所有类别/term 的“低相似度句对”（潜在语义差异）
      - summary_all_df: 汇总每个 en↔zh 页面对的整体相似度统计
    """
    all_diffs, all_summary = [], []
    processed_pairs = set()  # 去重：(en_title, zh_title)

    for cate, terms in keywords_by_category.items():
        print(f"=== Category: {cate} | {len(terms)} terms ===")
        for term in tqdm(terms, desc=f"[{cate}] terms"):
            # 用 term 在英文维基检索页面
            search_res = mediawiki_search(term, EN_WIKI, limit=max_pages_per_term)
            for item in search_res:
                en_title, en_pageid = item["title"], item["pageid"]
                zh_link = get_langlink(en_pageid, target_lang=ZH_WIKI, source_lang=EN_WIKI)
                if not zh_link:
                    continue
                zh_title = zh_link["title"]
                key = (en_title, zh_title)
                if key in processed_pairs:
                    continue
                processed_pairs.add(key)

                try:
                    en_text = clean_text(get_plain_text(en_title, EN_WIKI))
                    zh_text = clean_text(get_plain_text(zh_title, ZH_WIKI))
                except Exception as e:
                    print(f"[Skip] fetch text error for {en_title} -> {zh_title}: {e}")
                    continue

                en_sents = sent_split_en(en_text)
                zh_sents = sent_split_zh(zh_text)
                if not en_sents or not zh_sents:
                    continue

                df_align = pairwise_align(en_sents, zh_sents, model=model, device=device)
                sim_vals = df_align["cosine"].values

                # summary 行
                all_summary.append({
                    "category": cate,
                    "term": term,
                    "en_title": en_title,
                    "zh_title": zh_title,
                    "n_en_sent": len(en_sents),
                    "n_zh_sent": len(zh_sents),
                    "sim_mean": float(np.mean(sim_vals)),
                    "sim_median": float(np.median(sim_vals)),
                    "sim_q25": float(np.quantile(sim_vals, 0.25)),
                    "sim_q10": float(np.quantile(sim_vals, 0.10)),
                })

                # 导出每个页面里相似度最低的若干句对
                df_topdiff = df_align.head(top_diffs_per_page).copy()
                df_topdiff.insert(0, "category", cate)
                df_topdiff.insert(1, "term", term)
                df_topdiff.insert(2, "en_title", en_title)
                df_topdiff.insert(3, "zh_title", zh_title)
                all_diffs.append(df_topdiff)

                # 软限速，进一步降低被限流概率
                time.sleep(0.06)

    diffs_all_df = (pd.concat(all_diffs, ignore_index=True)
                    if all_diffs else
                    pd.DataFrame(columns=["category","term","en_title","zh_title","en_sentence","zh_best","cosine","zh_idx"]))
    summary_all_df = pd.DataFrame(all_summary, columns=[
        "category","term","en_title","zh_title","n_en_sent","n_zh_sent",
        "sim_mean","sim_median","sim_q25","sim_q10"
    ])
    return diffs_all_df, summary_all_df


In [None]:
diffs_all, summary_all = analyze_keywords_by_category(
    KEYWORDS_BY_CATEGORY,
    max_pages_per_term=MAX_PAGES_PER_TERM,
    top_diffs_per_page=TOP_DIFFS_PER_PAGE,
    device=DEVICE
)

print("完成：页面对数量 =", summary_all.shape[0])
display(summary_all.head(20))
display(diffs_all.head(20))


[race] terms:   2%|▏         | 1/53 [01:53<1:38:07, 113.23s/it]

In [None]:
ensure_dir(OUT_DIR)

sum_path = os.path.join(OUT_DIR, "summary_all.csv")
dif_path = os.path.join(OUT_DIR, "diffs_all.csv")
summary_all.to_csv(sum_path, index=False, encoding="utf-8")
diffs_all.to_csv(dif_path, index=False, encoding="utf-8")
print("Saved:", sum_path, "and", dif_path)

# 按类别分别导出，便于分工审阅
for cate in sorted(KEYWORDS_BY_CATEGORY.keys()):
    sub_sum = summary_all[summary_all["category"] == cate].reset_index(drop=True)
    sub_dif = diffs_all[diffs_all["category"] == cate].reset_index(drop=True)
    if not sub_sum.empty:
        sub_sum.to_csv(os.path.join(OUT_DIR, f"summary__{safe_name(cate)}.csv"), index=False, encoding="utf-8")
    if not sub_dif.empty:
        sub_dif.to_csv(os.path.join(OUT_DIR, f"diffs__{safe_name(cate)}.csv"), index=False, encoding="utf-8")

print("Per-category CSVs saved in:", OUT_DIR)


In [None]:
TOP_K_GLOBAL = 300
SIM_THRESHOLD = 0.5  # 越低越可能存在“语义差异/未对齐”

if not diffs_all.empty:
    topk = diffs_all.sort_values("cosine").head(TOP_K_GLOBAL).reset_index(drop=True)
    thr  = diffs_all[diffs_all["cosine"] <= SIM_THRESHOLD].sort_values("cosine").reset_index(drop=True)

    topk_path = os.path.join(OUT_DIR, f"global_top{TOP_K_GLOBAL}.csv")
    thr_path  = os.path.join(OUT_DIR, f"global_thr{SIM_THRESHOLD}.csv")
    topk.to_csv(topk_path, index=False, encoding="utf-8")
    thr.to_csv(thr_path, index=False, encoding="utf-8")

    print("Saved:", topk_path, "and", thr_path)
else:
    print("diffs_all is empty.")


In [None]:
import shutil
import os
from google.colab import files

# Zip the output directory
output_filename = 'wiki_sem_diff_outputs.zip'
shutil.make_archive(output_filename.replace('.zip', ''), 'zip', OUT_DIR)

# Download the zip file
files.download(output_filename)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

# ---------- load data (current dir or default folder) ----------
if os.path.exists("summary_all.csv") and os.path.exists("diffs_all.csv"):
    summary_df = pd.read_csv("summary_all.csv")
    diffs_df  = pd.read_csv("diffs_all.csv")
elif os.path.exists("wiki_sem_diff_outputs/summary_all.csv") and os.path.exists("wiki_sem_diff_outputs/diffs_all.csv"):
    summary_df = pd.read_csv("wiki_sem_diff_outputs/summary_all.csv")
    diffs_df  = pd.read_csv("wiki_sem_diff_outputs/diffs_all.csv")
else:
    raise FileNotFoundError("summary_all.csv and diffs_all.csv not found in current dir or wiki_sem_diff_outputs/")

# ---------- 1) Category average similarity (bar) ----------
cat_stats = summary_df.groupby("category")["sim_mean"].mean().sort_values()
plt.figure(figsize=(8, 5))
plt.bar(cat_stats.index, cat_stats.values)
plt.title("Average semantic similarity by category")
plt.ylabel("Mean cosine similarity")
plt.xticks(rotation=45, ha="right")
plt.show()

# ---------- 2) Page-level mean similarity distribution (hist) ----------
plt.figure(figsize=(8, 5))
plt.hist(summary_df["sim_mean"].dropna(), bins=30)
plt.title("Distribution of mean similarities across page pairs")
plt.xlabel("Mean cosine similarity")
plt.ylabel("Count")
plt.show()

# ---------- 3) Category boxplot of mean similarity ----------
ordered_cats = list(cat_stats.index)
data_by_cat = [summary_df.loc[summary_df["category"] == c, "sim_mean"].dropna() for c in ordered_cats]
plt.figure(figsize=(9, 5))
plt.boxplot(data_by_cat, labels=ordered_cats, showfliers=False)
plt.title("Mean similarity by category (boxplot)")
plt.ylabel("Mean cosine similarity")
plt.xticks(rotation=30, ha="right")
plt.show()

# ---------- 4) Top-10 lowest-similarity pages (barh, EN titles only) ----------
worst_pages = summary_df.nsmallest(10, "sim_mean").copy()
labels_en = [row["en_title"] for _, row in worst_pages.iterrows()]  # use ONLY English titles
plt.figure(figsize=(10, 6))
y = np.arange(len(worst_pages))
plt.barh(y, worst_pages["sim_mean"].values)
plt.yticks(y, labels_en)
plt.xlabel("Mean cosine similarity")
plt.title("Top-10 pages with lowest English–Chinese semantic similarity")
plt.gca().invert_yaxis()
plt.show()

# ---------- 5) Sentence-level cosine distribution for low-sim pairs ----------
plt.figure(figsize=(8, 5))
plt.hist(diffs_df["cosine"].dropna(), bins=30)
plt.title("Distribution of cosine similarity (sentence-level mismatches)")
plt.xlabel("Cosine similarity")
plt.ylabel("Count")
plt.show()


In [None]:
import pandas as pd
import os

# Auto detect CSV location
if os.path.exists("summary_all.csv") and os.path.exists("diffs_all.csv"):
    summary_df = pd.read_csv("summary_all.csv")
    diffs_df = pd.read_csv("diffs_all.csv")
elif os.path.exists("wiki_sem_diff_outputs/summary_all.csv"):
    summary_df = pd.read_csv("wiki_sem_diff_outputs/summary_all.csv")
    diffs_df = pd.read_csv("wiki_sem_diff_outputs/diffs_all.csv")
else:
    raise FileNotFoundError("summary_all.csv 和 diffs_all.csv 未找到，请确认文件路径")

# Threshold for low semantic similarity
LOW_THRESHOLD = 0.5

# ✅ 1. Total page pairs & sentence pairs
total_pages = len(summary_df)
total_sent_pairs = len(diffs_df)

# ✅ 2. Low-quality sentence pairs count
low_quality_count = (diffs_df["cosine"] < LOW_THRESHOLD).sum()
low_quality_pct = low_quality_count / total_sent_pairs * 100 if total_sent_pairs else 0

# ✅ 3. Replace mean with median at page level
#    summary_df 本身包含 sim_median 字段，我们直接基于此判断
worst_page = summary_df.nsmallest(1, "sim_median").iloc[0]
best_page = summary_df.nlargest(1, "sim_median").iloc[0]

# ✅ 4. Category-level median similarity
cat_median_stats = summary_df.groupby("category")["sim_median"].median().sort_values()
worst_cat = cat_median_stats.index[0]
best_cat = cat_median_stats.index[-1]

# ✅ 5. Print new median-based summary
print("===== Summary Statistics (Using Median Similarity) =====")
print(f"Total English–Chinese page pairs analyzed         : {total_pages}")
print(f"Total sentence pairs aligned                      : {total_sent_pairs}")
print(f"Low-quality sentence pairs (cosine < {LOW_THRESHOLD}) : {low_quality_count} ({low_quality_pct:.2f}%)")

print("\n--- Page-level extremes (by median similarity) ---")
print(f"Lowest similarity page : {worst_page['en_title']} ↔ {worst_page['zh_title']} "
      f"(median={worst_page['sim_median']:.4f})")
print(f"Highest similarity page: {best_page['en_title']} ↔ {best_page['zh_title']} "
      f"(median={best_page['sim_median']:.4f})")

print("\n--- Category-level median similarity ---")
print(cat_median_stats.round(4))
print(f"\nCategory with largest semantic gap : {worst_cat} (median={cat_median_stats[0]:.4f})")
print(f"Category with least semantic gap  : {best_cat} (median={cat_median_stats[-1]:.4f})")


In [None]:
!pip -q install sentence-transformers pandas numpy requests tqdm scipy


In [None]:
import os

# 路径设置：从你之前的结果里读取页面清单（en_title, zh_title 等）
# 如果在当前目录找不到，就尝试默认输出目录
SUMMARY_PATH = "summary_all.csv"
if not os.path.exists(SUMMARY_PATH):
    SUMMARY_PATH = os.path.join("wiki_sem_diff_outputs", "summary_all.csv")

# 模型配置：LaBSE 更适合跨语言整篇比较；如需更快可换成 MiniLM 多语模型
EMB_MODEL = "sentence-transformers/LaBSE"   # 也可改为 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
CHUNK_STRATEGY = "sentence"                 # "sentence" 或 "char" 两种分块策略
SENT_PER_CHUNK = 15                         # sentence 策略下，每块句子数（建议 10~25）
CHARS_PER_CHUNK = 1800                      # char 策略下，每块最大字符数（建议 1500~2500）

MIN_SENT_LEN = 8                            # 最小句长（字符），过滤噪声
MAX_PAGES = None                            # None 表示全量，或如 200 做小规模测试
SAVE_CSV = True                             # 是否保存结果 CSV
OUT_CSV = "doclevel_similarity.csv"         # 输出文件名


In [None]:
import re, time, random, requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional
from requests.adapters import HTTPAdapter, Retry
from sentence_transformers import SentenceTransformer, util
import torch

# ---------- 设备自动选择 ----------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# ---------- Wikimedia API: UA + Retry + 限速 ----------
WIKI_HEADERS = {
    "User-Agent": "WikiPageDocSim/0.1 (research; contact: your_email@example.com)",
    "Accept-Language": "en-US,zh-CN;q=0.9",
}
session = requests.Session()
retries = Retry(
    total=5, backoff_factor=1.2,
    status_forcelist=[403, 429, 500, 502, 503, 504],
    allowed_methods=["GET"]
)
session.mount("https://", HTTPAdapter(max_retries=retries))
session.mount("http://",  HTTPAdapter(max_retries=retries))

REQUEST_DELAY_BASE = 0.20
REQUEST_DELAY_JITTER = 0.20

def _wiki_get(lang: str, params: dict, timeout: int = 30) -> dict:
    merged = {"origin": "*", "format": "json", "formatversion": 2, **params}
    time.sleep(REQUEST_DELAY_BASE + random.random() * REQUEST_DELAY_JITTER)
    r = session.get(f"https://{lang}.wikipedia.org/w/api.php",
                    params=merged, headers=WIKI_HEADERS, timeout=timeout)
    if r.status_code in (403, 429):
        time.sleep(1.5 + random.random())
    r.raise_for_status()
    return r.json()

def get_plain_text(title: str, lang: str) -> str:
    """抓取维基页面纯文本（可能含摘要与正文，取决于条目结构）。"""
    data = _wiki_get(lang, {"action": "query", "prop": "extracts", "explaintext": 1, "titles": title})
    pages = (data.get("query") or {}).get("pages") or []
    texts = []
    for p in pages:
        ext = p.get("extract")
        if ext:
            texts.append(ext)
    return "\n".join(texts).strip()

# ---------- 文本预处理 ----------
def clean_text(t: str) -> str:
    t = re.sub(r'\[[0-9]+\]', '', t)     # 去掉参考文献标注 [1]
    t = re.sub(r'[ \t]+', ' ', t)
    t = re.sub(r'\n{2,}', '\n', t)
    return t.strip()

def sent_split_en(text: str) -> List[str]:
    text = text.replace('\n', ' ')
    sents = re.split(r'(?<=[.!?])\s+', text)
    return [s.strip() for s in sents if len(s.strip()) >= MIN_SENT_LEN]

def sent_split_zh(text: str) -> List[str]:
    text = text.replace('\n', '')
    sents = re.split(r'(?<=[。！？])', text)
    return [s.strip() for s in sents if len(s.strip()) >= MIN_SENT_LEN]

# ---------- 分块策略 ----------
def chunk_by_sentence(sents: List[str], sent_per_chunk: int = 15) -> List[str]:
    chunks, buf = [], []
    for s in sents:
        buf.append(s)
        if len(buf) >= sent_per_chunk:
            chunks.append(" ".join(buf))
            buf = []
    if buf:
        chunks.append(" ".join(buf))
    return chunks

def chunk_by_char(text: str, max_chars: int = 1800) -> List[str]:
    text = text.strip()
    if not text:
        return []
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + max_chars, len(text))
        # 尽量在句末断开
        sub = text[start:end]
        cut = max(sub.rfind("。"), sub.rfind("."), sub.rfind("!"), sub.rfind("?"))
        if cut != -1 and start + cut + 1 < end:
            end = start + cut + 1
        chunks.append(text[start:end].strip())
        start = end
    return [c for c in chunks if len(c) >= MIN_SENT_LEN]

def build_doc_chunks(text: str, lang: str) -> List[str]:
    if CHUNK_STRATEGY == "sentence":
        sents = sent_split_en(text) if lang == "en" else sent_split_zh(text)
        return chunk_by_sentence(sents, SENT_PER_CHUNK)
    else:
        return chunk_by_char(text, CHARS_PER_CHUNK)

# ---------- 文档级嵌入 ----------
def doc_embedding(text: str, lang: str, model: SentenceTransformer) -> Tuple[np.ndarray, int]:
    """
    返回：文档平均向量 (np.ndarray), chunk 数量
    """
    chunks = build_doc_chunks(text, lang)
    if len(chunks) == 0:
        return np.zeros((model.get_sentence_embedding_dimension(),), dtype=np.float32), 0
    emb = model.encode(chunks, convert_to_tensor=True, device=DEVICE, normalize_embeddings=True)
    doc_vec = emb.mean(dim=0).detach().cpu().numpy()
    return doc_vec, len(chunks)


In [None]:
# 读取页面清单
summary_df = pd.read_csv(SUMMARY_PATH)
if MAX_PAGES is not None:
    summary_df = summary_df.head(MAX_PAGES).copy()

# 仅保留必须字段
need_cols = ["category", "term", "en_title", "zh_title"]
for c in need_cols:
    if c not in summary_df.columns:
        raise ValueError(f"Input CSV缺少字段: {c}")

# 加载模型
model = SentenceTransformer(EMB_MODEL, device=DEVICE)
print("Loaded model:", EMB_MODEL, "on", DEVICE)

# 逐页面计算“整篇文档向量”及相似度
rows = []
for _, r in tqdm(summary_df.iterrows(), total=len(summary_df), desc="Computing doc-level similarity"):
    cate = r["category"] if "category" in r else None
    term = r["term"] if "term" in r else None
    en_title = r["en_title"]
    zh_title = r["zh_title"]

    try:
        en_text = clean_text(get_plain_text(en_title, "en"))
        zh_text = clean_text(get_plain_text(zh_title, "zh"))
    except Exception as e:
        print(f"[Skip] fetch error for {en_title} -> {zh_title}: {e}")
        continue

    # 文档级向量
    en_vec, en_chunks = doc_embedding(en_text, "en", model)
    zh_vec, zh_chunks = doc_embedding(zh_text, "zh", model)

    # 计算文档级余弦相似度
    if np.linalg.norm(en_vec) == 0 or np.linalg.norm(zh_vec) == 0:
        doc_sim = np.nan
    else:
        doc_sim = float(np.dot(en_vec, zh_vec) / (np.linalg.norm(en_vec) * np.linalg.norm(zh_vec)))

    rows.append({
        "category": cate,
        "term": term,
        "en_title": en_title,
        "zh_title": zh_title,
        "len_en_chars": len(en_text),
        "len_zh_chars": len(zh_text),
        "n_en_chunks": en_chunks,
        "n_zh_chunks": zh_chunks,
        "doc_similarity": doc_sim,
    })

docsim_df = pd.DataFrame(rows)
print("Done. Page pairs computed:", len(docsim_df))
display(docsim_df.head(10))

if SAVE_CSV and not docsim_df.empty:
    docsim_df.to_csv(OUT_CSV, index=False, encoding="utf-8")
    print("Saved:", OUT_CSV)


In [None]:
if not docsim_df.empty:
    valid = docsim_df.dropna(subset=["doc_similarity"])
    print("Total page pairs:", len(docsim_df))
    print("Valid doc similarities:", len(valid))
    if len(valid):
        print("Doc-sim median:", round(valid["doc_similarity"].median(), 4))
        print("Doc-sim mean  :", round(valid["doc_similarity"].mean(), 4))
        print("Doc-sim min   :", round(valid["doc_similarity"].min(), 4))
        print("Doc-sim max   :", round(valid["doc_similarity"].max(), 4))

        cat_median = valid.groupby("category")["doc_similarity"].median().sort_values()
        print("\nPer-category median doc similarity (ascending):")
        print(cat_median.round(4))
