In [1]:
import os
import re
import gc
import json
import time
import random
import pickle
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'STFangsong'

if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

MODEL_DIR = "../model"
RESULT_DIR = "../result"
DATA_DIR = "../data"

Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(RESULT_DIR).mkdir(parents=True, exist_ok=True)
Path(DATA_DIR).mkdir(parents=True, exist_ok=True)

Using device: mps


In [None]:
import jieba
RAW_WIKI_DIR = Path(DATA_DIR) / "output"
TOKENIZED_PATH = Path(DATA_DIR) / "tokenized.txt"
VOCAB_PATH = Path(MODEL_DIR) / "vocab.pkl"          # word2id, id2word, counts
COOC_NPZ = Path(MODEL_DIR) / "cooc.npz"             # co-occurrence triples (i, j, x_ij)
VECTORS_TXT = Path(MODEL_DIR) / "vectors.txt"       # exported word + vectors
EMB_TARGET_PT = Path(MODEL_DIR) / "glove_W.pt"
EMB_CONTEXT_PT = Path(MODEL_DIR) / "glove_C.pt"
EMB_MERGED_PT = Path(MODEL_DIR) / "glove_merged.pt"

# Corpus / preprocessing
SPLIT_SENT = True
MAX_DOCS = None
MAX_LINES_PER_FILE = None

# Tokenization / vocab
MIN_COUNT = 5
MAX_VOCAB = 400_000
KEEP_DIGITS = False
LOWER_CASE = False
RESERVED_TOKENS = ["<unk>"]

# Co-occurrence
WINDOW_SIZE = 10
SYMMETRIC = True
WEIGHT_BY_DISTANCE = True

# GloVe training
EMBED_DIM = 300
BATCH_SIZE = 131072
EPOCHS = 25
LR = 0.05
X_MAX = 100
ALPHA = 0.75
SEED = 2025

matplotlib.rcParams["figure.dpi"] = 120

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


def save_fig(path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path)
    print(f"Saved figure to: {path}")


def normalize_text(txt: str) -> str:
    """Basic cleaning for Chinese Wikipedia text."""
    if txt is None:
        return ""
    if LOWER_CASE:
        txt = txt.lower()
    txt = re.sub(r"[\x00-\x08\x0B-\x1F\x7F]", " ", txt)
    return txt


def sentence_split(txt: str):
    if not SPLIT_SENT:
        return [txt]
    sents = re.split(r"[。！？!?；;\n\r]+", txt)
    sents = [s.strip() for s in sents if s.strip()]
    return sents


cc = None

print("Setup done.")

Setup done.


In [3]:
def iter_wiki_json_lines(raw_dir: Path, max_docs=None, max_lines_per_file=None):
    """Yield raw text from WikiExtractor JSONL files under raw_dir."""
    count_docs = 0
    for root, _dirs, files in os.walk(raw_dir):
        files = sorted(files)
        for fname in files:
            if not fname.startswith("wiki_"):
                continue
            fpath = Path(root) / fname
            with open(fpath, "r", encoding="utf-8") as f:
                for i, line in enumerate(f):
                    if max_lines_per_file is not None and i >= max_lines_per_file:
                        break
                    try:
                        obj = json.loads(line)
                    except Exception:
                        continue
                    text = obj.get("text", "") or ""
                    title = obj.get("title", "") or ""
                    if title:
                        text = title + "。" + text
                    yield text
                    count_docs += 1
                    if max_docs is not None and count_docs >= max_docs:
                        return


