In [71]:
import plotly.graph_objects as go
import numpy as np
import os

# モデルのパスや埋め込み範囲などのパラメータの設定
model_path = '/home/lab/eight/KGE-HAKE/models/TransE_wn18rr_0'
embedding_range = 0.01
relation = "_also_see"
bin_width = 0.01

# エンティティの埋め込みをロード
entity_embedding = np.load(os.path.join(model_path, 'entity_embedding.npy'))

# entities.dictの読み込み
entities_dict_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/entities.dict'
entities_dict = {}
with open(entities_dict_file, 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split('\t')
        entities_dict[int(parts[1])] = int(parts[0])

# テストデータからエンティティのIDを読み取る（_hypernymの関係にあるもののみ）
test_entities_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/test.txt'
test_hypernym_ids = []

with open(test_entities_file, 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 3 and parts[1] == relation:
            entity_id_1 = entities_dict.get(int(parts[0]))
            entity_id_2 = entities_dict.get(int(parts[2]))
            if entity_id_1 is not None and entity_id_2 is not None:
                test_hypernym_ids.extend([entity_id_1, entity_id_2])

# 距離を計算する関数
def calculate_distance(entity_id):
    phase, mod = np.split(entity_embedding[entity_id], 2)
    mod = np.log(np.abs(mod)) * np.sign(mod)
    phase = phase / embedding_range * np.pi
    x, y = mod * np.cos(phase), mod * np.sin(phase)
    distance = np.sqrt(x**2 + y**2)
    return distance

# トリプレットごとに距離差を計算
triplet_distances = [np.abs(calculate_distance(test_hypernym_ids[i]) - calculate_distance(test_hypernym_ids[i + 1])) for i in range(0, len(test_hypernym_ids), 2)]


In [72]:
# 平均を計算
average_distance = np.mean(triplet_distances)

# 結果を表示
print(f'平均距離差: {average_distance}')

平均距離差: 0.6699936389923096


In [73]:
# ヒストグラムの作成

num_bins = int((max(np.ravel(triplet_distances)) - min(np.ravel(triplet_distances))) / bin_width) + 1
fig = go.Figure()

# ヒストグラムの作成
fig.add_trace(go.Histogram(x=np.ravel(triplet_distances), nbinsx=num_bins, xbins=dict(size=bin_width), name='Distance Difference'))

# レイアウトの設定
fig.update_layout(
    title_text='ヒストグラム: トリプレットの距離差',
    xaxis_title_text='距離差',
    yaxis_title_text='トリプレットの数',
    bargap=0.05,  # バーの間隔
    xaxis=dict(tickvals=np.arange(0, max(np.ravel(triplet_distances)) + 0.5, 0.5), ticktext=[f'{i:.1f}' for i in np.arange(0, max(np.ravel(triplet_distances)) + 0.5, 0.5)]),
)

# プロットの表示
fig.show()