	_hypernym	dog.n.01 10
	_hypernym	genus.n.02 08108972 19
	_hypernym	mammal_genus.n.01　01864707
	_hypernym	plant_genus.n.01
	_hypernym	car.n.01


genus.n.02 24
mammal_genus.n.01 251     mammal 399
fish_genus.n.01 214       fish 820
reptile_genus.n.01　144
bird_genus.n.01 278

genus.n.02 24

In [5]:
import plotly.graph_objects as go
import numpy as np
import os
from sklearn.manifold import TSNE

# モデルのパスや埋め込み範囲などのパラメータの設定
model_path = '/home/lab/eight/KGE-HAKE/models/HAKE_wn18rr_0'
embedding_range = 0.01
model = 'HAKE'
entities_dict_text_file = '/home/lab/eight/KGE-HAKE/data/wn18rr_text/entities.dict'
entities_dict_number_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/entities.dict'
train_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/train.txt'
# 対象の値とファイルパスを指定して関数を実行
target_values = [["genus.n.02", 'red'], ["dog.n.01", 'blue'], ["clothing.n.01", 'orange'], ["plant_genus.n.01", 'green'],["mammal_genus.n.01", 'yellow'], ["bird_genus.n.01", 'purple'], ["arthropod.n.01", 'gray']]

fig = go.Figure()
# エンティティの埋め込みをロード
entity_embedding = np.load(os.path.join(model_path, 'entity_embedding.npy'))

def load_entities_dict_number(file_path):
    entities_dict_number = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            entities_dict_number[int(parts[0])] = parts[1]
    return entities_dict_number

# エンティティ名のtext辞書読み込み
entities_dict_text = load_entities_dict_number(entities_dict_text_file)
# エンティティ名のnumber辞書読み込み
entities_dict_number = load_entities_dict_number(entities_dict_number_file)