def tokenize_to_file(raw_dir: Path, out_path: Path, keep_digits=False, cc_converter=None, max_docs=None, max_lines_per_file=None):
    """Stream JSONL -> tokenize -> write one sentence per line (space-separated tokens)."""
    out_path.parent.mkdir(parents=True, exist_ok=True)
    total_tokens = 0
    total_sents = 0
    total_docs = 0

    with open(out_path, "w", encoding="utf-8") as out_f:
        pbar = tqdm(iter_wiki_json_lines(raw_dir, max_docs, max_lines_per_file), desc="Tokenizing docs")
        for raw in pbar:
            total_docs += 1
            txt = normalize_text(raw)
            if cc_converter:
                txt = cc_converter.convert(txt)
            sents = sentence_split(txt)
            for s in sents:
                toks = [t.strip() for t in jieba.lcut(s) if t.strip()]
                if not keep_digits:
                    toks = [t for t in toks if not t.isdigit()]
                if len(toks) == 0:
                    continue
                out_f.write(" ".join(toks) + "\n")
                total_tokens += len(toks)
                total_sents += 1
            if total_docs % 5000 == 0:
                pbar.set_postfix({"docs": total_docs, "sents": total_sents, "tokens": total_tokens})

    print(f"Tokenization finished. Docs={total_docs:,}, Sents={total_sents:,}, Tokens={total_tokens:,}")
    return dict(docs=total_docs, sents=total_sents, tokens=total_tokens)


stats = tokenize_to_file(RAW_WIKI_DIR, TOKENIZED_PATH, keep_digits=KEEP_DIGITS, cc_converter=cc,
                         max_docs=MAX_DOCS, max_lines_per_file=MAX_LINES_PER_FILE)

Tokenizing docs: 0it [00:00, ?it/s]Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.413 seconds.
Prefix dict has been built successfully.
Tokenizing docs: 1907228it [21:04, 1508.21it/s, docs=1.9e+6, sents=1.42e+7, tokens=2.38e+8]  

Tokenization finished. Docs=1,907,228, Sents=14,209,318, Tokens=237,753,400





In [3]:
def build_vocab_from_tokenized(path: Path, min_count=5, max_vocab=400_000, reserved_tokens=None):
    """Count tokens -> prune -> build word2id/id2word."""
    cnt = Counter()
    with open(path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc="Counting vocab"):
            toks = line.strip().split()
            cnt.update(toks)

    items = [(w, c) for w, c in cnt.items() if c >= min_count]
    items.sort(key=lambda x: x[1], reverse=True)

    if max_vocab is not None:
        items = items[:max_vocab]

    word2id = {}
    id2word = []
    counts = []

    if reserved_tokens:
        for t in reserved_tokens:
            word2id[t] = len(id2word)
            id2word.append(t)
            counts.append(0)

    for w, c in items:
        if w in word2id:
            continue
        word2id[w] = len(id2word)
        id2word.append(w)
        counts.append(c)

    vocab = {"word2id": word2id, "id2word": id2word, "counts": np.array(counts, dtype=np.int64)}
    with open(VOCAB_PATH, "wb") as f:
        pickle.dump(vocab, f)
    print(f"Vocab built. Size={len(id2word):,} (min_count={min_count}, max_vocab={max_vocab})")
    return vocab, cnt


vocab, full_counter = build_vocab_from_tokenized(TOKENIZED_PATH, MIN_COUNT, MAX_VOCAB, RESERVED_TOKENS)

freqs = np.array(sorted(full_counter.values(), reverse=True), dtype=np.float64)
ranks = np.arange(1, len(freqs)+1, dtype=np.float64)

plt.figure()
plt.loglog(ranks, freqs)
plt.xlabel("Rank")
plt.ylabel("Frequency")
plt.title("Token Frequency (Zipf-like)")
save_fig(Path(RESULT_DIR) / "zipf_curve.png")
plt.close()

Counting vocab: 14209318it [00:31, 457434.62it/s]


Vocab built. Size=400,001 (min_count=5, max_vocab=400000)
Saved figure to: ../result/zipf_curve.png


In [3]:
with open(VOCAB_PATH, "rb") as f:
    vocab = pickle.load(f)
