In [15]:
from collections import defaultdict

INPUT = "freebase_douban"
ENTITY_PREFIX = "http://rdf.freebase.com/ns/"
REL_PREFIX = ""
MIN_COUNT = 10

entity_cnt = defaultdict(int)
rel_cnt = defaultdict(int)

with open(INPUT, "r", encoding="utf-8", errors="ignore") as f:
    for i, line in enumerate(f):
        if i % 10_000_000 == 0 and i > 0:
            print(f"[pass1] processed {i:,} lines")
            #break

        parts = line.strip().split()
        if len(parts) < 3:
            continue

        h, r, t = parts[0], parts[1], parts[2]

        h = h.strip("<>")
        r = r.strip("<>")
        t = t.strip("<>").rstrip(".")

        if h.startswith(ENTITY_PREFIX):
            entity_cnt[h] += 1
        if t.startswith(ENTITY_PREFIX):
            entity_cnt[t] += 1
        if r.startswith(REL_PREFIX):
            rel_cnt[r] += 1


[pass1] processed 10,000,000 lines
[pass1] processed 20,000,000 lines
[pass1] processed 30,000,000 lines
[pass1] processed 40,000,000 lines
[pass1] processed 50,000,000 lines
[pass1] processed 60,000,000 lines
[pass1] processed 70,000,000 lines
[pass1] processed 80,000,000 lines
[pass1] processed 90,000,000 lines
[pass1] processed 100,000,000 lines
[pass1] processed 110,000,000 lines
[pass1] processed 120,000,000 lines
[pass1] processed 130,000,000 lines
[pass1] processed 140,000,000 lines
[pass1] processed 150,000,000 lines
[pass1] processed 160,000,000 lines
[pass1] processed 170,000,000 lines
[pass1] processed 180,000,000 lines
[pass1] processed 190,000,000 lines
[pass1] processed 200,000,000 lines
[pass1] processed 210,000,000 lines
[pass1] processed 220,000,000 lines
[pass1] processed 230,000,000 lines
[pass1] processed 240,000,000 lines
[pass1] processed 250,000,000 lines
[pass1] processed 260,000,000 lines
[pass1] processed 270,000,000 lines
[pass1] processed 280,000,000 lines
[

In [16]:
valid_entities = {e for e, c in entity_cnt.items() if c >= MIN_COUNT}
valid_relations = {r for r, c in rel_cnt.items() if c >= MIN_COUNT}

print("valid entities:", len(valid_entities))
print("valid relations:", len(valid_relations))


valid entities: 10267841
valid relations: 2452


In [17]:
entity2id = {}
rel2id = {}
from tqdm import tqdm

def get_id(mapping, key):
    if key not in mapping:
        mapping[key] = len(mapping)
    return mapping[key]

with open(INPUT, "r", encoding="utf-8", errors="ignore") as f, \
     open("triples_id.txt", "w") as out:

    for i, line in enumerate(f):
        if i % 10_000_000 == 0 and i > 0:
            print(f"[pass2] processed {i:,} lines")

        parts = line.strip().split()
        if len(parts) < 3:
            continue

        h, r, t = parts[0], parts[1], parts[2]
        
        h = h.strip("<>")
        r = r.strip("<>")
        t = t.strip("<>").rstrip(".")


        if (
            h in valid_entities
            and t in valid_entities
            and r in valid_relations
        ):
            hid = get_id(entity2id, h)
            rid = get_id(rel2id, r)
            tid = get_id(entity2id, t)

            out.write(f"{hid}\t{rid}\t{tid}\n")



[pass2] processed 10,000,000 lines
[pass2] processed 20,000,000 lines
[pass2] processed 30,000,000 lines
[pass2] processed 40,000,000 lines
[pass2] processed 50,000,000 lines
[pass2] processed 60,000,000 lines
[pass2] processed 70,000,000 lines
[pass2] processed 80,000,000 lines
[pass2] processed 90,000,000 lines
[pass2] processed 100,000,000 lines
[pass2] processed 110,000,000 lines
[pass2] processed 120,000,000 lines
[pass2] processed 130,000,000 lines
[pass2] processed 140,000,000 lines
[pass2] processed 150,000,000 lines
[pass2] processed 160,000,000 lines
[pass2] processed 170,000,000 lines
[pass2] processed 180,000,000 lines
[pass2] processed 190,000,000 lines
[pass2] processed 200,000,000 lines
[pass2] processed 210,000,000 lines
[pass2] processed 220,000,000 lines
[pass2] processed 230,000,000 lines
[pass2] processed 240,000,000 lines
[pass2] processed 250,000,000 lines
[pass2] processed 260,000,000 lines
[pass2] processed 270,000,000 lines
[pass2] processed 280,000,000 lines
[

In [18]:
with open("entity2id.txt", "w") as f:
    for e, i in entity2id.items():
        f.write(f"{e}\t{i}\n")

with open("relation2id.txt", "w") as f:
    for r, i in rel2id.items():
        f.write(f"{r}\t{i}\n")


In [19]:
import random

INPUT = "triples_id.txt"
TRAIN_OUT = "train.txt"
VALID_OUT = "valid.txt"
TEST_OUT = "test.txt"

TRAIN_RATIO = 0.8
VALID_RATIO = 0.1
TEST_RATIO = 0.1

SEED = 42   # 固定随机种子，保证可复现

# 读入所有三元组
with open(INPUT, "r") as f:
    triples = f.readlines()

print("total triples:", len(triples))

# shuffle
random.seed(SEED)
random.shuffle(triples)

# 切分位置
n = len(triples)
n_train = int(n * TRAIN_RATIO)
n_valid = int(n * VALID_RATIO)

train_triples = triples[:n_train]
valid_triples = triples[n_train:n_train + n_valid]
test_triples  = triples[n_train + n_valid:]

# 写出
with open(TRAIN_OUT, "w") as f:
    f.writelines(train_triples)

with open(VALID_OUT, "w") as f:
    f.writelines(valid_triples)

with open(TEST_OUT, "w") as f:
    f.writelines(test_triples)

print("train:", len(train_triples))
print("valid:", len(valid_triples))
print("test :", len(test_triples))


total triples: 157309768
train: 125847814
valid: 15730976
test : 15730978