def extract_and_search_data(target_value, train_file_path, entities_dict_text_file, entities_dict_number_file):
    # エンティティ名のtext辞書読み込み
    def load_entities_dict(file_path):
        entities_dict = {}
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                entities_dict[int(parts[0])] = parts[1]
        return entities_dict

    entities_dict_text = load_entities_dict(entities_dict_text_file)
    entities_dict_number = load_entities_dict(entities_dict_number_file)

    # エンティティ名からインデックスを取得する関数
    def text_to_num(target_value):
        for index, value in entities_dict_text.items():
            if value == target_value:
                return entities_dict_number[index]

    # 対象の値からインデックスを取得
    num = text_to_num(target_value)

    # ファイルからデータを抽出する関数
    def extract_data(file_path, target_value):
        target_value = "\t" + target_value
        extracted_data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if target_value in line:
                    parts = line.strip().split('\t')
                    extracted_data.append(parts[0])
        return extracted_data

    # ファイルからデータを検索する関数
    def search_num(extracted_data, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                for i in extracted_data:
                    if i in line:
                        parts = line.strip().split('\t')
                        data.append(parts[0])
        return data

    # データを抽出
    extracted_data = extract_data(train_file_path, num)

    # データを検索
    searched_data = search_num(extracted_data, entities_dict_number_file)

    return searched_data

def calculate_tsne_coordinates(extracted_data, entity_embedding):
    # Get embeddings for extracted data
    data_embeddings = [entity_embedding[int(entity_id_str)] for entity_id_str in extracted_data]
    
    # Apply t-SNE to reduce dimensionality to 2D
    tsne = TSNE(n_components=2, random_state=0)
    tsne_coordinates = tsne.fit_transform(data_embeddings)
    
    return tsne_coordinates.tolist()

# 全データを格納するリスト
all_extracted_data = []
all_target_colors = []

# 対象の値と色を指定してデータをロードし、リストに追加するループ
for target_value, target_color in target_values:
    extracted_data = extract_and_search_data(target_value, train_file, entities_dict_text_file, entities_dict_number_file)
    all_extracted_data.extend(extracted_data)
    all_target_colors.extend([target_color] * len(extracted_data))

# t-SNE座標を計算
tsne_coordinates = calculate_tsne_coordinates(all_extracted_data, entity_embedding)

# プロットにトレースを追加するループ
for (x, y), color, entity_id_str in zip(tsne_coordinates, all_target_colors, all_extracted_data):
    entity_id = int(entity_id_str)
    entity_name = entities_dict_text[entity_id]
    fig.add_trace(go.Scatter(x=[x], y=[y], mode='markers', marker=dict(color=color, size=5), opacity=0.4, name=entity_name))


# 色と対応するエンティティの説明を右上に追加
for i, (target_value, target_color) in enumerate(target_values):
    fig.add_annotation(
        text="_hypernym	" + target_value,
        x=0.95,  # テキストのx座標 (0.95は図の右端から5%左)
        y=0.95 - i * 0.05,  # テキストのy座標 (0.95は図の上端から5%下)
        xref="paper",  # x座標の参照元 (paperは図の左端、dataはデータ座標)
        yref="paper",  # y座標の参照元 (paperは図の下端、dataはデータ座標)
        showarrow=False,  # 矢印を非表示にする
        font=dict(size=12, color=target_color),  # テキストのフォントサイズと色
    )

# HAKEテキストの追加
fig.add_annotation(
    text=model,  # 表示するテキスト
    xref="paper",  # x座標の参照元 (paperは図の左端、dataはデータ座標)
    yref="paper",  # y座標の参照元 (paperは図の下端、dataはデータ座標)
    x=0.5,  # x座標 (0.5は図の中央)
    y=1.1,  # y座標 (1.05は図の上端から5%上)
    showarrow=False,  # 矢印を非表示にする
    font=dict(size=20),  # テキストのフォントサイズ
)
# レイアウトの設定
fig.update_layout(
    title="",
    yaxis_title=" ",
    showlegend=False,
    width=550,  # 幅
    height=500,  # 高さ
    xaxis=dict(scaleanchor="y", scaleratio=1),  # x軸のアスペクト比を1:1に設定
    yaxis=dict(scaleanchor="x", scaleratio=1),  # y軸のアスペクト比を1:1に設定
)

fig.show()


In [6]:
import plotly.graph_objects as go
import numpy as np
import os

# モデルのパスや埋め込み範囲などのパラメータの設定
model_path = '/home/lab/eight/KGE-HAKE/models/HAKE_wn18rr_0'
embedding_range = 0.01
model = 'HAKE'
entities_dict_text_file = '/home/lab/eight/KGE-HAKE/data/wn18rr_text/entities.dict'
entities_dict_number_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/entities.dict'
train_file = '/home/lab/eight/KGE-HAKE/data/wn18rr/train.txt'
# 対象の値とファイルパスを指定して関数を実行
target_values = [["genus.n.02", 'red'], ["dog.n.01", 'blue'], ["clothing.n.01", 'orange'], ["plant_genus.n.01", 'green'],["mammal_genus.n.01", 'yellow'], ["bird_genus.n.01", 'purple'], ["arthropod.n.01", 'gray']]


fig = go.Figure()
# エンティティの埋め込みをロード
entity_embedding = np.load(os.path.join(model_path, 'entity_embedding.npy'))

def load_entities_dict_number(file_path):
    entities_dict_number = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            entities_dict_number[int(parts[0])] = parts[1]
    return entities_dict_number

# エンティティ名のtext辞書読み込み
entities_dict_text = load_entities_dict_number(entities_dict_text_file)
# エンティティ名のnumber辞書読み込み
entities_dict_number = load_entities_dict_number(entities_dict_number_file)

def extract_and_search_data(target_value, train_file_path, entities_dict_text_file, entities_dict_number_file):
    # エンティティ名のtext辞書読み込み
    def load_entities_dict(file_path):
        entities_dict = {}
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                entities_dict[int(parts[0])] = parts[1]
        return entities_dict

    entities_dict_text = load_entities_dict(entities_dict_text_file)
    entities_dict_number = load_entities_dict(entities_dict_number_file)

    # エンティティ名からインデックスを取得する関数
    def text_to_num(target_value):
        for index, value in entities_dict_text.items():
            if value == target_value:
                return entities_dict_number[index]

    # 対象の値からインデックスを取得
    num = text_to_num(target_value)

    # ファイルからデータを抽出する関数
    def extract_data(file_path, target_value):
        target_value = "_hypernym\t" + target_value
        extracted_data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if target_value in line:
                    parts = line.strip().split('\t')
                    extracted_data.append(parts[0])
        return extracted_data

    # ファイルからデータを検索する関数
    def search_num(extracted_data, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                for i in extracted_data:
                    if i in line:
                        parts = line.strip().split('\t')
                        data.append(parts[0])
        return data

    # データを抽出
    extracted_data = extract_data(train_file_path, num)

    # データを検索
    searched_data = search_num(extracted_data, entities_dict_number_file)

    return searched_data

# def calculate_mean_coordinates(extracted_data, entity_embedding, embedding_range, entities_dict_text):
#     mean_coordinates = []
#     for entity_id_str in extracted_data:
#         entity_id = int(entity_id_str)
#         entity = entity_embedding[entity_id]
#         phase, mod = np.split(entity, 2)
#         mod = np.log(np.abs(mod)) * np.sign(mod)
#         phase = phase / embedding_range * np.pi 
#         mod_mean = np.mean(mod)
#         phase_mean = np.mean(phase)
#         x, y = mod_mean * np.cos(phase_mean), mod_mean * np.sin(phase_mean)
#         mean_coordinates.append((x, y))
#     return mean_coordinates

def calculate_mean_coordinates(extracted_data, entity_embedding, embedding_range, entities_dict_text):
    mean_coordinates = []
    for entity_id_str in extracted_data:
        entity_id = int(entity_id_str)
        entity = entity_embedding[entity_id]
        phase, mod = np.split(entity, 2)
        mod = np.log(np.abs(mod))
        phase = phase / embedding_range * np.pi
        
        mod_mean = np.mean(mod)
        phase_mean = np.mean(phase)
        
        x, y = mod_mean * np.cos(phase_mean) * (-1), mod_mean * np.sin(phase_mean)
        mean_coordinates.append((x, y))
    
    return mean_coordinates


for i in target_values:
    target_value, target_color = i[0], i[1]
    extracted_data = extract_and_search_data(target_value, train_file, entities_dict_text_file, entities_dict_number_file)
    print(target_value, 'num:',len(extracted_data), extracted_data)

    mean_coordinates = calculate_mean_coordinates(extracted_data, entity_embedding, embedding_range, entities_dict_text)

    # 平均座標のプロット
    for entity_id_str, (x, y) in zip(extracted_data, mean_coordinates):
        # entity_id_strを整数に変換
        entity_id = int(entity_id_str)
        fig.add_trace(go.Scatter(x=[x], y=[y], mode='markers', name=entities_dict_text[entity_id], marker=dict(color=target_color, size=5), text=[entities_dict_text[entity_id]], opacity=0.4))

# 色と対応するエンティティの説明を右上に追加
for i, (target_value, target_color) in enumerate(target_values):
    fig.add_annotation(
        text="_hypernym	" + target_value,
        x=0.95,  # テキストのx座標 (0.95は図の右端から5%左)
        y=0.95 - i * 0.05,  # テキストのy座標 (0.95は図の上端から5%下)
        xref="paper",  # x座標の参照元 (paperは図の左端、dataはデータ座標)
        yref="paper",  # y座標の参照元 (paperは図の下端、dataはデータ座標)
        showarrow=False,  # 矢印を非表示にする
        font=dict(size=12, color=target_color),  # テキストのフォントサイズと色
    )

# HAKEテキストの追加
fig.add_annotation(
    text=model,  # 表示するテキスト
    xref="paper",  # x座標の参照元 (paperは図の左端、dataはデータ座標)
    yref="paper",  # y座標の参照元 (paperは図の下端、dataはデータ座標)
    x=0.5,  # x座標 (0.5は図の中央)
    y=1.1,  # y座標 (1.05は図の上端から5%上)
    showarrow=False,  # 矢印を非表示にする
    font=dict(size=20),  # テキストのフォントサイズ
)
# レイアウトの設定
fig.update_layout(
    title="",
    yaxis_title=" ",
    showlegend=False,
    width=550,  # 幅
    height=500,  # 高さ
    xaxis=dict(scaleanchor="y", scaleratio=1),  # x軸のアスペクト比を1:1に設定
    yaxis=dict(scaleanchor="x", scaleratio=1),  # y軸のアスペクト比を1:1に設定
)

# プロットの表示
fig.show()

genus.n.02 num: 19 ['21', '81', '265', '284', '427', '542', '692', '716', '883', '1244', '1929', '2110', '4550', '9626', '10288', '13047', '17252', '17504', '23662']
dog.n.01 num: 10 ['26', '10149', '16530', '20201', '21324', '22898', '26628', '37881', '39316', '39907']
clothing.n.01 num: 28 ['2796', '3006', '4190', '4694', '5177', '9874', '13161', '13202', '14106', '18788', '18942', '19381', '22328', '30766', '31424', '31804', '34237', '34531', '34735', '35184', '35395', '36895', '37025', '37875', '40077', '40205', '40444', '40557']
plant_genus.n.01 num: 43 ['106', '1293', '3113', '4676', '5014', '5561', '7242', '7263', '7923', '9198', '9933', '10237', '11960', '12834', '14320', '14784', '14851', '16653', '17648', '18657', '19239', '19946', '20969', '21215', '23565', '24038', '24211', '25370', '25662', '26971', '28245', '28599', '29700', '29704', '30492', '31536', '32257', '32820', '34894', '35638', '36586', '37440', '39290']
mammal_genus.n.01 num: 249 ['223', '264', '354', '598', '67