word2id = vocab["word2id"]
id2word = vocab["id2word"]
counts = vocab["counts"]
vocab_size = len(id2word)
unk_id = word2id.get("<unk>", None)


def line_to_ids(line: str):
    ids = []
    for t in line.strip().split():
        idx = word2id.get(t, unk_id)
        if idx is not None:
            ids.append(idx)
    return ids


def build_cooccurrence(tokenized_path: Path, window_size=10, symmetric=True, weight_by_distance=True):
    """
    Build co-occurrence dict: (i, j) -> x_ij.
    WARNING: This can be large for full Wikipedia; ensure enough RAM.
    """
    cooc = defaultdict(float)
    with open(tokenized_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc="Building co-occurrence"):
            ids = line_to_ids(line)
            n = len(ids)
            for center in range(n):
                w_i = ids[center]
                start = max(0, center - window_size)
                end = min(n, center + window_size + 1)
                for ctx in range(start, end):
                    if ctx == center:
                        continue
                    w_j = ids[ctx]
                    dist = abs(ctx - center)
                    if dist == 0:
                        continue
                    weight = (1.0 / dist) if weight_by_distance else 1.0
                    cooc[(w_i, w_j)] += weight
                    if symmetric:
                        cooc[(w_j, w_i)] += weight
    if len(cooc) == 0:
        raise RuntimeError("Empty co-occurrence matrix. Check preprocessing.")
    i_idx = np.fromiter((k[0] for k in cooc.keys()), dtype=np.int32, count=len(cooc))
    j_idx = np.fromiter((k[1] for k in cooc.keys()), dtype=np.int32, count=len(cooc))
    x_val = np.fromiter(cooc.values(), dtype=np.float32, count=len(cooc))
    np.savez_compressed(COOC_NPZ, i=i_idx, j=j_idx, x=x_val, vocab_size=np.int32(vocab_size))
    print(f"Co-occurrence saved: {COOC_NPZ}, nnz={len(cooc):,}")
    cooc.clear()
    gc.collect()

In [None]:
build_cooccurrence(TOKENIZED_PATH, WINDOW_SIZE, SYMMETRIC, WEIGHT_BY_DISTANCE)

Building co-occurrence: 14209318it [44:14, 5353.49it/s] 


Co-occurrence saved: ../model/cooc.npz, nnz=415,615,849


In [4]:
class CoocDataset(Dataset):
    def __init__(self, npz_path: Path):
        dat = np.load(npz_path)
        self.i = dat["i"].astype(np.int64)
        self.j = dat["j"].astype(np.int64)
        self.x = dat["x"].astype(np.float32)
        self.vocab_size = int(dat["vocab_size"])
        assert len(self.i) == len(self.j) == len(self.x)
        print(f"Loaded cooc triples: {len(self.x):,}, vocab_size={self.vocab_size:,}")

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.i[idx], self.j[idx], self.x[idx]


class GloVeModel(nn.Module):
    def __init__(self, vocab_size: int, dim: int):
        super().__init__()
        self.w = nn.Embedding(vocab_size, dim)   # target
        self.c = nn.Embedding(vocab_size, dim)   # context
        self.bw = nn.Embedding(vocab_size, 1)    # bias for w
        self.bc = nn.Embedding(vocab_size, 1)    # bias for c
        init_range = 0.5 / dim
        nn.init.uniform_(self.w.weight, a=-init_range, b=init_range)
        nn.init.uniform_(self.c.weight, a=-init_range, b=init_range)
        nn.init.zeros_(self.bw.weight)
        nn.init.zeros_(self.bc.weight)

    def forward(self, i_idx, j_idx):
        w_i = self.w(i_idx)          # (B, D)
        c_j = self.c(j_idx)          # (B, D)
        bw_i = self.bw(i_idx).squeeze(-1)  # (B,)
        bc_j = self.bc(j_idx).squeeze(-1)  # (B,)
        dot = (w_i * c_j).sum(dim=1)       # (B,)
        return dot + bw_i + bc_j


