In [2]:
import torch
from tqdm import tqdm
from model_transE import TransE
import json
import pandas as pd
from torch.nn import functional as F

In [3]:
entity2id = json.load(open("dataset/FB15k/entity2id.json"))
label2id = json.load(open("dataset/FB15k/label2id.json"))

In [4]:
def get_score(head_emb, rel_emb, tail_emb, norm):
    if len(head_emb.shape) < 2:
        head_emb = head_emb.unsqueeze(dim=0)
    if len(rel_emb.shape) < 2:
        rel_emb = rel_emb.unsqueeze(dim=0)
    if len(tail_emb.shape) < 2:
        tail_emb = tail_emb.unsqueeze(dim=0)
    # [b,e]
    score = torch.norm(head_emb + rel_emb - tail_emb, p=norm, dim=-1)
    return score.squeeze()

In [5]:
def evaluate(
    entity_embeddings,
    relation_embeddings,
    test_triples: list,
    all_triples: list,
    entity2id: dict,
    rel2id: dict,
    norm=2,
    k_list=[10],
    device="cpu",
    raw=True,
    entity_dim=50,
    relation_dim=50,
):
    # 构造一个三维的张量，值为1表示真实三元组
    h2_rt = dict()
    t2_rh = dict()
    for h, r, t in all_triples:
        if (r, t) not in h2_rt:
            h2_rt[(r, t)] = [h]
        else:
            h2_rt[(r, t)].append(h)
        if (r, h) not in t2_rh:
            t2_rh[(r, h)] = [t]
        else:
            t2_rh[(r, h)].append(t)

    # 将嵌入转移到指定的设备中
    entity_embeddings = entity_embeddings.to(device)
    relation_embeddings = relation_embeddings.to(device)
    hits_at_k = {k: 0 for k in k_list}
    mean_rank = 0
    entity_ids = list(entity2id.values())
    # 使用循环计算指标
    for head_idx, relation_idx, tail_idx in tqdm(test_triples, desc="Evaluating"):
        # 获取正例
        head_emb = entity_embeddings[head_idx]
        rel_emb = relation_embeddings[relation_idx]
        tail_emb = entity_embeddings[tail_idx]
        # rel_proj = relation_projs[relation_idx]
        # 计算正例的分数
        positive_score = get_score(
            head_emb,
            rel_emb,
            tail_emb,
            norm,
        ).unsqueeze(0)
        # 计算所有头实体被替换后的分数\
        all_heads_scores = get_score(
            entity_embeddings,
            rel_emb,
            tail_emb,
            norm,
        )

        if not raw:
            all_heads_scores[h2_rt[(relation_idx, tail_idx)]] = 1e8
        rank = (positive_score > all_heads_scores).sum() + 1
        head_rank = rank
        # mean_rank += rank.item()

        # 计算 Hit@K
        for k in k_list:
            if rank <= k:
                hits_at_k[k] += 1
        all_tail_scores = get_score(
            head_emb,
            rel_emb,
            entity_embeddings,
            norm,
        )
        if not raw:
            all_tail_scores[t2_rh[(relation_idx, head_idx)]] = 1e8
        rank = (positive_score > all_tail_scores).sum() + 1
        tail_rank = rank
        # sorted_scores, sorted_indices = torch.sort(all_tail_scores)
        # rank = (sorted_indices == pos_idx).nonzero(as_tuple=True)[0].item() + 1
        # mean_rank += rank.item()
        mean_rank += (head_rank.item() + tail_rank.item()) / 2

        # 计算 Hit@K
        for k in k_list:
            if rank <= k:
                hits_at_k[k] += 1
    # 计算平均值
    # num_samples = len(h2replace_ids) + len(t2replace_ids)
    num_samples = len(test_triples) * 2
    mean_rank /= len(test_triples)
    hits_at_k = {k: v / (num_samples) for k, v in hits_at_k.items()}
    return {"Mean Rank": mean_rank, "Hits@K": hits_at_k}

In [6]:
test_data = pd.read_csv("dataset/FB15k/test.txt", names=["h", "r", "t"], delimiter="\t")
valid_data = pd.read_csv(
    "dataset/FB15k/valid.txt", names=["h", "r", "t"], delimiter="\t"
)
train_data = pd.read_csv(
    "dataset/FB15k/train.txt", names=["h", "r", "t"], delimiter="\t"
)
test_entities = list(set(test_data["h"].tolist() + test_data["t"].tolist()))
test_entities_ids = [entity2id[s] for s in test_entities]
test_triples = test_data.values.tolist()
test_triples = [
    (entity2id[s[0]], label2id[s[1]], entity2id[s[2]]) for s in test_triples
]
valid_triples = valid_data.values.tolist()
valid_triples = [
    (entity2id[s[0]], label2id[s[1]], entity2id[s[2]]) for s in valid_triples
]
all_triples = (
    train_data.values.tolist() + valid_data.values.tolist() + test_data.values.tolist()
)
all_triples = [(entity2id[s[0]], label2id[s[1]], entity2id[s[2]]) for s in all_triples]

In [17]:
parameters = torch.load("model_save/TransE-model-bern2.pth", map_location="cuda")[
    "model_state_dict"
]
entity_embeddings = parameters["entity_embeddings.weight"]
relation_embeddings = parameters["relation_embeddings.weight"]

# Normalize the embeddings
# entity_embeddings = F.normalize(entity_embeddings, dim=1, p=2)
# relation_embeddings = F.normalize(relation_embeddings, dim=1, p=2)
# relation_projs = parameters["relation_proj.weight"]
score = evaluate(
    entity_embeddings=entity_embeddings,
    relation_embeddings=relation_embeddings,
    test_triples=valid_triples,
    all_triples=all_triples,
    entity2id=entity2id,
    rel2id=label2id,
    norm=1,
    k_list=[10],
    device="cuda",
    raw=True,
)
score

  parameters = torch.load("model_save/TransE-model-bern2.pth", map_location="cuda")[
Evaluating: 100%|██████████| 50000/50000 [00:18<00:00, 2694.85it/s]


{'Mean Rank': 216.13289, 'Hits@K': {10: 0.37105}}