In [None]:
import collections
import sys
from tqdm import tqdm

MAX_ENT = 10_000

train_path = "train.txt"
valid_path = "valid.txt"
test_path  = "test.txt"

out_train = "train_10K.txt"
out_valid = "valid_10K.txt"
out_test  = "test_10K.txt"
out_ent_map = "entity2id_10K.tsv"

# ============================================================
# 1. 统计实体出现次数
# ============================================================
entity_cnt = collections.Counter()

def scan_file(path):
    with open(path, "r") as f:
        for line in f:
            h, r, t = map(int, line.strip().split())
            entity_cnt[h] += 1
            entity_cnt[t] += 1

print("Scanning files...")
scan_file(train_path)
scan_file(valid_path)
scan_file(test_path)

print(f"Total unique entities: {len(entity_cnt)}")

# ============================================================
# 2. 选 Top 100W 实体
# ============================================================
top_entities = [e for e, _ in entity_cnt.most_common(MAX_ENT)]
top_entity_set = set(top_entities)

print(f"Kept entities: {len(top_entity_set)}")

# old_id -> new_id
entity2id = {e: i for i, e in enumerate(top_entities)}

# ============================================================
# 3. 过滤 + 重映射三元组
# ============================================================
def remap_file(in_path, out_path):
    kept = 0
    total = 0
    with open(in_path, "r") as fin, open(out_path, "w") as fout:
        for line in fin:
            total += 1
            h, r, t = map(int, line.strip().split())
            if h in top_entity_set and t in top_entity_set:
                fout.write(f"{entity2id[h]}\t{r}\t{entity2id[t]}\n")
                kept += 1
    print(f"{out_path}: kept {kept} / {total}")

print("Remapping triples...")
remap_file(train_path, out_train)
remap_file(valid_path, out_valid)
remap_file(test_path, out_test)

# ============================================================
# 4. 保存 entity2id
# ============================================================
with open(out_ent_map, "w") as f:
    for old_id, new_id in entity2id.items():
        f.write(f"{old_id}\t{new_id}\n")

print("Done.")


Scanning files...