def glove_loss(pred, x_ij, x_max=100.0, alpha=0.75):
    """
    pred = w_i^T c_j + b_i + b_j
    target = log(x_ij)
    weight = (x_ij / x_max)^alpha if x_ij < x_max else 1
    """
    log_x = torch.log(x_ij)
    w = torch.where(x_ij < x_max, (x_ij / x_max) ** alpha, torch.ones_like(x_ij))
    loss = w * (pred - log_x) ** 2
    return loss.mean()

In [5]:
dataset = CoocDataset(COOC_NPZ)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=0)

model = GloVeModel(dataset.vocab_size, EMBED_DIM).to(device)
optimizer = torch.optim.Adagrad(model.parameters(), lr=LR)

Loaded cooc triples: 415,615,849, vocab_size=400,001


In [None]:
epoch_losses = []

for epoch in range(1, EPOCHS + 1):
    model.train()
    running = 0.0
    n_batches = 0
    t0 = time.time()
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for (i_idx, j_idx, x_ij) in pbar:
        i_idx = i_idx.to(device)
        j_idx = j_idx.to(device)
        x_ij = x_ij.to(device)

        optimizer.zero_grad(set_to_none=True)
        pred = model(i_idx, j_idx)
        loss = glove_loss(pred, x_ij, X_MAX, ALPHA)
        loss.backward()
        optimizer.step()

        running += loss.item()
        n_batches += 1
        if n_batches % 100 == 0:
            pbar.set_postfix({"loss": f"{running/n_batches:.4f}"})

    epoch_loss = running / max(1, n_batches)
    epoch_losses.append(epoch_loss)
    dt = time.time() - t0
    print(f"Epoch {epoch}: loss={epoch_loss:.6f}, time={dt:.1f}s")

torch.save(model.w.weight.detach().cpu(), EMB_TARGET_PT)
torch.save(model.c.weight.detach().cpu(), EMB_CONTEXT_PT)
torch.save(model.state_dict(), Path(MODEL_DIR) / "glove_model_state.pt")
print(f"Saved embeddings to: {EMB_TARGET_PT}, {EMB_CONTEXT_PT}")

plt.figure()
plt.plot(np.arange(1, len(epoch_losses)+1), epoch_losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("GloVe Training Loss")
save_fig(Path(RESULT_DIR) / "glove_training_loss.png")
plt.close()

Loaded cooc triples: 415,615,849, vocab_size=400,001


Epoch 1/25: 100%|██████████| 3171/3171 [14:07<00:00,  3.74it/s, loss=0.0809]  


Epoch 1: loss=0.080469, time=847.5s


Epoch 2/25: 100%|██████████| 3171/3171 [14:22<00:00,  3.67it/s, loss=0.0438]


Epoch 2: loss=0.043797, time=862.9s


Epoch 3/25: 100%|██████████| 3171/3171 [14:44<00:00,  3.59it/s, loss=0.0353] 


Epoch 3: loss=0.035293, time=884.2s


Epoch 4/25: 100%|██████████| 3171/3171 [14:32<00:00,  3.64it/s, loss=0.0311]


Epoch 4: loss=0.031140, time=872.1s


Epoch 5/25: 100%|██████████| 3171/3171 [14:37<00:00,  3.61it/s, loss=0.0286] 


Epoch 5: loss=0.028622, time=878.0s


Epoch 6/25: 100%|██████████| 3171/3171 [14:38<00:00,  3.61it/s, loss=0.0269]  


Epoch 6: loss=0.026886, time=878.5s


Epoch 7/25: 100%|██████████| 3171/3171 [14:41<00:00,  3.60it/s, loss=0.0256] 


Epoch 7: loss=0.025595, time=882.0s


Epoch 8/25: 100%|██████████| 3171/3171 [14:33<00:00,  3.63it/s, loss=0.0246]


Epoch 8: loss=0.024585, time=873.1s


Epoch 9/25: 100%|██████████| 3171/3171 [14:40<00:00,  3.60it/s, loss=0.0238] 


Epoch 9: loss=0.023766, time=880.4s


Epoch 10/25: 100%|██████████| 3171/3171 [14:41<00:00,  3.60it/s, loss=0.0231]  


Epoch 10: loss=0.023084, time=881.0s


Epoch 11/25: 100%|██████████| 3171/3171 [14:40<00:00,  3.60it/s, loss=0.0225] 


Epoch 11: loss=0.022505, time=880.8s


Epoch 12/25: 100%|██████████| 3171/3171 [14:35<00:00,  3.62it/s, loss=0.0220]


Epoch 12: loss=0.022004, time=875.5s


Epoch 13/25: 100%|██████████| 3171/3171 [14:40<00:00,  3.60it/s, loss=0.0216] 


Epoch 13: loss=0.021567, time=880.2s


Epoch 14/25: 100%|██████████| 3171/3171 [14:26<00:00,  3.66it/s, loss=0.0212]  


Epoch 14: loss=0.021179, time=866.5s


Epoch 15/25: 100%|██████████| 3171/3171 [14:27<00:00,  3.65it/s, loss=0.0208] 


Epoch 15: loss=0.020834, time=867.6s


Epoch 16/25: 100%|██████████| 3171/3171 [14:13<00:00,  3.72it/s, loss=0.0205]


Epoch 16: loss=0.020523, time=853.1s


Epoch 17/25: 100%|██████████| 3171/3171 [14:23<00:00,  3.67it/s, loss=0.0202] 


Epoch 17: loss=0.020241, time=863.0s


Epoch 18/25: 100%|██████████| 3171/3171 [14:18<00:00,  3.69it/s, loss=0.0200]  


Epoch 18: loss=0.019984, time=858.7s


Epoch 19/25: 100%|██████████| 3171/3171 [14:23<00:00,  3.67it/s, loss=0.0197]


Epoch 19: loss=0.019749, time=863.4s


Epoch 20/25: 100%|██████████| 3171/3171 [14:18<00:00,  3.69it/s, loss=0.0195]


Epoch 20: loss=0.019532, time=858.9s


Epoch 21/25: 100%|██████████| 3171/3171 [14:19<00:00,  3.69it/s, loss=0.0193] 


Epoch 21: loss=0.019332, time=859.4s


Epoch 22/25: 100%|██████████| 3171/3171 [14:16<00:00,  3.70it/s, loss=0.0191]  


Epoch 22: loss=0.019146, time=856.8s


Epoch 23/25: 100%|██████████| 3171/3171 [14:21<00:00,  3.68it/s, loss=0.0190] 


Epoch 23: loss=0.018973, time=861.3s


Epoch 24/25: 100%|██████████| 3171/3171 [14:10<00:00,  3.73it/s, loss=0.0188]


Epoch 24: loss=0.018812, time=850.9s


Epoch 25/25: 100%|██████████| 3171/3171 [14:14<00:00,  3.71it/s, loss=0.0187] 


Epoch 25: loss=0.018661, time=854.2s


Saved embeddings to: ../model/glove_W.pt, ../model/glove_C.pt
Saved figure to: ../result/glove_training_loss.png


In [7]:
W = torch.load(EMB_TARGET_PT, map_location="cpu")
C = torch.load(EMB_CONTEXT_PT, map_location="cpu")
merged = (W + C) / 2.0
torch.save(merged, EMB_MERGED_PT)
print(f"Merged embedding saved to: {EMB_MERGED_PT}")

Merged embedding saved to: ../model/glove_merged.pt


In [8]:
with open(VECTORS_TXT, "w", encoding="utf-8") as f:
    for idx, word in enumerate(id2word):
        vec = merged[idx].numpy()
        f.write(word + " " + " ".join(f"{x:.6f}" for x in vec) + "\n")
print(f"Exported vectors to: {VECTORS_TXT}")

Exported vectors to: ../model/vectors.txt


In [9]:
E = merged.numpy().astype(np.float32)
norm = np.linalg.norm(E, axis=1, keepdims=True) + 1e-12
E_norm = E / norm

w2i = word2id
i2w = id2word


def most_similar(word, topn=10):
    """Return top-N nearest neighbors by cosine similarity."""
    if word not in w2i:
        return []
    idx = w2i[word]
    q = E_norm[idx]
    sims = E_norm @ q
    sims[idx] = -1.0
    top_idx = np.argpartition(-sims, range(topn))[:topn]
    top_idx = top_idx[np.argsort(-sims[top_idx])]
    result = [(i2w[i], float(sims[i])) for i in top_idx]
    return result


test_words = ["中国", "北京", "美国", "数学", "哲学", "上海"]
for w in test_words:
    if w in w2i:
        print(f"Nearest to '{w}':")
        for ww, sc in most_similar(w, topn=10):
            print(f"  {ww}\t{sc:.4f}")
    else:
        print(f"'{w}' not in vocabulary.")

Nearest to '中国':
  大陆	0.7101
  中國	0.5929
  当时	0.5713
  中华人民共和国	0.5712
  历史	0.5645
  内地	0.5614
  于	0.5612
  同时	0.5545
  成为	0.5539
  这是	0.5506
Nearest to '北京':
  上海	0.6605
  南京	0.5643
  天津	0.5554
  北京市	0.5396
  北平	0.5353
  广州	0.5205
  中国	0.5110
  杭州	0.5095
  赴	0.5094
  成都	0.5063
Nearest to '美国':
  英国	0.6713
  美國	0.5845
  华盛顿	0.5583
  澳大利亚	0.5573
  加拿大	0.5558
  成为	0.5507
  国家	0.5483
  并且	0.5468
  当时	0.5459
  他们	0.5421
Nearest to '数学':
  物理学	0.6340
  计算机科学	0.6060
  应用	0.6019
  理论	0.5845
  哲学	0.5787
  化学	0.5786
  科学	0.5747
  物理	0.5665
  计算	0.5561
  领域	0.5485
Nearest to '哲学':
  神学	0.6678
  社会学	0.6188
  社会科学	0.6060
  文学	0.5892
  心理学	0.5843
  科学	0.5838
  马克思主义	0.5837
  经济学	0.5808
  政治学	0.5803
  数学	0.5787
Nearest to '上海':
  北京	0.6605
  天津	0.6413
  广州	0.6138
  南京	0.5871
  杭州	0.5772
  上海市	0.5665
  武汉	0.5358
  廣州	0.5341
  重庆	0.5333
  深圳	0.5246


In [10]:
from sklearn.manifold import TSNE


def plot_tsne(words, out_path: Path):
    indices = [w2i[w] for w in words if w in w2i]
    if len(indices) < 2:
        print("Not enough words found in vocab for t-SNE.")
        return
    X = E[indices]
    tsne = TSNE(n_components=2, init="random", learning_rate="auto", perplexity=min(30, len(indices)-1), n_iter_without_progress=2000)
    X2 = tsne.fit_transform(X)
    plt.figure()
    plt.scatter(X2[:, 0], X2[:, 1])
    for i, w in enumerate([w for w in words if w in w2i]):
        plt.text(X2[i, 0], X2[i, 1], w)
    plt.title("t-SNE of Selected Words")
    save_fig(out_path)
    plt.close()


example_words = ["中国", "北京", "上海", "美国", "纽约", "数学", "代数", "几何", "物理", "化学", "哲学", "逻辑", "计算机", "算法", "数据"]
plot_tsne(example_words, Path(RESULT_DIR) / "glove_tsne.png")

Saved figure to: ../result/glove_tsne.png
